702 lines
27 KiB
Python
702 lines
27 KiB
Python
"""Tests for OidcMixin gRPC endpoints.
|
|
|
|
Tests cover:
|
|
- RegisterOidcProvider: provider registration with validation
|
|
- ListOidcProviders: listing with optional filtering
|
|
- GetOidcProvider: single provider retrieval
|
|
- UpdateOidcProvider: provider configuration updates
|
|
- DeleteOidcProvider: provider removal
|
|
- RefreshOidcDiscovery: discovery refresh operations
|
|
- ListOidcPresets: available provider presets
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable
|
|
from typing import TYPE_CHECKING
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
from noteflow.domain.auth.oidc import (
|
|
ClaimMapping,
|
|
OidcDiscoveryConfig,
|
|
OidcProviderConfig,
|
|
OidcProviderPreset,
|
|
)
|
|
from noteflow.grpc.mixins._types import GrpcContext
|
|
from noteflow.grpc.mixins.oidc import OidcMixin
|
|
from noteflow.grpc.proto import noteflow_pb2
|
|
from noteflow.infrastructure.auth.oidc_discovery import OidcDiscoveryError
|
|
|
|
if TYPE_CHECKING:
|
|
from datetime import datetime
|
|
|
|
|
|
class MockServicerHost(OidcMixin):
|
|
"""Mock servicer host implementing required protocol for OidcMixin."""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize mock servicer with no OIDC service (created lazily)."""
|
|
|
|
# Type stubs for mixin methods to fix type inference
|
|
if TYPE_CHECKING:
|
|
|
|
async def RegisterOidcProvider(
|
|
self,
|
|
request: noteflow_pb2.RegisterOidcProviderRequest,
|
|
context: GrpcContext,
|
|
) -> noteflow_pb2.OidcProviderProto: ...
|
|
|
|
async def ListOidcProviders(
|
|
self,
|
|
request: noteflow_pb2.ListOidcProvidersRequest,
|
|
context: GrpcContext,
|
|
) -> noteflow_pb2.ListOidcProvidersResponse: ...
|
|
|
|
async def GetOidcProvider(
|
|
self,
|
|
request: noteflow_pb2.GetOidcProviderRequest,
|
|
context: GrpcContext,
|
|
) -> noteflow_pb2.OidcProviderProto: ...
|
|
|
|
async def UpdateOidcProvider(
|
|
self,
|
|
request: noteflow_pb2.UpdateOidcProviderRequest,
|
|
context: GrpcContext,
|
|
) -> noteflow_pb2.OidcProviderProto: ...
|
|
|
|
async def DeleteOidcProvider(
|
|
self,
|
|
request: noteflow_pb2.DeleteOidcProviderRequest,
|
|
context: GrpcContext,
|
|
) -> noteflow_pb2.DeleteOidcProviderResponse: ...
|
|
|
|
async def RefreshOidcDiscovery(
|
|
self,
|
|
request: noteflow_pb2.RefreshOidcDiscoveryRequest,
|
|
context: GrpcContext,
|
|
) -> noteflow_pb2.RefreshOidcDiscoveryResponse: ...
|
|
|
|
async def ListOidcPresets(
|
|
self,
|
|
request: noteflow_pb2.ListOidcPresetsRequest,
|
|
context: GrpcContext,
|
|
) -> noteflow_pb2.ListOidcPresetsResponse: ...
|
|
|
|
|
|
@pytest.fixture
|
|
def oidc_servicer() -> MockServicerHost:
|
|
"""Create servicer for OIDC mixin testing."""
|
|
return MockServicerHost()
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_provider(sample_datetime: datetime) -> OidcProviderConfig:
|
|
"""Create sample OIDC provider config for testing."""
|
|
return OidcProviderConfig(
|
|
id=uuid4(),
|
|
workspace_id=uuid4(),
|
|
name="Test Authentik",
|
|
preset=OidcProviderPreset.AUTHENTIK,
|
|
issuer_url="https://auth.example.com",
|
|
client_id="test-client-id",
|
|
enabled=True,
|
|
discovery=OidcDiscoveryConfig(
|
|
issuer="https://auth.example.com",
|
|
authorization_endpoint="https://auth.example.com/authorize",
|
|
token_endpoint="https://auth.example.com/token",
|
|
userinfo_endpoint="https://auth.example.com/userinfo",
|
|
jwks_uri="https://auth.example.com/.well-known/jwks.json",
|
|
scopes_supported=("openid", "profile", "email"),
|
|
claims_supported=("sub", "email", "name"),
|
|
),
|
|
claim_mapping=ClaimMapping(),
|
|
scopes=("openid", "profile", "email"),
|
|
require_email_verified=True,
|
|
allowed_groups=(),
|
|
created_at=sample_datetime,
|
|
updated_at=sample_datetime,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def disabled_provider(
|
|
sample_provider: OidcProviderConfig, sample_datetime: datetime
|
|
) -> OidcProviderConfig:
|
|
"""Create a disabled OIDC provider config for testing."""
|
|
return OidcProviderConfig(
|
|
id=sample_provider.id,
|
|
workspace_id=sample_provider.workspace_id,
|
|
name=sample_provider.name,
|
|
preset=sample_provider.preset,
|
|
issuer_url=sample_provider.issuer_url,
|
|
client_id=sample_provider.client_id,
|
|
enabled=False,
|
|
discovery=sample_provider.discovery,
|
|
claim_mapping=sample_provider.claim_mapping,
|
|
scopes=sample_provider.scopes,
|
|
require_email_verified=sample_provider.require_email_verified,
|
|
allowed_groups=sample_provider.allowed_groups,
|
|
created_at=sample_datetime,
|
|
updated_at=sample_datetime,
|
|
)
|
|
|
|
|
|
class TestRegisterOidcProvider:
|
|
"""Tests for RegisterOidcProvider RPC."""
|
|
|
|
async def test_registers_provider_successfully(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
sample_provider: OidcProviderConfig,
|
|
) -> None:
|
|
"""RegisterOidcProvider creates provider with valid input."""
|
|
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
|
|
mock_service = MagicMock()
|
|
mock_service.register_provider = AsyncMock(return_value=(sample_provider, []))
|
|
mock_service.registry = MagicMock()
|
|
mock_get_service.return_value = mock_service
|
|
|
|
request = noteflow_pb2.RegisterOidcProviderRequest(
|
|
workspace_id=str(sample_provider.workspace_id),
|
|
name="Test Authentik",
|
|
issuer_url="https://auth.example.com",
|
|
client_id="test-client-id",
|
|
preset="authentik",
|
|
)
|
|
|
|
response = await oidc_servicer.RegisterOidcProvider(request, mock_grpc_context)
|
|
|
|
expected_provider_id = str(sample_provider.id)
|
|
assert response.id == expected_provider_id, "should return provider id"
|
|
assert response.name == "Test Authentik", "should return provider name"
|
|
assert response.preset == "authentik", "should return preset"
|
|
mock_service.register_provider.assert_called_once()
|
|
|
|
async def test_returns_warnings_from_validation(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
sample_provider: OidcProviderConfig,
|
|
) -> None:
|
|
"""RegisterOidcProvider returns validation warnings."""
|
|
warnings = ["Scope 'groups' not supported by provider"]
|
|
|
|
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
|
|
mock_service = MagicMock()
|
|
mock_service.register_provider = AsyncMock(return_value=(sample_provider, warnings))
|
|
mock_service.registry = MagicMock()
|
|
mock_get_service.return_value = mock_service
|
|
|
|
request = noteflow_pb2.RegisterOidcProviderRequest(
|
|
name="Test Provider",
|
|
issuer_url="https://auth.example.com",
|
|
client_id="test-client-id",
|
|
)
|
|
|
|
response = await oidc_servicer.RegisterOidcProvider(request, mock_grpc_context)
|
|
|
|
assert len(response.warnings) == 1, "should include warnings"
|
|
assert "groups" in response.warnings[0], "should include warning message"
|
|
|
|
async def test_rejects_missing_name(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
) -> None:
|
|
"""RegisterOidcProvider aborts when name is missing."""
|
|
request = noteflow_pb2.RegisterOidcProviderRequest(
|
|
issuer_url="https://auth.example.com",
|
|
client_id="test-client-id",
|
|
)
|
|
|
|
with pytest.raises(AssertionError, match="Unreachable"):
|
|
await oidc_servicer.RegisterOidcProvider(request, mock_grpc_context)
|
|
|
|
mock_grpc_context.abort.assert_called_once()
|
|
|
|
async def test_rejects_missing_issuer_url(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
) -> None:
|
|
"""RegisterOidcProvider aborts when issuer_url is missing."""
|
|
request = noteflow_pb2.RegisterOidcProviderRequest(
|
|
name="Test Provider",
|
|
client_id="test-client-id",
|
|
)
|
|
|
|
with pytest.raises(AssertionError, match="Unreachable"):
|
|
await oidc_servicer.RegisterOidcProvider(request, mock_grpc_context)
|
|
|
|
mock_grpc_context.abort.assert_called_once()
|
|
|
|
async def test_rejects_invalid_issuer_url_scheme(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
) -> None:
|
|
"""RegisterOidcProvider aborts when issuer_url has invalid scheme."""
|
|
request = noteflow_pb2.RegisterOidcProviderRequest(
|
|
name="Test Provider",
|
|
issuer_url="ftp://auth.example.com",
|
|
client_id="test-client-id",
|
|
)
|
|
|
|
with pytest.raises(AssertionError, match="Unreachable"):
|
|
await oidc_servicer.RegisterOidcProvider(request, mock_grpc_context)
|
|
|
|
mock_grpc_context.abort.assert_called_once()
|
|
|
|
async def test_rejects_missing_client_id(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
) -> None:
|
|
"""RegisterOidcProvider aborts when client_id is missing."""
|
|
request = noteflow_pb2.RegisterOidcProviderRequest(
|
|
name="Test Provider",
|
|
issuer_url="https://auth.example.com",
|
|
)
|
|
|
|
with pytest.raises(AssertionError, match="Unreachable"):
|
|
await oidc_servicer.RegisterOidcProvider(request, mock_grpc_context)
|
|
|
|
mock_grpc_context.abort.assert_called_once()
|
|
|
|
async def test_handles_discovery_error(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
) -> None:
|
|
"""RegisterOidcProvider aborts on discovery failure."""
|
|
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
|
|
mock_service = MagicMock()
|
|
mock_service.register_provider = AsyncMock(
|
|
side_effect=OidcDiscoveryError("Connection failed")
|
|
)
|
|
mock_get_service.return_value = mock_service
|
|
|
|
request = noteflow_pb2.RegisterOidcProviderRequest(
|
|
name="Test Provider",
|
|
issuer_url="https://auth.example.com",
|
|
client_id="test-client-id",
|
|
)
|
|
|
|
with pytest.raises(AssertionError, match="Unreachable"):
|
|
await oidc_servicer.RegisterOidcProvider(request, mock_grpc_context)
|
|
|
|
mock_grpc_context.abort.assert_called_once()
|
|
|
|
|
|
class TestListOidcProviders:
|
|
"""Tests for ListOidcProviders RPC."""
|
|
|
|
async def test_lists_all_providers(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
sample_provider: OidcProviderConfig,
|
|
) -> None:
|
|
"""ListOidcProviders returns all providers."""
|
|
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
|
|
mock_service = MagicMock()
|
|
mock_service.registry.list_providers = MagicMock(return_value=[sample_provider])
|
|
mock_get_service.return_value = mock_service
|
|
|
|
request = noteflow_pb2.ListOidcProvidersRequest()
|
|
response = await oidc_servicer.ListOidcProviders(request, mock_grpc_context)
|
|
|
|
assert response.total_count == 1, "should return total count"
|
|
assert len(response.providers) == 1, "should return providers list"
|
|
assert response.providers[0].name == "Test Authentik", (
|
|
"should return correct provider name"
|
|
)
|
|
|
|
async def test_filters_by_workspace_id(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
) -> None:
|
|
"""ListOidcProviders filters by workspace_id."""
|
|
workspace_id = uuid4()
|
|
|
|
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
|
|
mock_service = MagicMock()
|
|
mock_service.registry.list_providers = MagicMock(return_value=[])
|
|
mock_get_service.return_value = mock_service
|
|
|
|
request = noteflow_pb2.ListOidcProvidersRequest(workspace_id=str(workspace_id))
|
|
await oidc_servicer.ListOidcProviders(request, mock_grpc_context)
|
|
|
|
mock_service.registry.list_providers.assert_called_once_with(
|
|
workspace_id=workspace_id,
|
|
enabled_only=False,
|
|
)
|
|
|
|
async def test_filters_enabled_only(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
) -> None:
|
|
"""ListOidcProviders filters to enabled providers only."""
|
|
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
|
|
mock_service = MagicMock()
|
|
mock_service.registry.list_providers = MagicMock(return_value=[])
|
|
mock_get_service.return_value = mock_service
|
|
|
|
request = noteflow_pb2.ListOidcProvidersRequest(enabled_only=True)
|
|
await oidc_servicer.ListOidcProviders(request, mock_grpc_context)
|
|
|
|
mock_service.registry.list_providers.assert_called_once_with(
|
|
workspace_id=None,
|
|
enabled_only=True,
|
|
)
|
|
|
|
async def test_returns_empty_list_when_no_providers(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
) -> None:
|
|
"""ListOidcProviders returns empty list when no providers exist."""
|
|
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
|
|
mock_service = MagicMock()
|
|
mock_service.registry.list_providers = MagicMock(return_value=[])
|
|
mock_get_service.return_value = mock_service
|
|
|
|
request = noteflow_pb2.ListOidcProvidersRequest()
|
|
response = await oidc_servicer.ListOidcProviders(request, mock_grpc_context)
|
|
|
|
assert response.total_count == 0, "total count should be zero when no providers"
|
|
assert len(response.providers) == 0, "providers list should be empty"
|
|
|
|
|
|
class TestGetOidcProvider:
|
|
"""Tests for GetOidcProvider RPC."""
|
|
|
|
async def test_returns_provider_by_id(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
sample_provider: OidcProviderConfig,
|
|
) -> None:
|
|
"""GetOidcProvider returns provider when found."""
|
|
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
|
|
mock_service = MagicMock()
|
|
mock_service.registry.get_provider = MagicMock(return_value=sample_provider)
|
|
mock_get_service.return_value = mock_service
|
|
|
|
request = noteflow_pb2.GetOidcProviderRequest(provider_id=str(sample_provider.id))
|
|
response = await oidc_servicer.GetOidcProvider(request, mock_grpc_context)
|
|
|
|
expected_provider_id = str(sample_provider.id)
|
|
assert response.id == expected_provider_id, "should return correct provider ID"
|
|
assert response.name == "Test Authentik", "should return correct provider name"
|
|
|
|
async def test_aborts_on_invalid_provider_id_format(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
) -> None:
|
|
"""GetOidcProvider aborts with INVALID_ARGUMENT for invalid UUID."""
|
|
request = noteflow_pb2.GetOidcProviderRequest(provider_id="not-a-uuid")
|
|
|
|
with pytest.raises(AssertionError, match="Unreachable"):
|
|
await oidc_servicer.GetOidcProvider(request, mock_grpc_context)
|
|
|
|
mock_grpc_context.abort.assert_called_once()
|
|
|
|
|
|
class TestUpdateOidcProvider:
|
|
"""Tests for UpdateOidcProvider RPC."""
|
|
|
|
async def test_updates_provider_name(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
sample_provider: OidcProviderConfig,
|
|
) -> None:
|
|
"""UpdateOidcProvider updates provider name."""
|
|
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
|
|
mock_service = MagicMock()
|
|
mock_service.registry.get_provider = MagicMock(return_value=sample_provider)
|
|
mock_get_service.return_value = mock_service
|
|
|
|
request = noteflow_pb2.UpdateOidcProviderRequest(
|
|
provider_id=str(sample_provider.id),
|
|
name="Updated Name",
|
|
)
|
|
response = await oidc_servicer.UpdateOidcProvider(request, mock_grpc_context)
|
|
|
|
assert response.name == "Updated Name", "should return updated name"
|
|
|
|
async def test_updates_provider_scopes(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
sample_provider: OidcProviderConfig,
|
|
) -> None:
|
|
"""UpdateOidcProvider updates provider scopes."""
|
|
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
|
|
mock_service = MagicMock()
|
|
mock_service.registry.get_provider = MagicMock(return_value=sample_provider)
|
|
mock_get_service.return_value = mock_service
|
|
|
|
request = noteflow_pb2.UpdateOidcProviderRequest(
|
|
provider_id=str(sample_provider.id),
|
|
scopes=["openid", "profile", "email", "groups"],
|
|
)
|
|
response = await oidc_servicer.UpdateOidcProvider(request, mock_grpc_context)
|
|
|
|
assert "groups" in response.scopes, "should include added scope"
|
|
|
|
async def test_enables_provider(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
disabled_provider: OidcProviderConfig,
|
|
) -> None:
|
|
"""UpdateOidcProvider can enable a disabled provider."""
|
|
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
|
|
mock_service = MagicMock()
|
|
mock_service.registry.get_provider = MagicMock(return_value=disabled_provider)
|
|
mock_get_service.return_value = mock_service
|
|
|
|
request = noteflow_pb2.UpdateOidcProviderRequest(
|
|
provider_id=str(disabled_provider.id),
|
|
enabled=True,
|
|
)
|
|
response = await oidc_servicer.UpdateOidcProvider(request, mock_grpc_context)
|
|
|
|
assert response.enabled is True, "provider should be enabled"
|
|
|
|
|
|
class TestOidcProviderNotFound:
|
|
"""Tests for OIDC provider not found cases."""
|
|
|
|
@staticmethod
|
|
def _configure_get_provider_none(mock_service: MagicMock) -> None:
|
|
mock_service.registry.get_provider = MagicMock(return_value=None)
|
|
|
|
@staticmethod
|
|
def _configure_remove_provider_false(mock_service: MagicMock) -> None:
|
|
mock_service.registry.remove_provider = MagicMock(return_value=False)
|
|
|
|
@pytest.mark.parametrize(
|
|
("method_name", "proto_request", "configure"),
|
|
[
|
|
pytest.param(
|
|
"GetOidcProvider",
|
|
noteflow_pb2.GetOidcProviderRequest(provider_id=str(uuid4())),
|
|
_configure_get_provider_none,
|
|
id="get",
|
|
),
|
|
pytest.param(
|
|
"UpdateOidcProvider",
|
|
noteflow_pb2.UpdateOidcProviderRequest(
|
|
provider_id=str(uuid4()),
|
|
name="New Name",
|
|
),
|
|
_configure_get_provider_none,
|
|
id="update",
|
|
),
|
|
pytest.param(
|
|
"DeleteOidcProvider",
|
|
noteflow_pb2.DeleteOidcProviderRequest(provider_id=str(uuid4())),
|
|
_configure_remove_provider_false,
|
|
id="delete",
|
|
),
|
|
],
|
|
)
|
|
async def test_provider_not_found_aborts(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
method_name: str,
|
|
proto_request: object,
|
|
configure: Callable[[MagicMock], None],
|
|
) -> None:
|
|
"""OIDC provider operations abort when provider is missing."""
|
|
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
|
|
mock_service = MagicMock()
|
|
configure(mock_service)
|
|
mock_get_service.return_value = mock_service
|
|
|
|
method = getattr(oidc_servicer, method_name)
|
|
with pytest.raises(AssertionError, match="Unreachable"):
|
|
await method(proto_request, mock_grpc_context)
|
|
|
|
mock_grpc_context.abort.assert_called_once()
|
|
|
|
|
|
class TestDeleteOidcProvider:
|
|
"""Tests for DeleteOidcProvider RPC."""
|
|
|
|
async def test_deletes_provider_successfully(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
) -> None:
|
|
"""DeleteOidcProvider removes provider when found."""
|
|
provider_id = uuid4()
|
|
|
|
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
|
|
mock_service = MagicMock()
|
|
mock_service.registry.remove_provider = MagicMock(return_value=True)
|
|
mock_get_service.return_value = mock_service
|
|
|
|
request = noteflow_pb2.DeleteOidcProviderRequest(provider_id=str(provider_id))
|
|
response = await oidc_servicer.DeleteOidcProvider(request, mock_grpc_context)
|
|
|
|
assert response.success is True, "delete should return success=True"
|
|
mock_service.registry.remove_provider.assert_called_once_with(provider_id)
|
|
|
|
async def test_aborts_on_invalid_provider_id(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
) -> None:
|
|
"""DeleteOidcProvider aborts on invalid UUID format."""
|
|
request = noteflow_pb2.DeleteOidcProviderRequest(provider_id="invalid-uuid")
|
|
|
|
with pytest.raises(AssertionError, match="Unreachable"):
|
|
await oidc_servicer.DeleteOidcProvider(request, mock_grpc_context)
|
|
|
|
mock_grpc_context.abort.assert_called_once()
|
|
|
|
|
|
class TestRefreshOidcDiscovery:
|
|
"""Tests for RefreshOidcDiscovery RPC."""
|
|
|
|
async def test_refreshes_single_provider(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
sample_provider: OidcProviderConfig,
|
|
) -> None:
|
|
"""RefreshOidcDiscovery refreshes a single provider."""
|
|
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
|
|
mock_service = MagicMock()
|
|
mock_service.registry.get_provider = MagicMock(return_value=sample_provider)
|
|
mock_service.registry.refresh_discovery = AsyncMock()
|
|
mock_get_service.return_value = mock_service
|
|
|
|
request = noteflow_pb2.RefreshOidcDiscoveryRequest(provider_id=str(sample_provider.id))
|
|
response = await oidc_servicer.RefreshOidcDiscovery(request, mock_grpc_context)
|
|
|
|
expected_provider_id = str(sample_provider.id)
|
|
assert response.success_count == 1, "should report one success"
|
|
assert response.failure_count == 0, "should report no failures"
|
|
assert expected_provider_id in response.results, "provider ID should be in results"
|
|
|
|
async def test_reports_single_provider_failure(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
sample_provider: OidcProviderConfig,
|
|
) -> None:
|
|
"""RefreshOidcDiscovery reports failure for single provider."""
|
|
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
|
|
mock_service = MagicMock()
|
|
mock_service.registry.get_provider = MagicMock(return_value=sample_provider)
|
|
mock_service.registry.refresh_discovery = AsyncMock(
|
|
side_effect=OidcDiscoveryError("Network error")
|
|
)
|
|
mock_get_service.return_value = mock_service
|
|
|
|
request = noteflow_pb2.RefreshOidcDiscoveryRequest(provider_id=str(sample_provider.id))
|
|
response = await oidc_servicer.RefreshOidcDiscovery(request, mock_grpc_context)
|
|
|
|
assert response.success_count == 0, "should report no success"
|
|
assert response.failure_count == 1, "should report one failure"
|
|
assert "Network error" in response.results[str(sample_provider.id)], (
|
|
"error message should be in results"
|
|
)
|
|
|
|
async def test_refreshes_all_providers(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
) -> None:
|
|
"""RefreshOidcDiscovery refreshes all providers when no ID specified."""
|
|
provider1_id = uuid4()
|
|
provider2_id = uuid4()
|
|
|
|
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
|
|
mock_service = MagicMock()
|
|
mock_service.refresh_all_discovery = AsyncMock(
|
|
return_value={
|
|
provider1_id: None, # Success
|
|
provider2_id: "Connection refused", # Failure
|
|
}
|
|
)
|
|
mock_get_service.return_value = mock_service
|
|
|
|
request = noteflow_pb2.RefreshOidcDiscoveryRequest()
|
|
response = await oidc_servicer.RefreshOidcDiscovery(request, mock_grpc_context)
|
|
|
|
assert response.success_count == 1, "should count successes"
|
|
assert response.failure_count == 1, "should count failures"
|
|
assert response.results[str(provider1_id)] == "", (
|
|
"successful provider should have empty error"
|
|
)
|
|
assert "Connection refused" in response.results[str(provider2_id)], (
|
|
"failed provider should have error message"
|
|
)
|
|
|
|
async def test_aborts_when_single_provider_not_found(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
) -> None:
|
|
"""RefreshOidcDiscovery aborts when specified provider not found."""
|
|
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
|
|
mock_service = MagicMock()
|
|
mock_service.registry.get_provider = MagicMock(return_value=None)
|
|
mock_get_service.return_value = mock_service
|
|
|
|
request = noteflow_pb2.RefreshOidcDiscoveryRequest(provider_id=str(uuid4()))
|
|
|
|
with pytest.raises(AssertionError, match="Unreachable"):
|
|
await oidc_servicer.RefreshOidcDiscovery(request, mock_grpc_context)
|
|
|
|
mock_grpc_context.abort.assert_called_once()
|
|
|
|
|
|
class TestListOidcPresets:
|
|
"""Tests for ListOidcPresets RPC."""
|
|
|
|
async def test_returns_all_presets(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
) -> None:
|
|
"""ListOidcPresets returns all available provider presets."""
|
|
request = noteflow_pb2.ListOidcPresetsRequest()
|
|
response = await oidc_servicer.ListOidcPresets(request, mock_grpc_context)
|
|
|
|
preset_names = [p.preset for p in response.presets]
|
|
assert "authentik" in preset_names, "should include Authentik preset"
|
|
assert "keycloak" in preset_names, "should include Keycloak preset"
|
|
assert "auth0" in preset_names, "should include Auth0 preset"
|
|
assert len(response.presets) >= 6, "should have at least 6 presets"
|
|
|
|
async def test_presets_include_required_fields(
|
|
self,
|
|
oidc_servicer: MockServicerHost,
|
|
mock_grpc_context: MagicMock,
|
|
) -> None:
|
|
"""ListOidcPresets returns presets with all required fields."""
|
|
request = noteflow_pb2.ListOidcPresetsRequest()
|
|
response = await oidc_servicer.ListOidcPresets(request, mock_grpc_context)
|
|
|
|
authentik_preset = next((p for p in response.presets if p.preset == "authentik"), None)
|
|
|
|
assert authentik_preset is not None, "should find Authentik preset"
|
|
assert authentik_preset.display_name, "preset should have display name"
|
|
assert authentik_preset.description, "preset should have description"
|
|
assert len(authentik_preset.default_scopes) > 0, "preset should have scopes"
|