361 lines
15 KiB
Python
361 lines
15 KiB
Python
"""Tests for domain/ai/interrupts.py - LangGraph human-in-the-loop interrupts."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Final
|
|
|
|
import pytest
|
|
|
|
from noteflow.domain.ai.interrupts import (
|
|
DEFAULT_ANNOTATION_OPTIONS,
|
|
DEFAULT_SENSITIVE_OPTIONS,
|
|
DEFAULT_WEB_SEARCH_OPTIONS,
|
|
InterruptAction,
|
|
InterruptConfig,
|
|
InterruptRequest,
|
|
InterruptResponse,
|
|
InterruptType,
|
|
create_annotation_interrupt,
|
|
create_sensitive_action_interrupt,
|
|
create_web_search_interrupt,
|
|
)
|
|
|
|
FROZEN_ASSIGNMENT_MESSAGE: Final[str] = "cannot assign to field"
|
|
EXPECTED_WEB_OPTIONS_COUNT: Final[int] = 2
|
|
EXPECTED_ANNOTATION_OPTIONS_COUNT: Final[int] = 3
|
|
SAMPLE_TIMEOUT: Final[float] = 30.0
|
|
|
|
|
|
class TestInterruptType:
|
|
def test_web_search_approval_value(self) -> None:
|
|
assert InterruptType.WEB_SEARCH_APPROVAL == "web_search_approval"
|
|
|
|
def test_annotation_approval_value(self) -> None:
|
|
assert InterruptType.ANNOTATION_APPROVAL == "annotation_approval"
|
|
|
|
def test_sensitive_action_value(self) -> None:
|
|
assert InterruptType.SENSITIVE_ACTION == "sensitive_action"
|
|
|
|
|
|
class TestInterruptAction:
|
|
def test_approve_value(self) -> None:
|
|
assert InterruptAction.APPROVE == "approve"
|
|
|
|
def test_reject_value(self) -> None:
|
|
assert InterruptAction.REJECT == "reject"
|
|
|
|
def test_modify_value(self) -> None:
|
|
assert InterruptAction.MODIFY == "modify"
|
|
|
|
|
|
class TestDefaultOptions:
|
|
def test_web_search_options_has_approve_and_reject(self) -> None:
|
|
assert len(DEFAULT_WEB_SEARCH_OPTIONS) == EXPECTED_WEB_OPTIONS_COUNT, (
|
|
"Web search options should include expected number of choices"
|
|
)
|
|
assert InterruptAction.APPROVE.value in DEFAULT_WEB_SEARCH_OPTIONS, (
|
|
"Web search options should include approve"
|
|
)
|
|
assert InterruptAction.REJECT.value in DEFAULT_WEB_SEARCH_OPTIONS, (
|
|
"Web search options should include reject"
|
|
)
|
|
|
|
def test_annotation_options_has_approve_reject_modify(self) -> None:
|
|
assert len(DEFAULT_ANNOTATION_OPTIONS) == EXPECTED_ANNOTATION_OPTIONS_COUNT, (
|
|
"Annotation options should include expected number of choices"
|
|
)
|
|
assert InterruptAction.APPROVE.value in DEFAULT_ANNOTATION_OPTIONS, (
|
|
"Annotation options should include approve"
|
|
)
|
|
assert InterruptAction.REJECT.value in DEFAULT_ANNOTATION_OPTIONS, (
|
|
"Annotation options should include reject"
|
|
)
|
|
|
|
def test_sensitive_options_has_approve_and_reject(self) -> None:
|
|
assert len(DEFAULT_SENSITIVE_OPTIONS) == EXPECTED_WEB_OPTIONS_COUNT, (
|
|
"Sensitive options should include expected number of choices"
|
|
)
|
|
assert InterruptAction.APPROVE.value in DEFAULT_SENSITIVE_OPTIONS, (
|
|
"Sensitive options should include approve"
|
|
)
|
|
assert InterruptAction.REJECT.value in DEFAULT_SENSITIVE_OPTIONS, (
|
|
"Sensitive options should include reject"
|
|
)
|
|
|
|
|
|
class TestInterruptConfig:
|
|
def test_default_config_values(self, default_interrupt_config: InterruptConfig) -> None:
|
|
assert default_interrupt_config.allow_ignore is False, (
|
|
"Default allow_ignore should be False"
|
|
)
|
|
assert default_interrupt_config.allow_modify is False, (
|
|
"Default allow_modify should be False"
|
|
)
|
|
assert default_interrupt_config.timeout_seconds is None, "Default timeout should be None"
|
|
|
|
def test_config_with_all_options(self, permissive_interrupt_config: InterruptConfig) -> None:
|
|
assert permissive_interrupt_config.allow_ignore is True, (
|
|
"Permissive allow_ignore should be True"
|
|
)
|
|
assert permissive_interrupt_config.allow_modify is True, (
|
|
"Permissive allow_modify should be True"
|
|
)
|
|
assert permissive_interrupt_config.timeout_seconds == SAMPLE_TIMEOUT, "Timeout should match"
|
|
|
|
def test_interrupt_config_is_frozen(self, default_interrupt_config: InterruptConfig) -> None:
|
|
with pytest.raises(AttributeError, match=FROZEN_ASSIGNMENT_MESSAGE):
|
|
default_interrupt_config.allow_ignore = True
|
|
|
|
|
|
class TestInterruptRequest:
|
|
def test_request_creation_with_required_fields(self) -> None:
|
|
request = InterruptRequest(
|
|
interrupt_type=InterruptType.WEB_SEARCH_APPROVAL,
|
|
message="Allow web search?",
|
|
)
|
|
assert request.interrupt_type == InterruptType.WEB_SEARCH_APPROVAL, (
|
|
"Interrupt type should match"
|
|
)
|
|
assert request.message == "Allow web search?", "Message should match"
|
|
|
|
def test_request_default_context_is_empty_dict(self) -> None:
|
|
request = InterruptRequest(
|
|
interrupt_type=InterruptType.WEB_SEARCH_APPROVAL,
|
|
message="Test",
|
|
)
|
|
assert request.context == {}
|
|
|
|
def test_request_default_options_are_web_search(self) -> None:
|
|
request = InterruptRequest(
|
|
interrupt_type=InterruptType.WEB_SEARCH_APPROVAL,
|
|
message="Test",
|
|
)
|
|
assert request.options == DEFAULT_WEB_SEARCH_OPTIONS
|
|
|
|
def test_request_to_payload_includes_all_fields(
|
|
self,
|
|
web_search_interrupt_request: InterruptRequest,
|
|
) -> None:
|
|
payload = web_search_interrupt_request.to_request_payload()
|
|
assert "interrupt_type" in payload, "Payload should include interrupt type"
|
|
assert "message" in payload, "Payload should include message"
|
|
assert "context" in payload, "Payload should include context"
|
|
assert "options" in payload, "Payload should include options"
|
|
assert "config" in payload, "Payload should include config"
|
|
assert "request_id" in payload, "Payload should include request id"
|
|
|
|
def test_request_payload_config_is_dict(
|
|
self,
|
|
web_search_interrupt_request: InterruptRequest,
|
|
) -> None:
|
|
payload = web_search_interrupt_request.to_request_payload()
|
|
config = payload["config"]
|
|
assert isinstance(config, dict), "Config payload should be a dictionary"
|
|
assert "allow_ignore" in config, "Config should include allow_ignore"
|
|
assert "allow_modify" in config, "Config should include allow_modify"
|
|
assert "timeout_seconds" in config, "Config should include timeout_seconds"
|
|
|
|
def test_interrupt_request_is_frozen(
|
|
self,
|
|
web_search_interrupt_request: InterruptRequest,
|
|
) -> None:
|
|
with pytest.raises(AttributeError, match=FROZEN_ASSIGNMENT_MESSAGE):
|
|
web_search_interrupt_request.message = "Changed"
|
|
|
|
|
|
class TestInterruptResponse:
|
|
@pytest.mark.parametrize(
|
|
("action", "is_approved", "is_rejected", "is_modified"),
|
|
[
|
|
pytest.param(InterruptAction.APPROVE, True, False, False, id="approve"),
|
|
pytest.param(InterruptAction.REJECT, False, True, False, id="reject"),
|
|
pytest.param(InterruptAction.MODIFY, False, False, True, id="modify"),
|
|
],
|
|
)
|
|
def test_response_action_flags(
|
|
self,
|
|
sample_request_id: str,
|
|
action: InterruptAction,
|
|
is_approved: bool,
|
|
is_rejected: bool,
|
|
is_modified: bool,
|
|
) -> None:
|
|
response = InterruptResponse(action=action, request_id=sample_request_id)
|
|
assert response.action == action, "Response action should match the input action"
|
|
assert response.is_approved is is_approved, "Approved flag should match expected value"
|
|
assert response.is_rejected is is_rejected, "Rejected flag should match expected value"
|
|
assert response.is_modified is is_modified, "Modified flag should match expected value"
|
|
|
|
def test_response_modified_value_preserved(
|
|
self,
|
|
modified_response: InterruptResponse,
|
|
) -> None:
|
|
assert modified_response.modified_value is not None, (
|
|
"Modified response should include a modified value"
|
|
)
|
|
assert "query" in modified_response.modified_value, (
|
|
"Modified response should preserve query in modified value"
|
|
)
|
|
|
|
def test_response_to_payload_minimal(self, approved_response: InterruptResponse) -> None:
|
|
payload = approved_response.to_response_payload()
|
|
assert payload["action"] == InterruptAction.APPROVE, "Payload should include approve action"
|
|
assert "request_id" in payload, "Payload should include request id"
|
|
assert "modified_value" not in payload, "Payload should omit modified value"
|
|
assert "user_message" not in payload, "Payload should omit user message"
|
|
|
|
def test_response_to_payload_with_optional_fields(
|
|
self,
|
|
modified_response: InterruptResponse,
|
|
) -> None:
|
|
payload = modified_response.to_response_payload()
|
|
assert payload["action"] == InterruptAction.MODIFY, "Payload should include modify action"
|
|
assert "modified_value" in payload, "Payload should include modified value"
|
|
assert "user_message" in payload, "Payload should include user message"
|
|
|
|
def test_response_is_frozen(self, approved_response: InterruptResponse) -> None:
|
|
with pytest.raises(AttributeError, match=FROZEN_ASSIGNMENT_MESSAGE):
|
|
approved_response.action = InterruptAction.REJECT
|
|
|
|
|
|
class TestCreateWebSearchInterrupt:
|
|
def test_creates_web_search_type(self, sample_request_id: str) -> None:
|
|
request = create_web_search_interrupt(
|
|
query="test query",
|
|
request_id=sample_request_id,
|
|
)
|
|
assert request.interrupt_type == InterruptType.WEB_SEARCH_APPROVAL
|
|
|
|
def test_query_in_context(self, sample_request_id: str) -> None:
|
|
query = "What is the capital of France?"
|
|
request = create_web_search_interrupt(
|
|
query=query,
|
|
request_id=sample_request_id,
|
|
)
|
|
assert request.context["query"] == query
|
|
|
|
def test_message_contains_query_preview(self, sample_request_id: str) -> None:
|
|
query = "test search query"
|
|
request = create_web_search_interrupt(
|
|
query=query,
|
|
request_id=sample_request_id,
|
|
)
|
|
assert "test search query" in request.message
|
|
|
|
def test_truncates_long_query_in_message(self, sample_request_id: str) -> None:
|
|
long_query = "x" * 150
|
|
request = create_web_search_interrupt(
|
|
query=long_query,
|
|
request_id=sample_request_id,
|
|
)
|
|
assert len(request.message) < len(long_query) + 100
|
|
|
|
def test_default_options_without_modify(self, sample_request_id: str) -> None:
|
|
request = create_web_search_interrupt(
|
|
query="test",
|
|
request_id=sample_request_id,
|
|
)
|
|
assert request.options == DEFAULT_WEB_SEARCH_OPTIONS
|
|
|
|
def test_annotation_options_with_modify(self, sample_request_id: str) -> None:
|
|
request = create_web_search_interrupt(
|
|
query="test",
|
|
request_id=sample_request_id,
|
|
allow_modify=True,
|
|
)
|
|
assert request.options == DEFAULT_ANNOTATION_OPTIONS, (
|
|
"Allowing modify should use annotation options"
|
|
)
|
|
assert request.config.allow_modify is True, "Config should allow modify"
|
|
|
|
|
|
class TestCreateAnnotationInterrupt:
|
|
def test_creates_annotation_type(self, sample_request_id: str) -> None:
|
|
annotations: list[dict[str, object]] = [{"text": "test"}]
|
|
request = create_annotation_interrupt(
|
|
annotations=annotations,
|
|
request_id=sample_request_id,
|
|
)
|
|
assert request.interrupt_type == InterruptType.ANNOTATION_APPROVAL
|
|
|
|
def test_annotations_in_context(self, sample_request_id: str) -> None:
|
|
annotations: list[dict[str, object]] = [{"text": "action item"}]
|
|
request = create_annotation_interrupt(
|
|
annotations=annotations,
|
|
request_id=sample_request_id,
|
|
)
|
|
assert request.context["annotations"] == annotations, (
|
|
"Context should include provided annotations"
|
|
)
|
|
assert request.context["count"] == 1, "Context count should match annotations"
|
|
|
|
def test_message_contains_count(self, sample_request_id: str) -> None:
|
|
annotations: list[dict[str, object]] = [{"text": "a"}, {"text": "b"}]
|
|
request = create_annotation_interrupt(
|
|
annotations=annotations,
|
|
request_id=sample_request_id,
|
|
)
|
|
assert "2" in request.message
|
|
|
|
def test_uses_annotation_options(self, sample_request_id: str) -> None:
|
|
annotations: list[dict[str, object]] = [{"text": "test"}]
|
|
request = create_annotation_interrupt(
|
|
annotations=annotations,
|
|
request_id=sample_request_id,
|
|
)
|
|
assert request.options == DEFAULT_ANNOTATION_OPTIONS
|
|
|
|
def test_config_allows_modify_and_ignore(self, sample_request_id: str) -> None:
|
|
annotations: list[dict[str, object]] = [{"text": "test"}]
|
|
request = create_annotation_interrupt(
|
|
annotations=annotations,
|
|
request_id=sample_request_id,
|
|
)
|
|
assert request.config.allow_modify is True, "Config should allow modify"
|
|
assert request.config.allow_ignore is True, "Config should allow ignore"
|
|
|
|
|
|
class TestCreateSensitiveActionInterrupt:
|
|
def test_creates_sensitive_action_type(self, sample_request_id: str) -> None:
|
|
request = create_sensitive_action_interrupt(
|
|
action_name="delete",
|
|
action_description="Delete all data",
|
|
request_id=sample_request_id,
|
|
)
|
|
assert request.interrupt_type == InterruptType.SENSITIVE_ACTION
|
|
|
|
def test_action_details_in_context(self, sample_request_id: str) -> None:
|
|
request = create_sensitive_action_interrupt(
|
|
action_name="delete",
|
|
action_description="Delete all data",
|
|
request_id=sample_request_id,
|
|
)
|
|
assert request.context["action_name"] == "delete", "Context should include action name"
|
|
assert request.context["description"] == "Delete all data", (
|
|
"Context should include action description"
|
|
)
|
|
|
|
def test_message_contains_action_name(self, sample_request_id: str) -> None:
|
|
request = create_sensitive_action_interrupt(
|
|
action_name="export_data",
|
|
action_description="Export sensitive data",
|
|
request_id=sample_request_id,
|
|
)
|
|
assert "export_data" in request.message
|
|
|
|
def test_uses_sensitive_options(self, sample_request_id: str) -> None:
|
|
request = create_sensitive_action_interrupt(
|
|
action_name="delete",
|
|
action_description="Delete all data",
|
|
request_id=sample_request_id,
|
|
)
|
|
assert request.options == DEFAULT_SENSITIVE_OPTIONS
|
|
|
|
def test_config_does_not_allow_ignore(self, sample_request_id: str) -> None:
|
|
request = create_sensitive_action_interrupt(
|
|
action_name="delete",
|
|
action_description="Delete all data",
|
|
request_id=sample_request_id,
|
|
)
|
|
assert request.config.allow_ignore is False
|