Refactor and Enhance Codebase for Improved Modularity and Type Safety
- Analyzed and documented architectural improvements based on the `repomix-output.md`, focusing on reducing code duplication and enhancing type safety. - Introduced a new `ConditionalCompositeAction` class to support runtime conditional execution of action steps. - Refactored existing action and browser helper methods to improve modularity and maintainability. - Updated GraphQL client integration for better connection pooling and resource management. - Enhanced error handling and diagnostics across various components, ensuring clearer feedback during execution. - Removed outdated playbook actions and streamlined the action registry for better clarity and performance. - Updated configuration files to reflect changes in browser host management and session handling. - Added new tests to validate the refactored components and ensure robust functionality.
This commit is contained in:
@@ -1,15 +1,7 @@
|
||||
hosts:
|
||||
demo-cdp:
|
||||
kind: cdp
|
||||
host: 192.168.50.185
|
||||
port: 9223
|
||||
demo-extension:
|
||||
kind: extension
|
||||
port: 17373
|
||||
support-cdp:
|
||||
kind: cdp
|
||||
host: 192.168.50.108
|
||||
port: 9223
|
||||
support-extension:
|
||||
kind: extension
|
||||
port: 17374
|
||||
@@ -17,5 +9,5 @@ hosts:
|
||||
kind: cdp
|
||||
host: browserless.lab # goes through Traefik
|
||||
port: 80 # Traefik web entrypoint
|
||||
cdp_url: ws://browserless.lab:80/ # explicit endpoint to avoid 0.0.0.0 from /json/version
|
||||
browser: chromium
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from guide.app.actions.auth.login import LoginAsPersonaAction
|
||||
from guide.app.actions.auth.request_otp import RequestOtpAction
|
||||
|
||||
__all__ = ["LoginAsPersonaAction"]
|
||||
__all__ = ["LoginAsPersonaAction", "RequestOtpAction"]
|
||||
|
||||
@@ -1,31 +1,16 @@
|
||||
"""Login action with session persistence support."""
|
||||
|
||||
from typing import ClassVar, override
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from guide.app import errors
|
||||
from guide.app.actions.base import DemoAction, register_action
|
||||
from guide.app.auth import (
|
||||
SessionManager,
|
||||
detect_current_persona,
|
||||
login_with_otp_url,
|
||||
)
|
||||
from guide.app import errors
|
||||
from guide.app.browser.types import PageLike
|
||||
from guide.app.models.domain import ActionContext, ActionResult
|
||||
from guide.app.models.personas import PersonaStore
|
||||
|
||||
|
||||
def _extract_base_url(url: str) -> str:
|
||||
"""Extract base URL (scheme + netloc) from a full URL.
|
||||
|
||||
Args:
|
||||
url: Full URL possibly with path, query, fragment.
|
||||
|
||||
Returns:
|
||||
Base URL like 'https://stg.raindrop.com'.
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
return f"{parsed.scheme}://{parsed.netloc}"
|
||||
from guide.app.models.personas import PersonaResolver, PersonaStore
|
||||
|
||||
|
||||
@register_action
|
||||
@@ -46,68 +31,59 @@ class LoginAsPersonaAction(DemoAction):
|
||||
|
||||
_personas: PersonaStore
|
||||
_session_manager: SessionManager
|
||||
_persona_resolver: PersonaResolver
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
persona_store: PersonaStore,
|
||||
session_manager: SessionManager,
|
||||
persona_resolver: PersonaResolver,
|
||||
) -> None:
|
||||
"""Initialize login action.
|
||||
|
||||
Args:
|
||||
persona_store: Store for looking up personas.
|
||||
session_manager: Manager for session persistence.
|
||||
persona_resolver: Resolver for email-based persona lookup.
|
||||
"""
|
||||
self._personas = persona_store
|
||||
self._session_manager = session_manager
|
||||
self._persona_resolver = persona_resolver
|
||||
|
||||
@override
|
||||
async def run(self, page: PageLike, context: ActionContext) -> ActionResult:
|
||||
"""Execute login action with session handling.
|
||||
|
||||
Flow:
|
||||
1. Try to restore existing session (if not forcing fresh)
|
||||
2. Validate restored session via DOM check
|
||||
3. If invalid, perform fresh login via OTP URL
|
||||
4. Save session after successful login
|
||||
"""
|
||||
if context.persona_id is None:
|
||||
raise errors.PersonaError("persona_id is required for login action")
|
||||
1. Resolve persona from persona_id or email param
|
||||
2. Try to restore existing session (if not forcing fresh)
|
||||
3. Validate restored session via DOM check
|
||||
4. If invalid, perform fresh login via OTP URL
|
||||
5. Save session after successful login
|
||||
|
||||
persona = self._personas.get(context.persona_id)
|
||||
Supports two resolution modes:
|
||||
- Traditional: context.persona_id set (existing behavior)
|
||||
- Email-based: context.params["email"] with persona_id=None (for n8n flows)
|
||||
"""
|
||||
email = context.params.get("email")
|
||||
|
||||
if context.persona_id is not None:
|
||||
persona = self._personas.get(context.persona_id)
|
||||
elif email and isinstance(email, str):
|
||||
persona = self._persona_resolver.resolve_by_email(email)
|
||||
else:
|
||||
raise errors.ActionExecutionError(
|
||||
"Either persona_id or email param is required for login action",
|
||||
details={"provided_params": list(context.params.keys())},
|
||||
)
|
||||
otp_url = context.params.get("url")
|
||||
force_fresh = context.params.get("force_fresh_login", False)
|
||||
|
||||
# 1. Try to restore existing session (if not forcing fresh)
|
||||
if not force_fresh:
|
||||
session = self._session_manager.load_session(persona.id)
|
||||
if session:
|
||||
validation = self._session_manager.validate_offline(session)
|
||||
if validation.is_valid:
|
||||
# Extract base URL from origin (OTP URL contains expired codes)
|
||||
base_url = _extract_base_url(session.origin_url)
|
||||
|
||||
# Navigate to base URL first (required for localStorage to work)
|
||||
_ = await page.goto(base_url)
|
||||
|
||||
# Inject session localStorage
|
||||
await self._session_manager.inject_local_storage(page, session)
|
||||
|
||||
# Navigate again to pick up the injected session
|
||||
_ = await page.goto(base_url)
|
||||
|
||||
current = await detect_current_persona(page)
|
||||
if current and current.lower() == persona.email.lower():
|
||||
return ActionResult(
|
||||
details={
|
||||
"persona_id": persona.id,
|
||||
"status": "session_restored",
|
||||
"remaining_seconds": validation.remaining_seconds,
|
||||
}
|
||||
)
|
||||
|
||||
# Session invalid - invalidate it
|
||||
_ = self._session_manager.invalidate(persona.id)
|
||||
restored = await self._session_manager.restore_session(page, persona)
|
||||
if restored:
|
||||
return restored
|
||||
|
||||
# 2. Perform fresh login via OTP URL
|
||||
if not otp_url or not isinstance(otp_url, str):
|
||||
|
||||
532
src/guide/app/actions/auth/request_otp.py
Normal file
532
src/guide/app/actions/auth/request_otp.py
Normal file
@@ -0,0 +1,532 @@
|
||||
"""OTP request action with webhook callback flow.
|
||||
|
||||
Triggers OTP email by interacting with login page, sends webhook to n8n,
|
||||
waits for n8n to find the OTP URL and call back, then completes login.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Literal, cast, override
|
||||
|
||||
import httpx
|
||||
|
||||
from guide.app import errors
|
||||
from guide.app.actions.base import DemoAction, register_action
|
||||
from guide.app.auth import (
|
||||
SessionManager,
|
||||
get_otp_callback_store,
|
||||
login_with_otp_url,
|
||||
login_with_verification_code,
|
||||
)
|
||||
from guide.app.browser.helpers import PageHelpers
|
||||
from guide.app.browser.types import PageLike
|
||||
from guide.app.core.config import AppSettings
|
||||
from guide.app.models.domain import ActionContext, ActionResult
|
||||
from guide.app.models.personas import DemoPersona, PersonaResolver
|
||||
from guide.app.strings.selectors.login import LoginSelectors
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OtpCredential:
|
||||
"""OTP credential returned by n8n webhook.
|
||||
|
||||
Can be either a magic link URL or a verification code.
|
||||
"""
|
||||
|
||||
type: Literal["url", "code"]
|
||||
value: str
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def trigger_otp_email(
|
||||
page: PageLike,
|
||||
email: str,
|
||||
login_url: str,
|
||||
) -> bool:
|
||||
"""Navigate to login page and trigger OTP email.
|
||||
|
||||
Handles both dropdown and text input modes for email field.
|
||||
|
||||
Args:
|
||||
page: Browser page instance.
|
||||
email: Email address to request OTP for.
|
||||
login_url: Base URL of the application (login page will be at /login).
|
||||
|
||||
Returns:
|
||||
True if OTP was triggered successfully.
|
||||
"""
|
||||
helpers = PageHelpers(page)
|
||||
|
||||
# Navigate to login page
|
||||
full_login_url = f"{login_url.rstrip('/')}/login"
|
||||
_logger.info("Navigating to login page: %s", full_login_url)
|
||||
_ = await page.goto(full_login_url, wait_until="networkidle")
|
||||
_logger.info("Page loaded, current URL: %s", page.url)
|
||||
await helpers.wait_for_stable()
|
||||
|
||||
# Wait for login container with longer timeout
|
||||
_logger.info("Waiting for login container: %s", LoginSelectors.LOGIN_CONTAINER)
|
||||
_ = await page.wait_for_selector(
|
||||
LoginSelectors.LOGIN_CONTAINER, state="visible", timeout=30000
|
||||
)
|
||||
|
||||
# Check if email field is a dropdown or text input
|
||||
email_text_input = page.locator(LoginSelectors.EMAIL_TEXT_INPUT)
|
||||
email_dropdown = page.locator(LoginSelectors.EMAIL_DROPDOWN)
|
||||
|
||||
if await email_dropdown.count() > 0:
|
||||
# Dropdown mode - check if correct email is already selected
|
||||
current_value = await email_dropdown.text_content()
|
||||
if current_value and email.lower() in current_value.lower():
|
||||
_logger.info("Email already selected in dropdown: %s", email)
|
||||
else:
|
||||
# Try to select from dropdown or clear and type
|
||||
_logger.info("Email dropdown detected, attempting to select: %s", email)
|
||||
|
||||
# Click the dropdown to open options
|
||||
await email_dropdown.click()
|
||||
await helpers.wait_for_network_idle()
|
||||
|
||||
# Look for matching option (escape quotes in email for selector safety)
|
||||
escaped_email = email.replace('"', '\\"').replace("'", "\\'")
|
||||
option = page.locator(f'li:has-text("{escaped_email}")')
|
||||
if await option.count() > 0:
|
||||
await option.first.click()
|
||||
_logger.info("Selected email from dropdown")
|
||||
else:
|
||||
# Email not in dropdown - need to clear and type
|
||||
_logger.info(
|
||||
"Email not in dropdown options, clearing to enable text input"
|
||||
)
|
||||
# Close dropdown by clicking elsewhere
|
||||
await page.click(LoginSelectors.LOGIN_CONTAINER)
|
||||
await helpers.wait_for_network_idle()
|
||||
|
||||
# Click clear button to enable text input
|
||||
clear_btn = page.locator(LoginSelectors.EMAIL_CLEAR_BUTTON)
|
||||
if await clear_btn.count() > 0:
|
||||
await clear_btn.click()
|
||||
await helpers.wait_for_network_idle()
|
||||
|
||||
# Now fill the text input
|
||||
await page.fill(LoginSelectors.EMAIL_TEXT_INPUT, email)
|
||||
_logger.info("Filled email in text input after clearing")
|
||||
else:
|
||||
_logger.warning("Could not find clear button to enable text input")
|
||||
return False
|
||||
|
||||
elif await email_text_input.count() > 0:
|
||||
# Text input mode - fill directly
|
||||
_logger.info("Email text input detected, filling: %s", email)
|
||||
await page.fill(LoginSelectors.EMAIL_TEXT_INPUT, email)
|
||||
else:
|
||||
_logger.error("Could not find email field (dropdown or text input)")
|
||||
return False
|
||||
|
||||
# Click login button to trigger OTP
|
||||
_ = await page.wait_for_selector(
|
||||
LoginSelectors.LOGIN_BUTTON, state="visible", timeout=5000
|
||||
)
|
||||
|
||||
_logger.info("Clicking login button to trigger OTP email")
|
||||
await page.click(LoginSelectors.LOGIN_BUTTON)
|
||||
await helpers.wait_for_network_idle()
|
||||
|
||||
# Check for errors
|
||||
error_el = page.locator(LoginSelectors.ERROR_MESSAGE)
|
||||
if await error_el.count() > 0:
|
||||
error_text = await error_el.text_content()
|
||||
if error_text and error_text.strip():
|
||||
_logger.error("Login error: %s", error_text)
|
||||
return False
|
||||
|
||||
_logger.info("OTP email triggered successfully for: %s", email)
|
||||
return True
|
||||
|
||||
|
||||
async def send_otp_webhook(
|
||||
webhook_url: str,
|
||||
correlation_id: str,
|
||||
email: str,
|
||||
callback_url: str,
|
||||
timeout: float = 120.0,
|
||||
) -> OtpCredential | None:
|
||||
"""Send webhook to n8n to notify OTP was requested.
|
||||
|
||||
Supports two response modes:
|
||||
1. Synchronous: n8n returns OTP credential (URL or code) in the HTTP response body
|
||||
2. Async callback: n8n returns empty/ack response, calls back later
|
||||
|
||||
Args:
|
||||
webhook_url: n8n webhook URL.
|
||||
correlation_id: Unique ID to correlate callback.
|
||||
email: Email OTP was requested for.
|
||||
callback_url: URL for n8n to send OTP URL back (for async mode).
|
||||
timeout: Request timeout in seconds (default 120s for sync mode).
|
||||
|
||||
Returns:
|
||||
OtpCredential (url or code) if included in response, None if webhook failed.
|
||||
"""
|
||||
payload = {
|
||||
"event": "otp_requested",
|
||||
"correlation_id": correlation_id,
|
||||
"email": email,
|
||||
"callback_url": callback_url,
|
||||
}
|
||||
|
||||
_logger.info(
|
||||
"Sending OTP webhook to n8n: correlation_id=%s, email=%s",
|
||||
correlation_id,
|
||||
email,
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as client:
|
||||
_logger.info("Sending POST to webhook_url: %s", webhook_url)
|
||||
response = await client.post(webhook_url, json=payload)
|
||||
_logger.info("Webhook response status: %s", response.status_code)
|
||||
_ = response.raise_for_status()
|
||||
_logger.info("OTP webhook sent successfully")
|
||||
|
||||
# Try to extract OTP credential from response body
|
||||
credential = _extract_otp_credential_from_response(response)
|
||||
if credential:
|
||||
_logger.info(
|
||||
"Received OTP %s in webhook response (sync mode)", credential.type
|
||||
)
|
||||
return credential
|
||||
except httpx.HTTPError as exc:
|
||||
_logger.error(
|
||||
"Failed to send OTP webhook: %s (type: %s)", exc, type(exc).__name__
|
||||
)
|
||||
return None
|
||||
except Exception as exc:
|
||||
_logger.error(
|
||||
"Unexpected error sending webhook: %s (type: %s)", exc, type(exc).__name__
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _extract_credential_from_dict(data: dict[str, object]) -> OtpCredential | None:
|
||||
"""Extract OTP credential from a dictionary.
|
||||
|
||||
Priority: URL first (navigate to magic link), then code (fallback).
|
||||
The URL flow will request the code from n8n if needed after navigation.
|
||||
|
||||
Args:
|
||||
data: Dictionary to search for credential fields.
|
||||
|
||||
Returns:
|
||||
OtpCredential if found, None otherwise.
|
||||
"""
|
||||
# Check for URL fields first - navigate to magic link before using code
|
||||
if (otp_url := data.get("otp_url")) and isinstance(otp_url, str):
|
||||
return OtpCredential(type="url", value=otp_url)
|
||||
if (access_url := data.get("access_url")) and isinstance(access_url, str):
|
||||
return OtpCredential(type="url", value=access_url)
|
||||
|
||||
# Fallback: verification code (only if no URL provided)
|
||||
if (code := data.get("verification_code")) and isinstance(code, str):
|
||||
return OtpCredential(type="code", value=code)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _extract_otp_credential_from_response(
|
||||
response: httpx.Response,
|
||||
) -> OtpCredential | None:
|
||||
"""Extract OTP credential (URL or verification code) from n8n webhook response.
|
||||
|
||||
Handles multiple response formats:
|
||||
- {"verification_code": "..."}
|
||||
- {"otp_url": "..."}
|
||||
- {"access_url": "..."}
|
||||
- {"output": {"verification_code": "..."}}
|
||||
- {"output": {"otp_url": "..."}}
|
||||
- {"output": {"access_url": "..."}}
|
||||
|
||||
Returns:
|
||||
OtpCredential with type "code" or "url", or None if not found.
|
||||
"""
|
||||
try:
|
||||
raw: object = cast(object, response.json())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if not isinstance(raw, dict):
|
||||
return None
|
||||
|
||||
data = cast(dict[str, object], raw)
|
||||
|
||||
# Check top-level fields first
|
||||
if credential := _extract_credential_from_dict(data):
|
||||
return credential
|
||||
|
||||
# Check nested output object (n8n format)
|
||||
output = data.get("output")
|
||||
if isinstance(output, dict):
|
||||
output_dict = cast(dict[str, object], output)
|
||||
return _extract_credential_from_dict(output_dict)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@register_action
|
||||
class RequestOtpAction(DemoAction):
|
||||
"""Request OTP with webhook callback flow and session persistence.
|
||||
|
||||
Complete flow:
|
||||
1. Try to restore existing session from disk
|
||||
2. If no valid session, navigate to login page
|
||||
3. Enter email and click login to trigger OTP
|
||||
4. Send webhook to n8n with correlation_id
|
||||
5. Wait for n8n to call back with OTP URL
|
||||
6. Complete login with OTP URL
|
||||
7. Save session to disk for future requests
|
||||
|
||||
Request params:
|
||||
email: Email address to request OTP for (required)
|
||||
callback_base_url: Base URL for callback endpoint (optional, defaults to localhost)
|
||||
force_fresh_login: Skip session restoration (optional, default: false)
|
||||
|
||||
Requires:
|
||||
- RAINDROP_DEMO_N8N_WEBHOOK_URL environment variable
|
||||
"""
|
||||
|
||||
id: ClassVar[str] = "auth.request_otp"
|
||||
description: ClassVar[str] = (
|
||||
"Request OTP email and wait for callback with magic link."
|
||||
)
|
||||
category: ClassVar[str] = "auth"
|
||||
|
||||
_settings: AppSettings
|
||||
_session_manager: SessionManager
|
||||
_persona_resolver: PersonaResolver
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: AppSettings,
|
||||
session_manager: SessionManager,
|
||||
persona_resolver: PersonaResolver,
|
||||
) -> None:
|
||||
"""Initialize with settings and session management.
|
||||
|
||||
Args:
|
||||
settings: Application settings with n8n webhook URL.
|
||||
session_manager: Manager for session persistence.
|
||||
persona_resolver: Resolver for email-based persona lookup.
|
||||
"""
|
||||
self._settings = settings
|
||||
self._session_manager = session_manager
|
||||
self._persona_resolver = persona_resolver
|
||||
|
||||
@override
|
||||
async def run(self, page: PageLike, context: ActionContext) -> ActionResult:
|
||||
"""Execute OTP request flow with session persistence.
|
||||
|
||||
Args:
|
||||
page: Browser page instance.
|
||||
context: Action context with params.
|
||||
|
||||
Returns:
|
||||
ActionResult with login status.
|
||||
|
||||
Raises:
|
||||
ActionExecutionError: If OTP request fails.
|
||||
AuthError: If login fails after receiving OTP URL.
|
||||
"""
|
||||
email = context.params.get("email")
|
||||
if not email or not isinstance(email, str):
|
||||
raise errors.ActionExecutionError(
|
||||
"'email' param is required for OTP request",
|
||||
details={"provided_params": list(context.params.keys())},
|
||||
)
|
||||
|
||||
force_fresh = context.params.get("force_fresh_login", False)
|
||||
|
||||
# Resolve persona from email for session management
|
||||
persona: DemoPersona | None = None
|
||||
try:
|
||||
persona = self._persona_resolver.resolve_by_email(email)
|
||||
except errors.PersonaError:
|
||||
_logger.debug("No persona found for email %s, session caching disabled", email)
|
||||
|
||||
# 0. Try to restore existing session (if not forcing fresh)
|
||||
if persona and not force_fresh:
|
||||
restored = await self._session_manager.restore_session(page, persona)
|
||||
if restored:
|
||||
return restored
|
||||
|
||||
# Continue with OTP flow if session restoration failed
|
||||
webhook_url = self._settings.n8n_webhook_url
|
||||
if not webhook_url:
|
||||
raise errors.ConfigError(
|
||||
"n8n_webhook_url not configured",
|
||||
details={
|
||||
"hint": "Set RAINDROP_DEMO_N8N_WEBHOOK_URL environment variable"
|
||||
},
|
||||
)
|
||||
|
||||
# Callback URL for n8n to send OTP back
|
||||
callback_base = context.params.get(
|
||||
"callback_base_url", self._settings.callback_base_url
|
||||
)
|
||||
callback_url = f"{callback_base}/auth/otp-callback"
|
||||
|
||||
correlation_id = context.correlation_id
|
||||
store = get_otp_callback_store()
|
||||
|
||||
# 1. Trigger OTP email on login page
|
||||
triggered = await trigger_otp_email(
|
||||
page,
|
||||
email,
|
||||
self._settings.raindrop_base_url,
|
||||
)
|
||||
if not triggered:
|
||||
raise errors.ActionExecutionError(
|
||||
f"Failed to trigger OTP email for {email}",
|
||||
details={"email": email},
|
||||
)
|
||||
|
||||
# 1.5. Wait for email to arrive before notifying n8n
|
||||
# This ensures n8n finds the fresh email, not an old one
|
||||
email_delay = self._settings.n8n_otp_email_delay
|
||||
_logger.info("Waiting %.1fs for OTP email to arrive...", email_delay)
|
||||
await asyncio.sleep(email_delay)
|
||||
|
||||
# 2. Send webhook to n8n and check for sync response
|
||||
timeout = self._settings.n8n_otp_callback_timeout
|
||||
credential = await send_otp_webhook(
|
||||
webhook_url,
|
||||
correlation_id,
|
||||
email,
|
||||
callback_url,
|
||||
timeout=float(timeout),
|
||||
)
|
||||
|
||||
# 3. If no credential in response, fall back to async callback
|
||||
if not credential:
|
||||
_logger.info(
|
||||
"No credential in webhook response, waiting for async callback"
|
||||
)
|
||||
_ = await store.register(correlation_id, email)
|
||||
try:
|
||||
otp_url = await store.wait_for_callback(correlation_id, timeout=timeout)
|
||||
credential = OtpCredential(type="url", value=otp_url)
|
||||
except TimeoutError as exc:
|
||||
raise errors.ActionExecutionError(
|
||||
f"Timeout waiting for OTP callback ({timeout}s)",
|
||||
details={"correlation_id": correlation_id, "email": email},
|
||||
) from exc
|
||||
except ValueError as exc:
|
||||
raise errors.ActionExecutionError(
|
||||
f"OTP callback error: {exc}",
|
||||
details={"correlation_id": correlation_id, "email": email},
|
||||
) from exc
|
||||
|
||||
# 4. Complete login based on credential type
|
||||
success = await self._complete_login(
|
||||
page=page,
|
||||
credential=credential,
|
||||
email=email,
|
||||
webhook_url=webhook_url,
|
||||
correlation_id=correlation_id,
|
||||
callback_url=callback_url,
|
||||
timeout=float(timeout),
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise errors.AuthError(
|
||||
f"Login failed for {email}",
|
||||
details={"email": email, "credential_type": credential.type},
|
||||
)
|
||||
|
||||
# 5. Save session after successful login
|
||||
if persona and self._session_manager.auto_persist:
|
||||
_ = await self._session_manager.save_session(
|
||||
page, persona, self._settings.raindrop_base_url
|
||||
)
|
||||
_logger.info("Saved session for persona %s", persona.id)
|
||||
|
||||
return ActionResult(
|
||||
details={
|
||||
"email": email,
|
||||
"status": "logged_in",
|
||||
"correlation_id": correlation_id,
|
||||
}
|
||||
)
|
||||
|
||||
async def _complete_login(
|
||||
self,
|
||||
page: PageLike,
|
||||
credential: OtpCredential,
|
||||
email: str,
|
||||
webhook_url: str,
|
||||
correlation_id: str,
|
||||
callback_url: str,
|
||||
timeout: float,
|
||||
) -> bool:
|
||||
"""Complete login with OTP credential (URL or verification code).
|
||||
|
||||
Handles two-phase flow:
|
||||
- If credential is URL: navigate and login, detect if verification code page appears
|
||||
- If credential is code: fill verification code directly
|
||||
- If URL leads to verification code page: call webhook again for code
|
||||
|
||||
Args:
|
||||
page: Browser page instance.
|
||||
credential: OTP credential (url or code).
|
||||
email: Email address for validation.
|
||||
webhook_url: n8n webhook URL for re-fetch if needed.
|
||||
correlation_id: Request correlation ID.
|
||||
callback_url: Callback URL for async mode.
|
||||
timeout: Timeout for webhook calls.
|
||||
|
||||
Returns:
|
||||
True if login successful, False otherwise.
|
||||
"""
|
||||
if credential.type == "code":
|
||||
# Direct verification code login
|
||||
_logger.info("Using verification code for login: %s", email)
|
||||
return await login_with_verification_code(page, credential.value, email)
|
||||
|
||||
# URL-based login
|
||||
_logger.info("Using OTP URL for login: %s", email)
|
||||
success = await login_with_otp_url(page, credential.value, email)
|
||||
|
||||
if success:
|
||||
return True
|
||||
|
||||
# Check if page is asking for verification code (two-phase flow)
|
||||
code_input = page.locator(LoginSelectors.VERIFICATION_CODE_INPUT)
|
||||
if await code_input.count() > 0:
|
||||
_logger.info(
|
||||
"OTP URL led to verification code page, fetching code from n8n..."
|
||||
)
|
||||
|
||||
# Call webhook again - n8n will return verification code this time
|
||||
code_credential = await send_otp_webhook(
|
||||
webhook_url,
|
||||
correlation_id,
|
||||
email,
|
||||
callback_url,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if code_credential and code_credential.type == "code":
|
||||
_logger.info("Received verification code from n8n, completing login")
|
||||
return await login_with_verification_code(
|
||||
page, code_credential.value, email
|
||||
)
|
||||
|
||||
_logger.error("Failed to get verification code from n8n webhook")
|
||||
return False
|
||||
|
||||
# Login failed for other reasons
|
||||
return False
|
||||
|
||||
|
||||
__all__ = ["RequestOtpAction", "trigger_otp_email", "send_otp_webhook"]
|
||||
@@ -1,7 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from inspect import Parameter, signature
|
||||
from typing import ClassVar, override, cast
|
||||
from typing import ClassVar, cast, override
|
||||
|
||||
from guide.app import errors
|
||||
from guide.app.browser.types import PageLike
|
||||
@@ -136,6 +136,103 @@ class CompositeAction(DemoAction):
|
||||
pass # Default: no processing
|
||||
|
||||
|
||||
class ConditionalCompositeAction(CompositeAction):
|
||||
"""Composite action with runtime conditional step execution.
|
||||
|
||||
Extends CompositeAction to support skipping steps based on runtime conditions.
|
||||
Override `should_execute_step()` to implement conditional logic.
|
||||
|
||||
Example:
|
||||
@register_action
|
||||
class ConditionalFlow(ConditionalCompositeAction):
|
||||
id = "conditional-flow"
|
||||
description = "Flow with conditional steps"
|
||||
category = "flows"
|
||||
child_actions = ("step-one", "step-two", "step-three")
|
||||
|
||||
@override
|
||||
async def should_execute_step(
|
||||
self, step_id: str, context: ActionContext
|
||||
) -> bool:
|
||||
if step_id == "step-two":
|
||||
return context.params.get("include_step_two", True)
|
||||
return True
|
||||
"""
|
||||
|
||||
context: ActionContext | None
|
||||
|
||||
async def should_execute_step(self, step_id: str, context: ActionContext) -> bool:
|
||||
"""Determine if a step should execute.
|
||||
|
||||
Override in subclasses to implement conditional logic based on
|
||||
context.params, context.shared_state, or other runtime conditions.
|
||||
|
||||
Args:
|
||||
step_id: The action ID of the step to potentially execute.
|
||||
context: The current action context.
|
||||
|
||||
Returns:
|
||||
True to execute the step, False to skip it.
|
||||
"""
|
||||
_ = step_id, context # Unused in base implementation
|
||||
return True
|
||||
|
||||
@override
|
||||
async def run(self, page: PageLike, context: ActionContext) -> ActionResult:
|
||||
"""Execute child actions conditionally.
|
||||
|
||||
Args:
|
||||
page: The Playwright page instance.
|
||||
context: The action context (shared across all steps).
|
||||
|
||||
Returns:
|
||||
ActionResult with combined status, details, and skipped steps.
|
||||
"""
|
||||
self.context = context
|
||||
details: dict[str, object] = {}
|
||||
skipped: list[str] = []
|
||||
|
||||
for step_id in self.child_actions:
|
||||
if not await self.should_execute_step(step_id, context):
|
||||
skipped.append(step_id)
|
||||
details[step_id] = {"skipped": True}
|
||||
continue
|
||||
|
||||
try:
|
||||
action = self.registry.get(step_id)
|
||||
result = await action.run(page, context)
|
||||
details[step_id] = result.details
|
||||
|
||||
if result.status == "error":
|
||||
return ActionResult(
|
||||
status="error",
|
||||
details={
|
||||
"failed_step": step_id,
|
||||
"steps": details,
|
||||
"skipped": skipped,
|
||||
},
|
||||
error=result.error,
|
||||
)
|
||||
|
||||
await self.on_step_complete(step_id, result)
|
||||
|
||||
except Exception as exc:
|
||||
return ActionResult(
|
||||
status="error",
|
||||
details={
|
||||
"failed_step": step_id,
|
||||
"steps": details,
|
||||
"skipped": skipped,
|
||||
},
|
||||
error=f"Exception in step '{step_id}': {exc}",
|
||||
)
|
||||
|
||||
return ActionResult(
|
||||
status="ok",
|
||||
details={"steps": details, "skipped": skipped},
|
||||
)
|
||||
|
||||
|
||||
class ActionRegistry:
|
||||
"""Manages action instances and metadata.
|
||||
|
||||
@@ -308,6 +405,7 @@ class ActionRegistry:
|
||||
__all__ = [
|
||||
"DemoAction",
|
||||
"CompositeAction",
|
||||
"ConditionalCompositeAction",
|
||||
"ActionRegistry",
|
||||
"register_action",
|
||||
"get_registered_actions",
|
||||
|
||||
@@ -71,7 +71,7 @@ class FillContractFormAction(DemoAction):
|
||||
}})();
|
||||
"""
|
||||
with contextlib.suppress(Exception):
|
||||
await page.evaluate(script)
|
||||
_ = await page.evaluate(script)
|
||||
await page.wait_for_timeout(100)
|
||||
|
||||
async def attempt(name: str, coro: Awaitable[dict[str, object] | None | bool], timeout: float = 0.6) -> None:
|
||||
@@ -90,7 +90,7 @@ class FillContractFormAction(DemoAction):
|
||||
base_selector = primary_selector(selector)
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
await page.wait_for_selector(base_selector, timeout=0.2)
|
||||
_ = await page.wait_for_selector(base_selector, timeout=0.2)
|
||||
await page.click(base_selector)
|
||||
await page.fill(base_selector, value)
|
||||
return True
|
||||
@@ -98,7 +98,7 @@ class FillContractFormAction(DemoAction):
|
||||
pass
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
await page.wait_for_selector(f"{base_selector} input, {base_selector} textarea", timeout=0.2)
|
||||
_ = await page.wait_for_selector(f"{base_selector} input, {base_selector} textarea", timeout=0.2)
|
||||
await page.click(f"{base_selector} input, {base_selector} textarea")
|
||||
await page.fill(f"{base_selector} input, {base_selector} textarea", value)
|
||||
return True
|
||||
@@ -128,7 +128,7 @@ class FillContractFormAction(DemoAction):
|
||||
|
||||
if not is_checked:
|
||||
# Click using mouse events for MUI compatibility
|
||||
await click_with_mouse_events(page, selector, focus_first=True)
|
||||
_ = await click_with_mouse_events(page, selector, focus_first=True)
|
||||
await page.wait_for_timeout(100)
|
||||
except Exception:
|
||||
pass
|
||||
@@ -137,7 +137,7 @@ class FillContractFormAction(DemoAction):
|
||||
"""Blur/deselect a field to enable downstream fields."""
|
||||
sel = primary_selector(selector)
|
||||
field_selector_js = sel.replace("\\", "\\\\").replace("'", "\\'").replace('"', '\\"')
|
||||
await page.evaluate(
|
||||
_ = await page.evaluate(
|
||||
f"""
|
||||
(() => {{
|
||||
const root = document.querySelector('{field_selector_js}');
|
||||
@@ -224,7 +224,7 @@ class FillContractFormAction(DemoAction):
|
||||
return
|
||||
await scroll_into_view(sel)
|
||||
esc = value.replace("\\", "\\\\").replace("'", "\\'")
|
||||
await page.evaluate(
|
||||
_ = await page.evaluate(
|
||||
f"""
|
||||
(function(){{
|
||||
const root = document.querySelector('{sel}');
|
||||
@@ -238,7 +238,7 @@ class FillContractFormAction(DemoAction):
|
||||
}})();
|
||||
"""
|
||||
)
|
||||
await wait_for_input_value(sel, value, timeout_ms=1500)
|
||||
_ = await wait_for_input_value(sel, value, timeout_ms=1500)
|
||||
selections[name] = {"values": await read_values(sel)}
|
||||
|
||||
async def wait_for_field_enabled(selector: str, timeout_ms: int = 5000) -> bool:
|
||||
@@ -291,7 +291,7 @@ class FillContractFormAction(DemoAction):
|
||||
}
|
||||
"""
|
||||
try:
|
||||
await page.evaluate(scroll_script)
|
||||
_ = await page.evaluate(scroll_script)
|
||||
await page.wait_for_timeout(200)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
"""Demo action to test Docling UI element extraction."""
|
||||
|
||||
from typing import ClassVar, override
|
||||
|
||||
from guide.app.actions.base import DemoAction, register_action
|
||||
from guide.app.browser.types import PageLike
|
||||
from guide.app.core.config import load_settings
|
||||
from guide.app.models.domain.models import ActionContext, ActionResult
|
||||
from guide.app import errors
|
||||
|
||||
@@ -10,14 +13,16 @@ from guide.app import errors
|
||||
class TestDoclingExtractionAction(DemoAction):
|
||||
"""Test action that fails after page loads to capture UI elements via Docling."""
|
||||
|
||||
id = "test-docling-extraction"
|
||||
description = "Fail after page loads to test Docling UI element extraction"
|
||||
category = "demo"
|
||||
id: ClassVar[str] = "test-docling-extraction"
|
||||
description: ClassVar[str] = "Fail after page loads to test Docling UI element extraction"
|
||||
category: ClassVar[str] = "demo"
|
||||
|
||||
@override
|
||||
async def run(self, page: PageLike, context: ActionContext) -> ActionResult:
|
||||
"""Navigate to page, then fail to trigger diagnostics capture."""
|
||||
settings = load_settings()
|
||||
# Page is open and loaded - HTML should be capturable
|
||||
await page.goto("https://stg.raindrop.com")
|
||||
_ = await page.goto(settings.raindrop_base_url)
|
||||
|
||||
# Intentionally fail with a GuideError while page is open
|
||||
# This triggers capture_all_diagnostics with page still available
|
||||
|
||||
205
src/guide/app/actions/diagnose/messaging_selectors.py
Normal file
205
src/guide/app/actions/diagnose/messaging_selectors.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""Diagnostic script to validate messaging XPath selectors against live page.
|
||||
|
||||
Run via: python -m guide.app.actions.diagnose.messaging_selectors
|
||||
|
||||
Requires Chrome with Terminator Bridge extension connected to a page
|
||||
with the messaging UI (e.g., board view with chat panel).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from guide.app.browser.extension_client import ExtensionClient, ExtensionPage
|
||||
from guide.app.strings.registry import MessagingStrings
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# XPath selectors to validate
|
||||
MESSAGING_SELECTORS: dict[str, str] = {
|
||||
"notification_indicator": MessagingStrings.notification_indicator,
|
||||
"modal_wrapper": MessagingStrings.modal_wrapper,
|
||||
"modal_close_button": MessagingStrings.modal_close_button,
|
||||
"chat_messages_container": MessagingStrings.chat_messages_container,
|
||||
"chat_flyout_button": MessagingStrings.chat_flyout_button,
|
||||
"chat_conversations_tab": MessagingStrings.chat_conversations_tab,
|
||||
"chat_input": MessagingStrings.chat_input,
|
||||
"send_button": MessagingStrings.send_button,
|
||||
}
|
||||
|
||||
|
||||
async def validate_selector(
|
||||
page: ExtensionPage, name: str, selector: str
|
||||
) -> dict[str, object]:
|
||||
"""Validate a single selector against the page.
|
||||
|
||||
Args:
|
||||
page: ExtensionPage instance
|
||||
name: Friendly name for the selector
|
||||
selector: Playwright selector (xpath= or CSS)
|
||||
|
||||
Returns:
|
||||
Dict with validation results
|
||||
"""
|
||||
result: dict[str, object] = {
|
||||
"name": name,
|
||||
"selector": selector,
|
||||
"found": False,
|
||||
"count": 0,
|
||||
"visible": False,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
# Handle XPath selectors (Playwright format: xpath=/...)
|
||||
if selector.startswith("xpath="):
|
||||
xpath = selector[6:] # Strip "xpath=" prefix
|
||||
js_code = f"""
|
||||
(() => {{
|
||||
try {{
|
||||
const result = document.evaluate(
|
||||
'{xpath}',
|
||||
document,
|
||||
null,
|
||||
XPathResult.ORDERED_NODE_SNAPSHOT_TYPE,
|
||||
null
|
||||
);
|
||||
const count = result.snapshotLength;
|
||||
if (count === 0) {{
|
||||
return {{ found: false, count: 0, visible: false }};
|
||||
}}
|
||||
const elem = result.snapshotItem(0);
|
||||
const rect = elem.getBoundingClientRect();
|
||||
const computed = window.getComputedStyle(elem);
|
||||
const visible = rect.width > 0 && rect.height > 0 && computed.display !== 'none';
|
||||
return {{
|
||||
found: true,
|
||||
count: count,
|
||||
visible: visible,
|
||||
tagName: elem.tagName,
|
||||
className: elem.className || '',
|
||||
id: elem.id || null,
|
||||
rect: {{
|
||||
top: Math.round(rect.top),
|
||||
left: Math.round(rect.left),
|
||||
width: Math.round(rect.width),
|
||||
height: Math.round(rect.height)
|
||||
}}
|
||||
}};
|
||||
}} catch (e) {{
|
||||
return {{ found: false, count: 0, visible: false, error: e.message }};
|
||||
}}
|
||||
}})();
|
||||
"""
|
||||
else:
|
||||
# CSS selector
|
||||
js_code = f"""
|
||||
(() => {{
|
||||
try {{
|
||||
const elements = document.querySelectorAll('{selector}');
|
||||
const count = elements.length;
|
||||
if (count === 0) {{
|
||||
return {{ found: false, count: 0, visible: false }};
|
||||
}}
|
||||
const elem = elements[0];
|
||||
const rect = elem.getBoundingClientRect();
|
||||
const computed = window.getComputedStyle(elem);
|
||||
const visible = rect.width > 0 && rect.height > 0 && computed.display !== 'none';
|
||||
return {{
|
||||
found: true,
|
||||
count: count,
|
||||
visible: visible,
|
||||
tagName: elem.tagName,
|
||||
className: elem.className || '',
|
||||
id: elem.id || null,
|
||||
rect: {{
|
||||
top: Math.round(rect.top),
|
||||
left: Math.round(rect.left),
|
||||
width: Math.round(rect.width),
|
||||
height: Math.round(rect.height)
|
||||
}}
|
||||
}};
|
||||
}} catch (e) {{
|
||||
return {{ found: false, count: 0, visible: false, error: e.message }};
|
||||
}}
|
||||
}})();
|
||||
"""
|
||||
|
||||
raw_result = await page.evaluate(js_code)
|
||||
if isinstance(raw_result, dict):
|
||||
result["found"] = raw_result.get("found", False)
|
||||
result["count"] = raw_result.get("count", 0)
|
||||
result["visible"] = raw_result.get("visible", False)
|
||||
result["error"] = raw_result.get("error")
|
||||
if raw_result.get("tagName"):
|
||||
result["tag"] = raw_result.get("tagName")
|
||||
if raw_result.get("className"):
|
||||
result["class"] = raw_result.get("className")
|
||||
if raw_result.get("rect"):
|
||||
result["rect"] = raw_result.get("rect")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def run_diagnostics() -> None:
|
||||
"""Run diagnostics for all messaging selectors."""
|
||||
_logger.info("Connecting to browser extension...")
|
||||
|
||||
async with ExtensionClient() as client:
|
||||
page = await client.get_page()
|
||||
_logger.info("Connected to browser")
|
||||
|
||||
# Get page info
|
||||
url = await page.evaluate("window.location.href")
|
||||
title = await page.evaluate("document.title")
|
||||
_logger.info(f"Page: {title}")
|
||||
_logger.info(f"URL: {url}")
|
||||
_logger.info("-" * 60)
|
||||
|
||||
results: list[dict[str, object]] = []
|
||||
for name, selector in MESSAGING_SELECTORS.items():
|
||||
result = await validate_selector(page, name, selector)
|
||||
results.append(result)
|
||||
|
||||
status = "✅" if result["found"] else "❌"
|
||||
visible_status = "(visible)" if result["visible"] else "(hidden)"
|
||||
count_str = f"[{result['count']}]" if result["count"] else ""
|
||||
|
||||
if result["found"]:
|
||||
_logger.info(f"{status} {name}: FOUND {count_str} {visible_status}")
|
||||
if result.get("tag"):
|
||||
class_str = str(result.get("class", ""))[:50]
|
||||
_logger.info(f" tag: {result['tag']}, class: {class_str}")
|
||||
else:
|
||||
_logger.warning(f"{status} {name}: NOT FOUND")
|
||||
if result.get("error"):
|
||||
_logger.warning(f" error: {result['error']}")
|
||||
|
||||
_logger.info("-" * 60)
|
||||
|
||||
# Summary
|
||||
found_count = sum(1 for r in results if r["found"])
|
||||
visible_count = sum(1 for r in results if r["visible"])
|
||||
_logger.info(
|
||||
f"Summary: {found_count}/{len(results)} found, {visible_count} visible"
|
||||
)
|
||||
|
||||
# Report missing critical selectors
|
||||
critical = [
|
||||
"chat_flyout_button",
|
||||
"chat_messages_container",
|
||||
"chat_input",
|
||||
"send_button",
|
||||
]
|
||||
missing_critical = [
|
||||
n
|
||||
for n in critical
|
||||
if not next((r for r in results if r["name"] == n and r["found"]), None)
|
||||
]
|
||||
if missing_critical:
|
||||
_logger.warning(
|
||||
f"Missing critical selectors: {', '.join(missing_critical)}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(run_diagnostics())
|
||||
@@ -13,6 +13,7 @@ import logging
|
||||
from typing import ClassVar, cast, override
|
||||
|
||||
from guide.app.actions.base import DemoAction, register_action
|
||||
from guide.app.core.config import load_settings
|
||||
from guide.app.browser.context_builder import (
|
||||
FormContext,
|
||||
build_form_context,
|
||||
@@ -155,9 +156,10 @@ class SmartFillAction(DemoAction):
|
||||
details={"error": "No bearer token found in page localStorage"},
|
||||
)
|
||||
|
||||
# Determine GraphQL URL (from params or default)
|
||||
# Determine GraphQL URL (from params or config default)
|
||||
settings = load_settings()
|
||||
graphql_url = str(
|
||||
params.get("graphql_url", "https://stg.raindrop.com/hasura/v1/graphql")
|
||||
params.get("graphql_url", settings.raindrop_graphql_url)
|
||||
)
|
||||
|
||||
# 1. Fetch board schema from GraphQL
|
||||
|
||||
5
src/guide/app/actions/messaging/__init__.py
Normal file
5
src/guide/app/actions/messaging/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Messaging actions for chat panel interactions."""
|
||||
|
||||
from guide.app.actions.messaging.respond import RespondToMessageAction
|
||||
|
||||
__all__ = ["RespondToMessageAction"]
|
||||
225
src/guide/app/actions/messaging/respond.py
Normal file
225
src/guide/app/actions/messaging/respond.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Message response action for chat panel interactions."""
|
||||
|
||||
from typing import ClassVar, override
|
||||
|
||||
from guide.app import errors
|
||||
from guide.app.actions.base import DemoAction, register_action
|
||||
from guide.app.browser.helpers import PageHelpers
|
||||
from guide.app.browser.types import PageLike
|
||||
from guide.app.models.domain import ActionContext, ActionResult
|
||||
from guide.app.strings.registry import MessagingStrings
|
||||
|
||||
|
||||
@register_action
|
||||
class RespondToMessageAction(DemoAction):
|
||||
"""Open chat panel and send a message.
|
||||
|
||||
Handles the full messaging flow including:
|
||||
1. Dismissing any blocking modals (common after email URL login)
|
||||
2. Expanding the chat flyout if not visible
|
||||
3. Switching to conversations tab
|
||||
4. Typing and sending the message
|
||||
|
||||
Visibility checks are performed between each step to ensure the chat
|
||||
panel remains accessible throughout the flow.
|
||||
|
||||
Request params:
|
||||
message: str - Message text to send (required)
|
||||
"""
|
||||
|
||||
id: ClassVar[str] = "messaging.respond"
|
||||
description: ClassVar[str] = "Open chat panel, type message, and send."
|
||||
category: ClassVar[str] = "messaging"
|
||||
|
||||
@override
|
||||
async def run(self, page: PageLike, context: ActionContext) -> ActionResult:
|
||||
"""Execute message response flow.
|
||||
|
||||
Args:
|
||||
page: The browser page instance.
|
||||
context: Action context with params.
|
||||
|
||||
Returns:
|
||||
ActionResult with message_sent status and details.
|
||||
|
||||
Raises:
|
||||
ActionExecutionError: If message param missing or chat elements not found.
|
||||
"""
|
||||
message = context.params.get("message")
|
||||
if not message or not isinstance(message, str):
|
||||
raise errors.ActionExecutionError(
|
||||
"'message' param is required",
|
||||
details={"provided_params": list(context.params.keys())},
|
||||
)
|
||||
|
||||
helpers = PageHelpers(page)
|
||||
|
||||
# 1. Wait for page stability
|
||||
await helpers.wait_for_stable()
|
||||
|
||||
# 2. Dismiss blocking modal if present (common after email URL login)
|
||||
modal_dismissed = await self._dismiss_modal_if_present(page, helpers)
|
||||
|
||||
# 3. Ensure chat panel is visible (expand flyout + switch to conversations)
|
||||
chat_expanded = await self._ensure_chat_visible(page, helpers)
|
||||
|
||||
# 4. Verify chat is still visible before typing
|
||||
await self._verify_chat_visible(page, helpers, step="before_typing")
|
||||
|
||||
# 5. Type message into chat input
|
||||
await page.fill(MessagingStrings.chat_input, message)
|
||||
|
||||
# 6. Verify chat is still visible before sending
|
||||
await self._verify_chat_visible(page, helpers, step="before_send")
|
||||
|
||||
# 7. Send message
|
||||
await page.click(MessagingStrings.send_button)
|
||||
|
||||
# 8. Wait for network activity to settle
|
||||
await helpers.wait_for_network_idle()
|
||||
|
||||
# 9. Final verification that chat is still visible
|
||||
await self._verify_chat_visible(page, helpers, step="after_send")
|
||||
|
||||
return ActionResult(
|
||||
details={
|
||||
"message_sent": True,
|
||||
"message_length": len(message),
|
||||
"modal_dismissed": modal_dismissed,
|
||||
"chat_expanded": chat_expanded,
|
||||
}
|
||||
)
|
||||
|
||||
async def _dismiss_modal_if_present(
|
||||
self, page: PageLike, helpers: PageHelpers
|
||||
) -> bool:
|
||||
"""Dismiss blocking modal if present.
|
||||
|
||||
After logging in via emailed URL, a modal may appear that blocks
|
||||
access to the page. This method detects and dismisses it.
|
||||
|
||||
Args:
|
||||
page: The browser page instance.
|
||||
helpers: PageHelpers instance for wait utilities.
|
||||
|
||||
Returns:
|
||||
True if modal was dismissed, False otherwise.
|
||||
"""
|
||||
modal = page.locator(MessagingStrings.modal_wrapper)
|
||||
modal_count = await modal.count()
|
||||
|
||||
if modal_count > 0:
|
||||
close_button = page.locator(MessagingStrings.modal_close_button)
|
||||
close_count = await close_button.count()
|
||||
|
||||
if close_count > 0:
|
||||
await close_button.first.click()
|
||||
# Wait for modal to close
|
||||
_ = await page.wait_for_selector(
|
||||
MessagingStrings.modal_wrapper,
|
||||
state="hidden",
|
||||
timeout=5000,
|
||||
)
|
||||
# Wait for page to stabilize after modal close
|
||||
await helpers.wait_for_stable()
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _ensure_chat_visible(self, page: PageLike, helpers: PageHelpers) -> bool:
|
||||
"""Ensure chat panel is visible, expanding if needed.
|
||||
|
||||
Checks if chat messages container is visible. If not, clicks the
|
||||
flyout button and switches to conversations tab.
|
||||
|
||||
Args:
|
||||
page: The browser page instance.
|
||||
helpers: PageHelpers instance for wait utilities.
|
||||
|
||||
Returns:
|
||||
True if chat was expanded, False if already visible.
|
||||
"""
|
||||
chat_container = page.locator(MessagingStrings.chat_messages_container)
|
||||
container_count = await chat_container.count()
|
||||
|
||||
if container_count > 0:
|
||||
# Chat already visible
|
||||
return False
|
||||
|
||||
# Click flyout button to expand
|
||||
flyout_button = page.locator(MessagingStrings.chat_flyout_button)
|
||||
flyout_count = await flyout_button.count()
|
||||
|
||||
if flyout_count == 0:
|
||||
raise errors.ActionExecutionError(
|
||||
"Chat flyout button not found",
|
||||
details={"selector": MessagingStrings.chat_flyout_button},
|
||||
)
|
||||
|
||||
await flyout_button.first.click()
|
||||
await helpers.wait_for_stable()
|
||||
|
||||
# Verify flyout expanded before proceeding
|
||||
await self._verify_chat_visible(page, helpers, step="after_flyout_click")
|
||||
|
||||
# Switch to conversations tab
|
||||
conversations_tab = page.locator(MessagingStrings.chat_conversations_tab)
|
||||
tab_count = await conversations_tab.count()
|
||||
|
||||
if tab_count > 0:
|
||||
await conversations_tab.first.click()
|
||||
await helpers.wait_for_stable()
|
||||
|
||||
# Verify still visible after tab switch
|
||||
await self._verify_chat_visible(page, helpers, step="after_tab_switch")
|
||||
|
||||
return True
|
||||
|
||||
async def _verify_chat_visible(
|
||||
self, page: PageLike, helpers: PageHelpers, step: str
|
||||
) -> None:
|
||||
"""Verify chat panel is visible at a given step.
|
||||
|
||||
Checks that the chat messages container is present and visible.
|
||||
If not, attempts to re-expand the flyout once before failing.
|
||||
|
||||
Args:
|
||||
page: The browser page instance.
|
||||
helpers: PageHelpers instance for wait utilities.
|
||||
step: Name of the current step (for error reporting).
|
||||
|
||||
Raises:
|
||||
ActionExecutionError: If chat panel cannot be made visible.
|
||||
"""
|
||||
chat_container = page.locator(MessagingStrings.chat_messages_container)
|
||||
container_count = await chat_container.count()
|
||||
|
||||
if container_count > 0:
|
||||
# Chat is visible
|
||||
return
|
||||
|
||||
# Chat not visible - attempt recovery by clicking flyout button
|
||||
flyout_button = page.locator(MessagingStrings.chat_flyout_button)
|
||||
flyout_count = await flyout_button.count()
|
||||
|
||||
if flyout_count > 0:
|
||||
await flyout_button.first.click()
|
||||
await helpers.wait_for_stable()
|
||||
|
||||
# Check again after recovery attempt
|
||||
container_count = await chat_container.count()
|
||||
if container_count > 0:
|
||||
return
|
||||
|
||||
# Recovery failed - raise error with step context
|
||||
raise errors.ActionExecutionError(
|
||||
f"Chat panel not visible at step '{step}'",
|
||||
details={
|
||||
"step": step,
|
||||
"container_selector": MessagingStrings.chat_messages_container,
|
||||
"flyout_selector": MessagingStrings.chat_flyout_button,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["RespondToMessageAction"]
|
||||
@@ -11,6 +11,9 @@ Example flows:
|
||||
from typing import ClassVar, override
|
||||
|
||||
from guide.app.actions.base import CompositeAction, register_action
|
||||
from guide.app.actions.playbooks.email_notification import (
|
||||
EmailNotificationResponsePlaybook,
|
||||
)
|
||||
from guide.app.models.domain import ActionResult
|
||||
|
||||
|
||||
@@ -98,4 +101,8 @@ class FullDemoFlowAction(CompositeAction):
|
||||
self.context.shared_state["suppliers"] = suppliers
|
||||
|
||||
|
||||
__all__ = ["OnboardingFlowAction", "FullDemoFlowAction"]
|
||||
__all__ = [
|
||||
"OnboardingFlowAction",
|
||||
"FullDemoFlowAction",
|
||||
"EmailNotificationResponsePlaybook",
|
||||
]
|
||||
209
src/guide/app/actions/playbooks/email_notification.py
Normal file
209
src/guide/app/actions/playbooks/email_notification.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""Email notification response playbook.
|
||||
|
||||
Receives n8n webhook payload and executes:
|
||||
1. Session awareness check (detect if already logged in as correct user)
|
||||
2. Conditional login (only if is_login=True and not already logged in)
|
||||
- If no OTP URL provided, triggers OTP request flow via webhook callback
|
||||
3. Conditional message response (only if is_message=True)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import ClassVar, override
|
||||
|
||||
from guide.app import errors
|
||||
from guide.app.actions.base import (
|
||||
ActionRegistry,
|
||||
ConditionalCompositeAction,
|
||||
register_action,
|
||||
)
|
||||
from guide.app.auth import detect_current_persona
|
||||
from guide.app.browser.types import PageLike
|
||||
from guide.app.core.config import AppSettings
|
||||
from guide.app.models.domain import ActionContext, ActionResult
|
||||
from guide.app.models.personas import PersonaResolver
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_action
|
||||
class EmailNotificationResponsePlaybook(ConditionalCompositeAction):
|
||||
"""Playbook for responding to email notifications via n8n.
|
||||
|
||||
Receives n8n payload and conditionally executes login and message steps.
|
||||
|
||||
Request params (from n8n):
|
||||
user_email: str - Email of user to authenticate as (required)
|
||||
is_login: bool - Whether to perform login step (default: True)
|
||||
is_message: bool - Whether to send message (default: True)
|
||||
access_url: str - OTP URL for authentication (optional - triggers OTP request if missing)
|
||||
message: str - Message text to send (required if is_message)
|
||||
callback_base_url: str - Base URL for OTP callback (optional)
|
||||
|
||||
Session awareness:
|
||||
- Detects if already logged in as correct user
|
||||
- Skips login if session matches user_email
|
||||
- If no OTP URL provided, triggers OTP request and waits for callback
|
||||
|
||||
Browser host:
|
||||
- Uses browserless-cdp (pass browser_host_id in request)
|
||||
"""
|
||||
|
||||
id: ClassVar[str] = "playbook.email_notification_response"
|
||||
description: ClassVar[str] = "Respond to email notification: login + message"
|
||||
category: ClassVar[str] = "playbooks"
|
||||
child_actions: ClassVar[tuple[str, ...]] = (
|
||||
"auth.login_as_persona",
|
||||
"messaging.respond",
|
||||
)
|
||||
|
||||
_persona_resolver: PersonaResolver
|
||||
_settings: AppSettings
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
registry: ActionRegistry,
|
||||
persona_resolver: PersonaResolver,
|
||||
settings: AppSettings,
|
||||
) -> None:
|
||||
"""Initialize playbook with dependencies.
|
||||
|
||||
Args:
|
||||
registry: ActionRegistry for resolving child actions.
|
||||
persona_resolver: Resolver for email-based persona lookup.
|
||||
settings: Application settings for n8n webhook configuration.
|
||||
"""
|
||||
super().__init__(registry)
|
||||
self._persona_resolver = persona_resolver
|
||||
self._settings = settings
|
||||
|
||||
@override
|
||||
async def should_execute_step(
|
||||
self,
|
||||
step_id: str,
|
||||
context: ActionContext,
|
||||
) -> bool:
|
||||
"""Determine if step should execute based on n8n params.
|
||||
|
||||
Args:
|
||||
step_id: The action ID of the step.
|
||||
context: Current action context with params and shared_state.
|
||||
|
||||
Returns:
|
||||
True to execute the step, False to skip.
|
||||
"""
|
||||
params = context.params
|
||||
|
||||
if step_id == "auth.login_as_persona":
|
||||
is_login = params.get("is_login", True)
|
||||
already_logged_in = context.shared_state.get("session_reused", False)
|
||||
otp_flow_completed = context.shared_state.get("otp_flow_completed", False)
|
||||
# Skip if already logged in OR if OTP flow handled login
|
||||
return bool(is_login) and not already_logged_in and not otp_flow_completed
|
||||
|
||||
if step_id == "messaging.respond":
|
||||
return bool(params.get("is_message", True))
|
||||
|
||||
return True
|
||||
|
||||
@override
|
||||
async def run(self, page: PageLike, context: ActionContext) -> ActionResult:
|
||||
"""Execute email notification playbook.
|
||||
|
||||
Args:
|
||||
page: The browser page instance.
|
||||
context: Action context with n8n params.
|
||||
|
||||
Returns:
|
||||
ActionResult with combined status from all steps.
|
||||
|
||||
Raises:
|
||||
ActionExecutionError: If required params are missing.
|
||||
"""
|
||||
user_email = context.params.get("user_email")
|
||||
if not user_email or not isinstance(user_email, str):
|
||||
raise errors.ActionExecutionError(
|
||||
"'user_email' param is required",
|
||||
details={"provided_params": list(context.params.keys())},
|
||||
)
|
||||
|
||||
# Session awareness: check if already logged in as this user
|
||||
current_email = await detect_current_persona(page)
|
||||
if current_email and current_email.lower() == user_email.lower():
|
||||
_logger.info("Already logged in as: %s", user_email)
|
||||
context.shared_state["logged_in_email"] = current_email
|
||||
context.shared_state["session_reused"] = True
|
||||
else:
|
||||
# Not logged in - check if we need OTP request flow
|
||||
access_url = context.params.get("access_url")
|
||||
is_login = context.params.get("is_login", True)
|
||||
|
||||
if is_login and not access_url:
|
||||
# No OTP URL provided - trigger OTP request flow
|
||||
_logger.info("No OTP URL provided, triggering OTP request flow for: %s", user_email)
|
||||
otp_result = await self._trigger_otp_flow(page, context, user_email)
|
||||
if otp_result.status == "error":
|
||||
return otp_result
|
||||
context.shared_state["otp_flow_completed"] = True
|
||||
|
||||
# Forward params to login action (email-based resolution)
|
||||
context.params["email"] = user_email
|
||||
if access_url := context.params.get("access_url"):
|
||||
context.params["url"] = access_url
|
||||
|
||||
return await super().run(page, context)
|
||||
|
||||
async def _trigger_otp_flow(
|
||||
self,
|
||||
page: PageLike,
|
||||
context: ActionContext,
|
||||
email: str,
|
||||
) -> ActionResult:
|
||||
"""Trigger OTP request flow via auth.request_otp action.
|
||||
|
||||
Args:
|
||||
page: Browser page instance.
|
||||
context: Action context.
|
||||
email: Email to request OTP for.
|
||||
|
||||
Returns:
|
||||
ActionResult from OTP request action.
|
||||
"""
|
||||
try:
|
||||
otp_action = self.registry.get("auth.request_otp")
|
||||
except errors.ActionExecutionError:
|
||||
return ActionResult(
|
||||
status="error",
|
||||
error="auth.request_otp action not available",
|
||||
details={"hint": "Ensure auth.request_otp is registered"},
|
||||
)
|
||||
|
||||
# Forward callback_base_url if provided
|
||||
otp_context = ActionContext(
|
||||
action_id="auth.request_otp",
|
||||
persona_id=context.persona_id,
|
||||
browser_host_id=context.browser_host_id,
|
||||
params={
|
||||
"email": email,
|
||||
"callback_base_url": context.params.get("callback_base_url", self._settings.callback_base_url),
|
||||
},
|
||||
)
|
||||
|
||||
_logger.info("Executing OTP request flow for: %s", email)
|
||||
return await otp_action.run(page, otp_context)
|
||||
|
||||
@override
|
||||
async def on_step_complete(self, step_id: str, result: ActionResult) -> None:
|
||||
"""Process step results and update shared state.
|
||||
|
||||
Args:
|
||||
step_id: The completed step action ID.
|
||||
result: The step's ActionResult.
|
||||
"""
|
||||
assert self.context is not None
|
||||
|
||||
if step_id == "auth.login_as_persona" and result.details:
|
||||
if status := result.details.get("status"):
|
||||
self.context.shared_state["login_status"] = status
|
||||
|
||||
|
||||
__all__ = ["EmailNotificationResponsePlaybook"]
|
||||
@@ -7,7 +7,7 @@ from guide.app.actions import base
|
||||
from guide.app.actions.base import ActionRegistry, CompositeAction, DemoAction
|
||||
from guide.app.auth import SessionManager
|
||||
from guide.app.core.config import AppSettings
|
||||
from guide.app.models.personas import PersonaStore
|
||||
from guide.app.models.personas import PersonaResolver, PersonaStore
|
||||
|
||||
|
||||
def _discover_action_modules() -> None:
|
||||
@@ -59,8 +59,11 @@ def default_registry(
|
||||
"""
|
||||
_discover_action_modules()
|
||||
|
||||
persona_resolver = PersonaResolver(persona_store)
|
||||
|
||||
di_context: dict[str, object] = {
|
||||
"persona_store": persona_store,
|
||||
"persona_resolver": persona_resolver,
|
||||
"login_url": login_url,
|
||||
"settings": settings,
|
||||
"session_manager": session_manager,
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from guide.app.api.routes import actions, boards, config, diagnostics, health, sessions
|
||||
from guide.app.api.routes import (
|
||||
actions,
|
||||
boards,
|
||||
config,
|
||||
diagnostics,
|
||||
health,
|
||||
otp_callback,
|
||||
sessions,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(health.router)
|
||||
@@ -9,5 +17,6 @@ router.include_router(boards.router)
|
||||
router.include_router(config.router)
|
||||
router.include_router(diagnostics.router)
|
||||
router.include_router(sessions.router)
|
||||
router.include_router(otp_callback.router)
|
||||
|
||||
__all__ = ["router"]
|
||||
|
||||
@@ -83,6 +83,18 @@ async def execute_action(
|
||||
)
|
||||
target_host_id = target_host_id or settings.default_browser_host_id
|
||||
|
||||
# Early validation: ensure browser host exists before proceeding
|
||||
if target_host_id not in settings.browser_hosts:
|
||||
available_hosts = list(settings.browser_hosts.keys())
|
||||
raise errors.ConfigError(
|
||||
f"Browser host '{target_host_id}' not found in configuration",
|
||||
details={
|
||||
"requested_host": target_host_id,
|
||||
"available_hosts": available_hosts,
|
||||
"source": "persona" if persona and persona.browser_host_id == target_host_id else "request",
|
||||
},
|
||||
)
|
||||
|
||||
context = ActionContext(
|
||||
action_id=action_id,
|
||||
persona_id=persona.id if persona else None,
|
||||
@@ -112,6 +124,17 @@ async def execute_action(
|
||||
page_like, persona, mfa_provider, login_url=settings.raindrop_base_url
|
||||
)
|
||||
result = await action.run(page_like, context)
|
||||
|
||||
# Check if action returned an error result (e.g., from CompositeAction)
|
||||
if result.status == "error":
|
||||
return ActionEnvelope(
|
||||
status=ActionStatus.ERROR,
|
||||
action_id=action_id,
|
||||
correlation_id=context.correlation_id,
|
||||
error_code="ACTION_EXECUTION_FAILED",
|
||||
message=result.error or "Action returned error status",
|
||||
details=result.details,
|
||||
)
|
||||
except errors.GuideError as exc:
|
||||
# Capture diagnostics for debugging (including UI elements via Docling if enabled)
|
||||
debug_info = None
|
||||
|
||||
@@ -8,13 +8,23 @@ Provides read-only inspection endpoints for:
|
||||
- Page structure
|
||||
"""
|
||||
|
||||
from typing import cast
|
||||
from typing import Protocol, cast
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from guide.app.browser.client import BrowserClient
|
||||
from guide.app.browser.types import PageLike
|
||||
from guide.app.core.config import AppSettings
|
||||
|
||||
|
||||
class _AppStateProtocol(Protocol):
|
||||
"""Protocol for app.state with typed attributes."""
|
||||
|
||||
browser_client: BrowserClient
|
||||
settings: AppSettings
|
||||
|
||||
|
||||
from guide.app.browser.diagnostics import (
|
||||
analyze_field_issues,
|
||||
extract_ui_elements,
|
||||
@@ -136,8 +146,9 @@ async def check_connectivity(
|
||||
Returns:
|
||||
Connectivity status with page title and URL if connected
|
||||
"""
|
||||
browser_client: BrowserClient = request.app.state.browser_client
|
||||
settings = request.app.state.settings
|
||||
state = cast(_AppStateProtocol, request.app.state)
|
||||
browser_client = state.browser_client
|
||||
settings = state.settings
|
||||
|
||||
target_host = host_id or settings.default_browser_host_id
|
||||
|
||||
@@ -170,8 +181,9 @@ async def diagnose_field(
|
||||
Returns:
|
||||
Field diagnostics including visibility, input state, chips
|
||||
"""
|
||||
browser_client: BrowserClient = request.app.state.browser_client
|
||||
settings = request.app.state.settings
|
||||
state = cast(_AppStateProtocol, request.app.state)
|
||||
browser_client = state.browser_client
|
||||
settings = state.settings
|
||||
|
||||
target_host = host_id or settings.default_browser_host_id
|
||||
selector = resolve_selector(field)
|
||||
@@ -221,8 +233,9 @@ async def diagnose_dropdown(
|
||||
Returns:
|
||||
Dropdown diagnostics including options and component structure
|
||||
"""
|
||||
browser_client: BrowserClient = request.app.state.browser_client
|
||||
settings = request.app.state.settings
|
||||
state = cast(_AppStateProtocol, request.app.state)
|
||||
browser_client = state.browser_client
|
||||
settings = state.settings
|
||||
|
||||
target_host = host_id or settings.default_browser_host_id
|
||||
selector = resolve_selector(field)
|
||||
@@ -283,8 +296,9 @@ async def diagnose_form(
|
||||
Returns:
|
||||
Diagnostics for all known intake form fields
|
||||
"""
|
||||
browser_client: BrowserClient = request.app.state.browser_client
|
||||
settings = request.app.state.settings
|
||||
state = cast(_AppStateProtocol, request.app.state)
|
||||
browser_client = state.browser_client
|
||||
settings = state.settings
|
||||
|
||||
target_host = host_id or settings.default_browser_host_id
|
||||
|
||||
@@ -355,8 +369,9 @@ async def diagnose_page(
|
||||
Returns:
|
||||
Page structure including data-cy elements
|
||||
"""
|
||||
browser_client: BrowserClient = request.app.state.browser_client
|
||||
settings = request.app.state.settings
|
||||
state = cast(_AppStateProtocol, request.app.state)
|
||||
browser_client = state.browser_client
|
||||
settings = state.settings
|
||||
|
||||
target_host = host_id or settings.default_browser_host_id
|
||||
|
||||
@@ -414,8 +429,9 @@ async def extract_selectors(
|
||||
Returns:
|
||||
Extracted UI elements with their CSS selectors
|
||||
"""
|
||||
browser_client: BrowserClient = request.app.state.browser_client
|
||||
settings = request.app.state.settings
|
||||
state = cast(_AppStateProtocol, request.app.state)
|
||||
browser_client = state.browser_client
|
||||
settings = state.settings
|
||||
|
||||
target_host = host_id or settings.default_browser_host_id
|
||||
|
||||
|
||||
104
src/guide/app/api/routes/otp_callback.py
Normal file
104
src/guide/app/api/routes/otp_callback.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""OTP callback endpoint for n8n webhook responses.
|
||||
|
||||
Receives OTP URLs from n8n after it searches for and extracts
|
||||
the magic link from email.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from guide.app.auth.otp_callback import OtpCallbackStore, get_otp_callback_store
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
class OtpCallbackRequest(BaseModel):
|
||||
"""Request body for OTP callback from n8n."""
|
||||
|
||||
correlation_id: str
|
||||
"""Correlation ID from the original OTP request."""
|
||||
|
||||
otp_url: str | None = None
|
||||
"""The OTP magic link URL extracted from email."""
|
||||
|
||||
error: str | None = None
|
||||
"""Error message if OTP URL could not be retrieved."""
|
||||
|
||||
|
||||
class OtpCallbackResponse(BaseModel):
|
||||
"""Response for OTP callback."""
|
||||
|
||||
status: str
|
||||
"""Status of the callback: 'fulfilled' or 'not_found'."""
|
||||
|
||||
correlation_id: str
|
||||
"""The correlation ID that was processed."""
|
||||
|
||||
|
||||
def _get_store() -> OtpCallbackStore:
|
||||
"""Dependency to get the OTP callback store."""
|
||||
return get_otp_callback_store()
|
||||
|
||||
|
||||
StoreDep = Annotated[OtpCallbackStore, Depends(_get_store)]
|
||||
|
||||
|
||||
@router.post("/otp-callback", response_model=OtpCallbackResponse)
|
||||
async def otp_callback(
|
||||
payload: OtpCallbackRequest,
|
||||
store: StoreDep,
|
||||
) -> OtpCallbackResponse:
|
||||
"""Receive OTP URL callback from n8n.
|
||||
|
||||
Called by n8n after it:
|
||||
1. Receives webhook notification that OTP was requested
|
||||
2. Searches inbox for OTP email
|
||||
3. Extracts magic link URL
|
||||
|
||||
Args:
|
||||
payload: Callback request with correlation_id and otp_url or error.
|
||||
store: The OTP callback store.
|
||||
|
||||
Returns:
|
||||
Response indicating if the callback was fulfilled.
|
||||
|
||||
Raises:
|
||||
HTTPException: If correlation_id not found (404).
|
||||
"""
|
||||
_logger.info(
|
||||
"Received OTP callback: correlation_id=%s, has_url=%s, error=%s",
|
||||
payload.correlation_id,
|
||||
bool(payload.otp_url),
|
||||
payload.error,
|
||||
)
|
||||
|
||||
if not payload.otp_url and not payload.error:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either otp_url or error must be provided",
|
||||
)
|
||||
|
||||
fulfilled = await store.fulfill(
|
||||
correlation_id=payload.correlation_id,
|
||||
otp_url=payload.otp_url,
|
||||
error=payload.error,
|
||||
)
|
||||
|
||||
if not fulfilled:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No pending OTP request for correlation_id: {payload.correlation_id}",
|
||||
)
|
||||
|
||||
return OtpCallbackResponse(
|
||||
status="fulfilled",
|
||||
correlation_id=payload.correlation_id,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["router"]
|
||||
@@ -1,9 +1,15 @@
|
||||
from guide.app.auth.mfa import DummyMfaCodeProvider, MfaCodeProvider
|
||||
from guide.app.auth.otp_callback import (
|
||||
OtpCallbackStore,
|
||||
OtpRequest,
|
||||
get_otp_callback_store,
|
||||
)
|
||||
from guide.app.auth.session import (
|
||||
detect_current_persona,
|
||||
ensure_persona,
|
||||
login_with_mfa,
|
||||
login_with_otp_url,
|
||||
login_with_verification_code,
|
||||
logout,
|
||||
validate_persona_from_storage,
|
||||
)
|
||||
@@ -14,14 +20,18 @@ from guide.app.auth.session_storage import SessionStorage
|
||||
__all__ = [
|
||||
"DummyMfaCodeProvider",
|
||||
"MfaCodeProvider",
|
||||
"OtpCallbackStore",
|
||||
"OtpRequest",
|
||||
"SessionData",
|
||||
"SessionManager",
|
||||
"SessionStorage",
|
||||
"SessionValidationResult",
|
||||
"detect_current_persona",
|
||||
"ensure_persona",
|
||||
"get_otp_callback_store",
|
||||
"login_with_mfa",
|
||||
"login_with_otp_url",
|
||||
"login_with_verification_code",
|
||||
"logout",
|
||||
"validate_persona_from_storage",
|
||||
]
|
||||
|
||||
219
src/guide/app/auth/otp_callback.py
Normal file
219
src/guide/app/auth/otp_callback.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""OTP callback store for async correlation between request and webhook response.
|
||||
|
||||
Enables the two-phase OTP flow:
|
||||
1. Action triggers OTP email, sends webhook to n8n, waits for callback
|
||||
2. n8n finds OTP URL, calls callback endpoint with correlation_id
|
||||
3. Action receives OTP URL and completes login
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OtpRequest:
|
||||
"""Pending OTP request awaiting callback."""
|
||||
|
||||
correlation_id: str
|
||||
email: str
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
event: asyncio.Event = field(default_factory=asyncio.Event)
|
||||
otp_url: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class OtpCallbackStore:
|
||||
"""Thread-safe store for OTP request/response correlation.
|
||||
|
||||
Usage:
|
||||
store = OtpCallbackStore()
|
||||
|
||||
# In action: register and wait
|
||||
request = store.register(correlation_id, email)
|
||||
otp_url = await store.wait_for_callback(correlation_id, timeout=120)
|
||||
|
||||
# In callback endpoint: fulfill
|
||||
store.fulfill(correlation_id, otp_url)
|
||||
"""
|
||||
|
||||
_requests: dict[str, OtpRequest]
|
||||
_lock: asyncio.Lock
|
||||
_default_timeout: float
|
||||
|
||||
def __init__(self, default_timeout: float = 120.0) -> None:
|
||||
"""Initialize the callback store.
|
||||
|
||||
Args:
|
||||
default_timeout: Default timeout in seconds for waiting.
|
||||
"""
|
||||
self._requests = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._default_timeout = default_timeout
|
||||
|
||||
async def register(self, correlation_id: str, email: str) -> OtpRequest:
|
||||
"""Register a new OTP request.
|
||||
|
||||
Args:
|
||||
correlation_id: Unique ID to correlate request with callback.
|
||||
email: Email address OTP was requested for.
|
||||
|
||||
Returns:
|
||||
The registered OtpRequest.
|
||||
"""
|
||||
async with self._lock:
|
||||
# Clean up expired requests periodically to prevent memory leaks
|
||||
# Do this opportunistically on register (every 10 requests worth)
|
||||
if len(self._requests) > 0 and len(self._requests) % 10 == 0:
|
||||
_ = await self._cleanup_expired_unlocked()
|
||||
|
||||
request = OtpRequest(correlation_id=correlation_id, email=email)
|
||||
self._requests[correlation_id] = request
|
||||
_logger.info(
|
||||
"Registered OTP request: %s for %s", correlation_id, email
|
||||
)
|
||||
return request
|
||||
|
||||
async def wait_for_callback(
|
||||
self,
|
||||
correlation_id: str,
|
||||
timeout: float | None = None,
|
||||
) -> str:
|
||||
"""Wait for OTP URL callback.
|
||||
|
||||
Args:
|
||||
correlation_id: The request correlation ID.
|
||||
timeout: Timeout in seconds (default: store default).
|
||||
|
||||
Returns:
|
||||
The OTP URL from callback.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If callback not received within timeout.
|
||||
ValueError: If correlation_id not found or callback had error.
|
||||
"""
|
||||
timeout = timeout or self._default_timeout
|
||||
|
||||
async with self._lock:
|
||||
request = self._requests.get(correlation_id)
|
||||
if not request:
|
||||
raise ValueError(f"No pending request for correlation_id: {correlation_id}")
|
||||
|
||||
_logger.info(
|
||||
"Waiting for OTP callback: %s (timeout=%ss)", correlation_id, timeout
|
||||
)
|
||||
|
||||
try:
|
||||
_ = await asyncio.wait_for(request.event.wait(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
await self._cleanup(correlation_id)
|
||||
raise TimeoutError(
|
||||
f"OTP callback timeout after {timeout}s for {correlation_id}"
|
||||
) from None
|
||||
|
||||
if request.error:
|
||||
await self._cleanup(correlation_id)
|
||||
raise ValueError(f"OTP callback error: {request.error}")
|
||||
|
||||
if not request.otp_url:
|
||||
await self._cleanup(correlation_id)
|
||||
raise ValueError(f"OTP callback received but no URL for {correlation_id}")
|
||||
|
||||
await self._cleanup(correlation_id)
|
||||
_logger.info("OTP callback received for: %s", correlation_id)
|
||||
return request.otp_url
|
||||
|
||||
async def fulfill(
|
||||
self,
|
||||
correlation_id: str,
|
||||
otp_url: str | None = None,
|
||||
error: str | None = None,
|
||||
) -> bool:
|
||||
"""Fulfill a pending OTP request with URL or error.
|
||||
|
||||
Args:
|
||||
correlation_id: The request correlation ID.
|
||||
otp_url: The OTP URL from n8n.
|
||||
error: Error message if OTP retrieval failed.
|
||||
|
||||
Returns:
|
||||
True if request was found and fulfilled, False otherwise.
|
||||
"""
|
||||
async with self._lock:
|
||||
request = self._requests.get(correlation_id)
|
||||
if not request:
|
||||
_logger.warning(
|
||||
"No pending request for callback: %s", correlation_id
|
||||
)
|
||||
return False
|
||||
|
||||
request.otp_url = otp_url
|
||||
request.error = error
|
||||
_ = request.event.set()
|
||||
|
||||
_logger.info(
|
||||
"Fulfilled OTP request: %s (url=%s, error=%s)",
|
||||
correlation_id,
|
||||
bool(otp_url),
|
||||
error,
|
||||
)
|
||||
return True
|
||||
|
||||
async def _cleanup(self, correlation_id: str) -> None:
|
||||
"""Remove completed request from store."""
|
||||
async with self._lock:
|
||||
_ = self._requests.pop(correlation_id, None)
|
||||
|
||||
async def _cleanup_expired_unlocked(self, max_age_seconds: float = 300) -> int:
|
||||
"""Remove expired requests older than max_age (must hold lock).
|
||||
|
||||
Args:
|
||||
max_age_seconds: Maximum age in seconds.
|
||||
|
||||
Returns:
|
||||
Number of requests cleaned up.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
expired: list[str] = []
|
||||
|
||||
for cid, request in self._requests.items():
|
||||
age = (now - request.created_at).total_seconds()
|
||||
if age > max_age_seconds:
|
||||
expired.append(cid)
|
||||
|
||||
for cid in expired:
|
||||
_ = self._requests.pop(cid, None)
|
||||
|
||||
if expired:
|
||||
_logger.info("Cleaned up %d expired OTP requests", len(expired))
|
||||
return len(expired)
|
||||
|
||||
async def cleanup_expired(self, max_age_seconds: float = 300) -> int:
|
||||
"""Remove expired requests older than max_age.
|
||||
|
||||
Args:
|
||||
max_age_seconds: Maximum age in seconds.
|
||||
|
||||
Returns:
|
||||
Number of requests cleaned up.
|
||||
"""
|
||||
async with self._lock:
|
||||
return await self._cleanup_expired_unlocked(max_age_seconds)
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_store: OtpCallbackStore | None = None
|
||||
|
||||
|
||||
def get_otp_callback_store() -> OtpCallbackStore:
|
||||
"""Get the global OTP callback store instance."""
|
||||
global _store
|
||||
if _store is None:
|
||||
_store = OtpCallbackStore()
|
||||
return _store
|
||||
|
||||
|
||||
__all__ = ["OtpCallbackStore", "OtpRequest", "get_otp_callback_store"]
|
||||
@@ -9,10 +9,42 @@ from guide.app.browser.types import PageLike
|
||||
from guide.app.errors import PersonaError
|
||||
from guide.app.models.personas.models import DemoPersona
|
||||
from guide.app.strings.registry import app_strings
|
||||
from guide.app.strings.selectors.auth import Auth0ErrorSelectors
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _check_for_auth_errors(
|
||||
page: PageLike,
|
||||
selectors: tuple[str, ...],
|
||||
context: str,
|
||||
) -> str | None:
|
||||
"""Check page for authentication error messages.
|
||||
|
||||
Iterates through error selectors and returns the first matching error text.
|
||||
|
||||
Args:
|
||||
page: Browser page to check.
|
||||
selectors: Tuple of selectors to check for errors.
|
||||
context: Context string for logging (e.g., "OTP login", "verification code").
|
||||
|
||||
Returns:
|
||||
Error text if found, None otherwise.
|
||||
"""
|
||||
for selector in selectors:
|
||||
error_el = page.locator(selector)
|
||||
if await error_el.count() > 0:
|
||||
error_text = await error_el.first.text_content()
|
||||
_logger.warning(
|
||||
"%s error: %s (selector: %s)",
|
||||
context,
|
||||
error_text,
|
||||
selector,
|
||||
)
|
||||
return error_text or "Unknown error"
|
||||
return None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExtractedToken:
|
||||
"""Bearer token extracted from browser session."""
|
||||
@@ -81,7 +113,11 @@ async def detect_current_persona(page: PageLike) -> str | None:
|
||||
that may contain user email information.
|
||||
"""
|
||||
tokens = await discover_auth_tokens(page)
|
||||
_logger.debug("Discovered %d auth-related localStorage keys: %s", len(tokens), list(tokens.keys()))
|
||||
_logger.debug(
|
||||
"Discovered %d auth-related localStorage keys: %s",
|
||||
len(tokens),
|
||||
list(tokens.keys()),
|
||||
)
|
||||
if not tokens:
|
||||
_logger.debug("No auth tokens found in localStorage")
|
||||
return None
|
||||
@@ -90,7 +126,9 @@ async def detect_current_persona(page: PageLike) -> str | None:
|
||||
if "email" in tokens:
|
||||
email_value = tokens["email"]
|
||||
if "@" in email_value:
|
||||
_logger.info("Detected persona from localStorage 'email' key: %s", email_value)
|
||||
_logger.info(
|
||||
"Detected persona from localStorage 'email' key: %s", email_value
|
||||
)
|
||||
return email_value
|
||||
|
||||
# Second priority: Auth0 SPA SDK user keys (@@user@@ keys contain profile info)
|
||||
@@ -128,7 +166,9 @@ async def login_with_mfa(
|
||||
del _response
|
||||
# Check again after navigation - user might already be logged in
|
||||
if await email_input.count() == 0:
|
||||
_logger.debug("No email input found after navigation - user already logged in")
|
||||
_logger.debug(
|
||||
"No email input found after navigation - user already logged in"
|
||||
)
|
||||
return
|
||||
else:
|
||||
_logger.debug("No login URL and no email input - user already logged in")
|
||||
@@ -154,6 +194,80 @@ async def logout(page: PageLike) -> None:
|
||||
_logger.debug("No logout button found - user may not be logged in")
|
||||
|
||||
|
||||
async def login_with_verification_code(
|
||||
page: PageLike,
|
||||
verification_code: str,
|
||||
expected_email: str | None = None,
|
||||
) -> bool:
|
||||
"""Authenticate via verification code input on Auth0 page.
|
||||
|
||||
Used when Auth0 shows "Enter your email code to log in" page
|
||||
instead of processing a magic link directly.
|
||||
|
||||
Args:
|
||||
page: Browser page already on the Auth0 verification code page.
|
||||
verification_code: 6-digit code from email.
|
||||
expected_email: If provided, validate logged-in user matches this email.
|
||||
|
||||
Returns:
|
||||
True if authentication successful, False otherwise.
|
||||
"""
|
||||
from guide.app.browser.wait import wait_for_stable_page
|
||||
from guide.app.strings.selectors.login import LoginSelectors
|
||||
|
||||
_logger.info("Starting verification code login flow")
|
||||
|
||||
# Wait for code input field
|
||||
code_input = page.locator(LoginSelectors.VERIFICATION_CODE_INPUT)
|
||||
try:
|
||||
await code_input.wait_for(state="visible", timeout=5000)
|
||||
except Exception as exc:
|
||||
_logger.error("Verification code input not found: %s", exc)
|
||||
return False
|
||||
|
||||
# Fill verification code
|
||||
_logger.info("Filling verification code")
|
||||
await code_input.fill(verification_code)
|
||||
|
||||
# Click submit button
|
||||
submit_btn = page.locator(LoginSelectors.VERIFICATION_CODE_SUBMIT)
|
||||
if await submit_btn.count() > 0:
|
||||
_logger.info("Clicking verification code submit button")
|
||||
await submit_btn.first.click()
|
||||
else:
|
||||
_logger.warning("No verification code submit button found")
|
||||
return False
|
||||
|
||||
# Wait for auth redirect
|
||||
await wait_for_stable_page(page, stability_check_ms=8000)
|
||||
|
||||
# Log current URL after submission
|
||||
_logger.info("URL after verification code submission: %s", page.url)
|
||||
|
||||
# Check for error messages
|
||||
if await _check_for_auth_errors(
|
||||
page, Auth0ErrorSelectors.VERIFICATION_CODE_ERRORS, "Verification code"
|
||||
):
|
||||
return False
|
||||
|
||||
# Optionally validate logged-in user
|
||||
if expected_email:
|
||||
current = await detect_current_persona(page)
|
||||
if current and current.lower() == expected_email.lower():
|
||||
_logger.info("Verification code login successful for: %s", expected_email)
|
||||
return True
|
||||
|
||||
_logger.warning(
|
||||
"Verification code login validation failed - expected: %s, detected: %s",
|
||||
expected_email,
|
||||
current,
|
||||
)
|
||||
return False
|
||||
|
||||
_logger.info("Verification code login completed (no email validation requested)")
|
||||
return True
|
||||
|
||||
|
||||
async def login_with_otp_url(
|
||||
page: PageLike,
|
||||
otp_url: str,
|
||||
@@ -175,71 +289,83 @@ async def login_with_otp_url(
|
||||
from guide.app.browser.wait import wait_for_stable_page
|
||||
|
||||
_logger.info("Starting OTP login flow")
|
||||
_logger.debug("OTP URL: %s", f"{otp_url[:100]}..." if len(otp_url) > 100 else otp_url)
|
||||
_logger.info("OTP URL: %s", otp_url)
|
||||
|
||||
# Navigate to OTP URL
|
||||
_ = await page.goto(otp_url)
|
||||
try:
|
||||
_ = await page.goto(otp_url)
|
||||
except Exception as exc:
|
||||
_logger.error("Failed to navigate to OTP URL: %s", exc)
|
||||
return False
|
||||
|
||||
# Wait for page to stabilize
|
||||
await wait_for_stable_page(page, stability_check_ms=2000)
|
||||
|
||||
# Check for error messages (e.g., expired OTP)
|
||||
error_selectors = [
|
||||
"text=verification code has expired",
|
||||
"text=unauthorized",
|
||||
"text=invalid",
|
||||
"text=try to login again",
|
||||
]
|
||||
for selector in error_selectors:
|
||||
error_el = page.locator(selector)
|
||||
if await error_el.count() > 0:
|
||||
error_text = await error_el.first.text_content()
|
||||
_logger.warning("OTP error detected: %s", error_text)
|
||||
return False
|
||||
if await _check_for_auth_errors(
|
||||
page, Auth0ErrorSelectors.OTP_URL_ERRORS, "OTP navigation"
|
||||
):
|
||||
return False
|
||||
|
||||
# Check for OTP confirmation page - Auth0 shows "Almost there" with a LOG IN button
|
||||
# Try multiple selectors for the login confirmation button
|
||||
login_button_selectors = [
|
||||
"button:has-text('LOG IN')",
|
||||
"button:has-text('Log in')",
|
||||
"button:has-text('Login')",
|
||||
"[data-action-button-primary='true']",
|
||||
"form button[type='submit']",
|
||||
]
|
||||
|
||||
button_clicked = False
|
||||
for selector in login_button_selectors:
|
||||
for selector in Auth0ErrorSelectors.LOGIN_BUTTON_SELECTORS:
|
||||
login_btn = page.locator(selector)
|
||||
if await login_btn.count() > 0:
|
||||
_logger.info("Clicking OTP confirmation button: %s", selector)
|
||||
await login_btn.first.click()
|
||||
button_clicked = True
|
||||
# Wait for auth redirect after clicking
|
||||
await wait_for_stable_page(page, stability_check_ms=3000)
|
||||
# Wait for auth redirect after clicking - Auth0 can be slow
|
||||
await wait_for_stable_page(page, stability_check_ms=8000)
|
||||
break
|
||||
|
||||
if not button_clicked:
|
||||
_logger.debug("No OTP confirmation button found - may have auto-redirected")
|
||||
_logger.info("No OTP confirmation button found - may have auto-redirected")
|
||||
|
||||
# Log current URL after click attempt to diagnose redirect issues
|
||||
_logger.info("URL after OTP login flow: %s", page.url)
|
||||
|
||||
# Check for errors after clicking (in case click triggered error)
|
||||
for selector in error_selectors:
|
||||
error_el = page.locator(selector)
|
||||
if await error_el.count() > 0:
|
||||
error_text = await error_el.first.text_content()
|
||||
_logger.warning("OTP error after click: %s", error_text)
|
||||
return False
|
||||
if await _check_for_auth_errors(
|
||||
page, Auth0ErrorSelectors.OTP_URL_ERRORS, "OTP after click"
|
||||
):
|
||||
return False
|
||||
|
||||
# Optionally validate logged-in user matches expected email
|
||||
if expected_email:
|
||||
# Debug: log current URL and localStorage keys
|
||||
try:
|
||||
current_url = page.url
|
||||
_logger.info("Post-login URL: %s", current_url)
|
||||
ls_keys = await page.evaluate("Object.keys(localStorage)")
|
||||
_logger.info("localStorage keys: %s", ls_keys)
|
||||
except Exception as debug_exc:
|
||||
_logger.warning("Debug info collection failed: %s", debug_exc)
|
||||
|
||||
current = await detect_current_persona(page)
|
||||
if current and current.lower() == expected_email.lower():
|
||||
_logger.info("OTP login successful for: %s", expected_email)
|
||||
return True
|
||||
_logger.warning(
|
||||
"OTP login validation failed - expected: %s, detected: %s",
|
||||
expected_email,
|
||||
current,
|
||||
)
|
||||
|
||||
# Capture page content on failure for debugging
|
||||
try:
|
||||
page_title = await page.evaluate("document.title || ''")
|
||||
body_text = await page.evaluate(
|
||||
"document.body?.innerText?.substring(0, 500) || ''"
|
||||
)
|
||||
_logger.warning(
|
||||
"OTP login validation failed - expected: %s, detected: %s, page_title: %s, body_preview: %s",
|
||||
expected_email,
|
||||
current,
|
||||
page_title,
|
||||
body_text,
|
||||
)
|
||||
except Exception:
|
||||
_logger.warning(
|
||||
"OTP login validation failed - expected: %s, detected: %s",
|
||||
expected_email,
|
||||
current,
|
||||
)
|
||||
return False
|
||||
|
||||
_logger.info("OTP login completed (no email validation requested)")
|
||||
@@ -422,20 +548,16 @@ async def discover_auth_tokens(page: PageLike) -> dict[str, str]:
|
||||
_logger.debug("localStorage scan returned non-dict result")
|
||||
return {}
|
||||
result_dict = cast(dict[str, object], result)
|
||||
return {key: str(value) for key, value in result_dict.items() if value is not None}
|
||||
return {
|
||||
key: str(value) for key, value in result_dict.items() if value is not None
|
||||
}
|
||||
except Exception as exc:
|
||||
# Handle cases where page.evaluate fails (restricted pages, closed pages)
|
||||
_logger.debug("localStorage access failed: %s", exc)
|
||||
return {}
|
||||
|
||||
|
||||
def _is_jwt_format(value: str) -> bool:
|
||||
"""Check if value looks like a JWT (three base64 segments separated by dots)."""
|
||||
parts = value.split(".")
|
||||
if len(parts) != 3:
|
||||
return False
|
||||
# Each part should be non-empty and contain valid base64-ish characters
|
||||
return all(part and all(c.isalnum() or c in "-_=" for c in part) for part in parts)
|
||||
from guide.app.utils.jwt import is_jwt_format as _is_jwt_format
|
||||
|
||||
|
||||
def _extract_token_from_json(value: str) -> str | None:
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""Core session management service."""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from guide.app.utils.jwt import parse_jwt_expiry as _parse_jwt_expiry
|
||||
|
||||
from playwright.async_api import BrowserContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -14,7 +14,9 @@ if TYPE_CHECKING:
|
||||
from guide.app.auth.session_models import SessionData, SessionValidationResult
|
||||
from guide.app.auth.session_storage import SessionStorage
|
||||
from guide.app.browser.types import PageLike
|
||||
from guide.app.models.domain import ActionResult
|
||||
from guide.app.models.personas.models import DemoPersona
|
||||
from guide.app.utils.urls import extract_base_url
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -81,39 +83,50 @@ class SessionManager:
|
||||
"""
|
||||
return self._storage.load(persona_id)
|
||||
|
||||
def validate_offline(self, session: SessionData) -> SessionValidationResult:
|
||||
def validate_offline(
|
||||
self,
|
||||
session: SessionData,
|
||||
ttl_buffer_seconds: int = 30,
|
||||
) -> SessionValidationResult:
|
||||
"""Validate session without browser access.
|
||||
|
||||
Checks TTL expiry and JWT token expiry if available.
|
||||
Applies a buffer to prevent race conditions where session
|
||||
expires during navigation/injection (typically 5-10 seconds).
|
||||
|
||||
Args:
|
||||
session: Session data to validate.
|
||||
ttl_buffer_seconds: Minimum remaining TTL required (default 30s).
|
||||
Sessions with less remaining time are considered invalid
|
||||
to prevent expiry during restoration.
|
||||
|
||||
Returns:
|
||||
Validation result with is_valid flag and reason.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Check TTL expiry
|
||||
if now >= session.session_ttl_expires_at:
|
||||
return SessionValidationResult(
|
||||
is_valid=False,
|
||||
reason="session_ttl_expired",
|
||||
remaining_seconds=0,
|
||||
)
|
||||
|
||||
# Check JWT token expiry if available
|
||||
if session.token_expires_at and now >= session.token_expires_at:
|
||||
return SessionValidationResult(
|
||||
is_valid=False,
|
||||
reason="token_expired",
|
||||
remaining_seconds=0,
|
||||
)
|
||||
|
||||
# Calculate remaining time
|
||||
# Check TTL expiry (with buffer to prevent race conditions)
|
||||
remaining = session.session_ttl_expires_at - now
|
||||
remaining_seconds = int(remaining.total_seconds())
|
||||
|
||||
if remaining_seconds <= ttl_buffer_seconds:
|
||||
return SessionValidationResult(
|
||||
is_valid=False,
|
||||
reason="session_ttl_expired" if remaining_seconds <= 0 else "session_ttl_near_expiry",
|
||||
remaining_seconds=max(0, remaining_seconds),
|
||||
)
|
||||
|
||||
# Check JWT token expiry if available (also with buffer)
|
||||
if session.token_expires_at:
|
||||
token_remaining = session.token_expires_at - now
|
||||
token_remaining_seconds = int(token_remaining.total_seconds())
|
||||
if token_remaining_seconds <= ttl_buffer_seconds:
|
||||
return SessionValidationResult(
|
||||
is_valid=False,
|
||||
reason="token_expired" if token_remaining_seconds <= 0 else "token_near_expiry",
|
||||
remaining_seconds=max(0, token_remaining_seconds),
|
||||
)
|
||||
|
||||
return SessionValidationResult(
|
||||
is_valid=True,
|
||||
reason=None,
|
||||
@@ -237,6 +250,86 @@ class SessionManager:
|
||||
"""
|
||||
return self._storage.delete(persona_id)
|
||||
|
||||
async def restore_session(
|
||||
self,
|
||||
page: PageLike,
|
||||
persona: DemoPersona,
|
||||
) -> ActionResult | None:
|
||||
"""Try to restore existing session from disk.
|
||||
|
||||
Loads session from storage, validates it, injects into browser,
|
||||
and verifies the login state. Automatically invalidates expired
|
||||
or invalid sessions.
|
||||
|
||||
Args:
|
||||
page: Browser page instance.
|
||||
persona: Persona to restore session for.
|
||||
|
||||
Returns:
|
||||
ActionResult if session was restored successfully, None otherwise.
|
||||
"""
|
||||
session = self.load_session(persona.id)
|
||||
if not session:
|
||||
_logger.debug("No saved session found for persona %s", persona.id)
|
||||
return None
|
||||
|
||||
validation = self.validate_offline(session)
|
||||
if not validation.is_valid:
|
||||
_logger.info(
|
||||
"Saved session for %s is invalid: %s",
|
||||
persona.id,
|
||||
validation.reason,
|
||||
)
|
||||
_ = self.invalidate(persona.id)
|
||||
return None
|
||||
|
||||
# Extract base URL from origin
|
||||
try:
|
||||
base_url = extract_base_url(session.origin_url)
|
||||
except ValueError as exc:
|
||||
_logger.warning(
|
||||
"Invalid origin URL in session for %s: %s",
|
||||
persona.id,
|
||||
exc,
|
||||
)
|
||||
_ = self.invalidate(persona.id)
|
||||
return None
|
||||
|
||||
# Navigate to base URL first (required for localStorage to work)
|
||||
_ = await page.goto(base_url)
|
||||
|
||||
# Inject session localStorage
|
||||
await self.inject_local_storage(page, session)
|
||||
|
||||
# Reload page to pick up injected session (using evaluate to avoid full navigation)
|
||||
# This triggers the app to read localStorage without clearing it like goto() would
|
||||
_ = await page.evaluate("window.location.reload()")
|
||||
|
||||
# Validate session is actually working
|
||||
from guide.app.auth.session import detect_current_persona
|
||||
|
||||
current = await detect_current_persona(page)
|
||||
if current and current.lower() == persona.email.lower():
|
||||
_logger.info("Restored session for persona %s", persona.id)
|
||||
return ActionResult(
|
||||
details={
|
||||
"persona_id": persona.id,
|
||||
"email": persona.email,
|
||||
"status": "session_restored",
|
||||
"remaining_seconds": validation.remaining_seconds,
|
||||
}
|
||||
)
|
||||
|
||||
# Session injection failed - invalidate it
|
||||
_logger.warning(
|
||||
"Session injection failed for %s (expected: %s, got: %s)",
|
||||
persona.id,
|
||||
persona.email,
|
||||
current,
|
||||
)
|
||||
_ = self.invalidate(persona.id)
|
||||
return None
|
||||
|
||||
def list_sessions(self) -> list[str]:
|
||||
"""List all stored session persona IDs.
|
||||
|
||||
@@ -257,35 +350,7 @@ class SessionManager:
|
||||
Returns:
|
||||
Expiration datetime if found, None otherwise.
|
||||
"""
|
||||
try:
|
||||
parts = token.split(".")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
# Decode payload (second part)
|
||||
payload_b64 = parts[1]
|
||||
# Add padding if needed
|
||||
padding = 4 - len(payload_b64) % 4
|
||||
if padding != 4:
|
||||
payload_b64 += "=" * padding
|
||||
|
||||
payload_bytes = base64.urlsafe_b64decode(payload_b64)
|
||||
# json.loads returns Any; cast to object for type narrowing
|
||||
decoded = cast(object, json.loads(payload_bytes.decode("utf-8")))
|
||||
|
||||
if not isinstance(decoded, dict):
|
||||
return None
|
||||
|
||||
# Cast narrowed dict to typed dict for proper member access
|
||||
payload = cast("dict[str, object]", decoded)
|
||||
exp_value = payload.get("exp")
|
||||
if isinstance(exp_value, int | float):
|
||||
return datetime.fromtimestamp(exp_value, tz=timezone.utc)
|
||||
|
||||
except Exception as exc:
|
||||
_logger.debug("Failed to parse JWT expiry: %s", exc)
|
||||
|
||||
return None
|
||||
return _parse_jwt_expiry(token)
|
||||
|
||||
async def _extract_local_storage(self, page: PageLike) -> dict[str, str]:
|
||||
"""Extract all localStorage items from page."""
|
||||
@@ -310,4 +375,9 @@ class SessionManager:
|
||||
cast("dict[str, str | int | float | bool | None]", dict(c))
|
||||
for c in cookies
|
||||
]
|
||||
# ExtensionPage doesn't support cookie extraction - session will be localStorage-only
|
||||
_logger.warning(
|
||||
"Cookie extraction not supported for %s - session will use localStorage only",
|
||||
type(page).__name__,
|
||||
)
|
||||
return []
|
||||
|
||||
@@ -25,8 +25,8 @@ class SessionStorage:
|
||||
self._base_dir = base_dir
|
||||
|
||||
def _ensure_dir(self) -> None:
|
||||
"""Ensure storage directory exists."""
|
||||
self._base_dir.mkdir(parents=True, exist_ok=True)
|
||||
"""Ensure storage directory exists with secure permissions."""
|
||||
self._base_dir.mkdir(parents=True, exist_ok=True, mode=0o700)
|
||||
|
||||
def _session_path(self, persona_id: str) -> Path:
|
||||
"""Get path to session file for persona."""
|
||||
@@ -34,7 +34,10 @@ class SessionStorage:
|
||||
return self._base_dir / f"{safe_id}.session.json"
|
||||
|
||||
def save(self, session: SessionData) -> Path:
|
||||
"""Save session data to disk.
|
||||
"""Save session data to disk atomically.
|
||||
|
||||
Uses temp file + rename pattern to prevent corruption from
|
||||
concurrent writes or crashes mid-write.
|
||||
|
||||
Args:
|
||||
session: Session data to persist.
|
||||
@@ -44,7 +47,14 @@ class SessionStorage:
|
||||
"""
|
||||
self._ensure_dir()
|
||||
path = self._session_path(session.persona_id)
|
||||
_ = path.write_text(session.model_dump_json(indent=2))
|
||||
temp_path = path.with_suffix(".tmp")
|
||||
|
||||
# Write to temp file first
|
||||
_ = temp_path.write_text(session.model_dump_json(indent=2))
|
||||
|
||||
# Atomic rename (POSIX-compliant)
|
||||
_ = temp_path.replace(path)
|
||||
|
||||
_logger.info("Saved session for persona '%s' to %s", session.persona_id, path)
|
||||
return path
|
||||
|
||||
@@ -64,7 +74,13 @@ class SessionStorage:
|
||||
try:
|
||||
return SessionData.model_validate_json(path.read_text())
|
||||
except (ValueError, OSError) as exc:
|
||||
_logger.warning("Invalid session file for '%s': %s", persona_id, exc)
|
||||
_logger.warning(
|
||||
"Invalid/corrupted session file for '%s': %s - deleting",
|
||||
persona_id,
|
||||
exc,
|
||||
)
|
||||
# Delete corrupted session file to prevent repeated failures
|
||||
path.unlink(missing_ok=True)
|
||||
return None
|
||||
|
||||
def delete(self, persona_id: str) -> bool:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -13,6 +14,8 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
StorageState = dict[str, object] # Runtime fallback
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BrowserClient:
|
||||
"""Provides page access via a persistent browser pool with context isolation.
|
||||
@@ -58,19 +61,17 @@ class BrowserClient:
|
||||
ConfigError: If the host_id is invalid or not configured
|
||||
BrowserConnectionError: If the browser connection fails
|
||||
"""
|
||||
import logging
|
||||
_logger = logging.getLogger(__name__)
|
||||
_logger.info(f"[BrowserClient] open_page called for host_id: {host_id}")
|
||||
_logger.info("[BrowserClient] open_page called for host_id: %s", host_id)
|
||||
|
||||
context, page, should_close = await self.pool.allocate_context_and_page(
|
||||
host_id, storage_state=storage_state
|
||||
)
|
||||
|
||||
_logger.info(f"[BrowserClient] Got page from pool, should_close: {should_close}")
|
||||
_logger.info("[BrowserClient] Got page from pool, should_close: %s", should_close)
|
||||
try:
|
||||
yield page
|
||||
finally:
|
||||
_logger.info(f"[BrowserClient] Cleaning up, should_close: {should_close}")
|
||||
_logger.info("[BrowserClient] Cleaning up, should_close: %s", should_close)
|
||||
# Only close context for headless mode (not CDP/extension)
|
||||
if should_close and context is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
|
||||
@@ -16,7 +16,7 @@ from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import TypeGuard, TypedDict
|
||||
|
||||
from guide.app.browser.elements.mui import escape_selector
|
||||
from guide.app.browser.utils import escape_selector
|
||||
from guide.app.browser.types import PageLike
|
||||
from guide.app.models.domain.models import DebugInfo
|
||||
|
||||
@@ -807,7 +807,7 @@ async def inspect_dropdown(
|
||||
listbox_after = await inspect_listbox(page)
|
||||
|
||||
# Close dropdown
|
||||
await page.evaluate(_JS_CLOSE_DROPDOWN)
|
||||
_ = await page.evaluate(_JS_CLOSE_DROPDOWN)
|
||||
|
||||
# Get component structure
|
||||
escaped = escape_selector(selector)
|
||||
|
||||
@@ -7,8 +7,7 @@ import contextlib
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from guide.app.browser.extension_client import ExtensionPage
|
||||
from guide.app.browser.types import PageLike
|
||||
from guide.app.browser.types import PageLike, PageWithTrustedClick, supports_trusted_click
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,7 +22,7 @@ async def close_all_dropdowns(page: PageLike) -> None:
|
||||
_logger.debug("[Dropdown] Closing all open dropdowns")
|
||||
|
||||
# Mark ALL current listboxes as closed so subsequent queries don't find them
|
||||
await page.evaluate(
|
||||
_ = await page.evaluate(
|
||||
"""
|
||||
(() => {
|
||||
const listboxes = document.querySelectorAll('[role="listbox"]:not([data-dropdown-closed])');
|
||||
@@ -38,8 +37,8 @@ async def close_all_dropdowns(page: PageLike) -> None:
|
||||
"""
|
||||
)
|
||||
|
||||
# Check if trusted_click is available (ExtensionPage only)
|
||||
if hasattr(page, "trusted_click"):
|
||||
# Check if trusted_click is available (extension mode)
|
||||
if supports_trusted_click(page):
|
||||
coords = await page.evaluate(
|
||||
"""
|
||||
(() => {
|
||||
@@ -58,15 +57,15 @@ async def close_all_dropdowns(page: PageLike) -> None:
|
||||
})();
|
||||
"""
|
||||
)
|
||||
if coords and isinstance(coords, dict):
|
||||
if isinstance(coords, dict):
|
||||
coords_dict = cast(dict[str, object], coords)
|
||||
x_val = coords_dict.get("x", 100)
|
||||
y_val = coords_dict.get("y", 100)
|
||||
x = float(x_val) if isinstance(x_val, (int, float)) else 100.0
|
||||
y = float(y_val) if isinstance(y_val, (int, float)) else 100.0
|
||||
with contextlib.suppress(Exception):
|
||||
if isinstance(page, ExtensionPage):
|
||||
await page.trusted_click(x, y)
|
||||
trusted_page: PageWithTrustedClick = page
|
||||
await trusted_page.trusted_click(x, y)
|
||||
await page.wait_for_timeout(150)
|
||||
else:
|
||||
# Fallback for Playwright - blur only
|
||||
|
||||
@@ -77,7 +77,7 @@ async def select_typeahead(
|
||||
_logger.info("[Typeahead] Typing search text: '%s'", search_text)
|
||||
|
||||
# Clear existing value first
|
||||
await page.evaluate(
|
||||
_ = await page.evaluate(
|
||||
f"""
|
||||
(() => {{
|
||||
const field = document.querySelector('{field_selector_js}');
|
||||
|
||||
@@ -73,8 +73,7 @@ async def infer_type_from_element(page: PageLike, selector: str) -> str:
|
||||
Returns:
|
||||
Inferred field type
|
||||
"""
|
||||
# Local import to avoid circular import
|
||||
from guide.app.browser.elements.mui import escape_selector
|
||||
from guide.app.browser.utils import escape_selector
|
||||
|
||||
escaped = escape_selector(selector)
|
||||
result = await page.evaluate(
|
||||
|
||||
@@ -17,7 +17,7 @@ from guide.app.browser.elements._type_guards import (
|
||||
is_dict_str_object,
|
||||
is_list_of_objects,
|
||||
)
|
||||
from guide.app.browser.elements.mui import escape_selector
|
||||
from guide.app.browser.utils import escape_selector
|
||||
from guide.app.browser.types import PageLike
|
||||
|
||||
|
||||
|
||||
86
src/guide/app/browser/elements/layout.py
Normal file
86
src/guide/app/browser/elements/layout.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Layout-oriented element helpers (accordions, panels, etc.)."""
|
||||
|
||||
import logging
|
||||
from typing import TypedDict
|
||||
|
||||
from playwright.async_api import Error as PlaywrightError
|
||||
from playwright.async_api import TimeoutError as PlaywrightTimeoutError
|
||||
|
||||
from guide.app import errors
|
||||
from guide.app.browser.types import PageLike, PageLocator
|
||||
from guide.app.core.config import Timeouts
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AccordionCollapseResult(TypedDict):
|
||||
"""Result from collapsing accordions."""
|
||||
|
||||
collapsed_count: int
|
||||
total_found: int
|
||||
failed_indices: list[int]
|
||||
|
||||
|
||||
class Accordion:
|
||||
"""Collapse/expand helpers for accordion buttons."""
|
||||
|
||||
page: PageLike
|
||||
timeouts: Timeouts
|
||||
expanded_icon_selector: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
page: PageLike,
|
||||
*,
|
||||
timeouts: Timeouts | None = None,
|
||||
expanded_icon_selector: str = 'svg[data-testid="KeyboardArrowUpOutlinedIcon"]',
|
||||
) -> None:
|
||||
self.page = page
|
||||
self.timeouts = timeouts or Timeouts()
|
||||
self.expanded_icon_selector = expanded_icon_selector
|
||||
|
||||
async def collapse_all(
|
||||
self,
|
||||
buttons_selector: str,
|
||||
timeout_ms: int | None = None,
|
||||
) -> AccordionCollapseResult:
|
||||
"""Collapse all expanded accordion buttons that match selector."""
|
||||
buttons = self.page.locator(buttons_selector)
|
||||
count = await buttons.count()
|
||||
|
||||
if count == 0:
|
||||
return {"collapsed_count": 0, "total_found": 0, "failed_indices": []}
|
||||
|
||||
collapsed_count = 0
|
||||
failed_indices: list[int] = []
|
||||
max_wait = timeout_ms if timeout_ms is not None else self.timeouts.element_default
|
||||
|
||||
for index in range(count):
|
||||
button: PageLocator = buttons.nth(index)
|
||||
try:
|
||||
icon = button.locator(self.expanded_icon_selector)
|
||||
if await icon.count() > 0:
|
||||
await button.click(timeout=max_wait)
|
||||
collapsed_count += 1
|
||||
except (PlaywrightTimeoutError, PlaywrightError) as exc:
|
||||
_logger.debug("Failed to collapse accordion %s: %s", index, exc)
|
||||
failed_indices.append(index)
|
||||
|
||||
if count > 0 and collapsed_count == 0:
|
||||
raise errors.ActionExecutionError(
|
||||
f"Failed to collapse any accordions (found {count}, all failed)",
|
||||
details={
|
||||
"selector": buttons_selector,
|
||||
"found_count": count,
|
||||
"failed_indices": ",".join(str(i) for i in failed_indices),
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"collapsed_count": collapsed_count,
|
||||
"total_found": count,
|
||||
"failed_indices": failed_indices,
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["Accordion", "AccordionCollapseResult"]
|
||||
@@ -34,23 +34,10 @@ class DropdownResult(TypedDict):
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Selector Utilities
|
||||
# Selector Utilities (imported from browser/utils.py to avoid circular imports)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def escape_selector(selector: str) -> str:
|
||||
"""Escape a CSS selector for safe use in JavaScript code.
|
||||
|
||||
Handles backslashes, single quotes, and double quotes.
|
||||
|
||||
Args:
|
||||
selector: Raw CSS selector string
|
||||
|
||||
Returns:
|
||||
Escaped selector safe for JS string interpolation
|
||||
"""
|
||||
return selector.replace("\\", "\\\\").replace("'", "\\'").replace('"', '\\"')
|
||||
|
||||
from guide.app.browser.utils import escape_selector, escape_js_string
|
||||
|
||||
# Backward compatibility alias
|
||||
_escape_selector = escape_selector
|
||||
@@ -462,6 +449,7 @@ __all__ = [
|
||||
"DropdownResult",
|
||||
# Selector utilities
|
||||
"escape_selector",
|
||||
"escape_js_string",
|
||||
"_escape_selector",
|
||||
# Keyboard
|
||||
"send_key",
|
||||
|
||||
@@ -5,6 +5,7 @@ executed via the browser extension, avoiding CDP page refresh issues.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
@@ -13,12 +14,15 @@ from typing import Protocol, cast
|
||||
from websockets.asyncio.server import Server, ServerConnection, serve
|
||||
|
||||
from guide.app.browser.types import PageLocator
|
||||
from guide.app.core.config import DEFAULT_EXTENSION_PORT
|
||||
from guide.app.errors import BrowserConnectionError
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
# JSON-serializable values that can be returned from JavaScript evaluation
|
||||
type JSONValue = str | int | float | bool | None | dict[str, JSONValue] | list[JSONValue]
|
||||
type JSONValue = (
|
||||
str | int | float | bool | None | dict[str, JSONValue] | list[JSONValue]
|
||||
)
|
||||
|
||||
|
||||
class PageWithExtensions(Protocol):
|
||||
@@ -84,8 +88,7 @@ class ExtensionPage:
|
||||
|
||||
def _escape_selector(self, selector: str) -> str:
|
||||
"""Escape a CSS selector for use in JavaScript single-quoted strings."""
|
||||
# Local import to avoid circular import with elements.mui
|
||||
from guide.app.browser.elements.mui import escape_selector
|
||||
from guide.app.browser.utils import escape_selector
|
||||
|
||||
return escape_selector(selector)
|
||||
|
||||
@@ -132,7 +135,11 @@ class ExtensionPage:
|
||||
_ = self._pending.pop(request_id, None)
|
||||
raise BrowserConnectionError(
|
||||
f"Extension command timeout after {timeout}s: {action}",
|
||||
details={"action": action, "payload": payload, "timeout_seconds": timeout},
|
||||
details={
|
||||
"action": action,
|
||||
"payload": payload,
|
||||
"timeout_seconds": timeout,
|
||||
},
|
||||
) from e
|
||||
|
||||
async def click(self, selector: str) -> None:
|
||||
@@ -202,7 +209,9 @@ class ExtensionPage:
|
||||
Exception: If element not found or fill fails
|
||||
"""
|
||||
selector_escaped = self._escape_selector(selector)
|
||||
value_escaped = value.replace("\\", "\\\\").replace("'", "\\'").replace('"', '\\"')
|
||||
value_escaped = (
|
||||
value.replace("\\", "\\\\").replace("'", "\\'").replace('"', '\\"')
|
||||
)
|
||||
|
||||
js_code = f"""
|
||||
(() => {{
|
||||
@@ -272,7 +281,9 @@ class ExtensionPage:
|
||||
return await self.eval_js(code)
|
||||
return await self.eval_js(expression)
|
||||
|
||||
async def click_element_with_text(self, selector: str, text: str, timeout: int = 5000) -> None:
|
||||
async def click_element_with_text(
|
||||
self, selector: str, text: str, timeout: int = 5000
|
||||
) -> None:
|
||||
"""Click an element matching selector that contains specific text.
|
||||
|
||||
Uses content script's CLICK_TEXT action with element search.
|
||||
@@ -332,13 +343,11 @@ class ExtensionPage:
|
||||
This method exists only for API compatibility with Playwright's Page interface.
|
||||
|
||||
Args:
|
||||
url: URL to navigate to (ignored)
|
||||
timeout: Maximum navigation time in milliseconds (ignored)
|
||||
wait_until: When to consider navigation complete (ignored)
|
||||
referer: Referer header value (ignored)
|
||||
_url: URL to navigate to (ignored)
|
||||
_timeout: Maximum navigation time in milliseconds (ignored)
|
||||
_wait_until: When to consider navigation complete (ignored)
|
||||
_referer: Referer header value (ignored)
|
||||
"""
|
||||
# Do absolutely nothing - user has already navigated to the correct page
|
||||
pass
|
||||
|
||||
def locator(self, selector: str) -> "PageLocator":
|
||||
"""Create a locator for the given selector.
|
||||
@@ -349,7 +358,7 @@ class ExtensionPage:
|
||||
Returns:
|
||||
ExtensionLocator instance
|
||||
"""
|
||||
return cast(PageLocator, ExtensionLocator(self, selector))
|
||||
return cast(PageLocator, cast(object, ExtensionLocator(self, selector)))
|
||||
|
||||
async def wait_for_selector(
|
||||
self,
|
||||
@@ -584,11 +593,19 @@ class ExtensionLocator:
|
||||
query is scoped to the same parent element.
|
||||
"""
|
||||
|
||||
return cast(PageLocator, ExtensionLocator(self._page, selector, None, self))
|
||||
return cast(
|
||||
PageLocator,
|
||||
cast(object, ExtensionLocator(self._page, selector, None, self)),
|
||||
)
|
||||
|
||||
async def click(self) -> None:
|
||||
"""Click the located element."""
|
||||
async def click(self, *, timeout: int | float | None = None) -> None:
|
||||
"""Click the located element.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait in milliseconds (ignored for extension mode)
|
||||
"""
|
||||
|
||||
_ = timeout # Unused but required for PageLocator protocol compatibility
|
||||
js_code = self._build_chain_eval(
|
||||
"target.click(); return true;",
|
||||
"throw new Error('Element not found');",
|
||||
@@ -598,7 +615,9 @@ class ExtensionLocator:
|
||||
async def fill(self, value: str) -> None:
|
||||
"""Fill the located element with a value."""
|
||||
|
||||
escaped_value = value.replace("\\", "\\\\").replace("'", "\\'").replace('"', '\\"')
|
||||
escaped_value = (
|
||||
value.replace("\\", "\\\\").replace("'", "\\'").replace('"', '\\"')
|
||||
)
|
||||
js_code = self._build_chain_eval(
|
||||
f"""
|
||||
const el = target;
|
||||
@@ -631,7 +650,9 @@ class ExtensionLocator:
|
||||
async def type(self, text: str) -> None:
|
||||
"""Type text into the located element."""
|
||||
|
||||
escaped_text = text.replace("\\", "\\\\").replace("'", "\\'").replace('"', '\\"')
|
||||
escaped_text = (
|
||||
text.replace("\\", "\\\\").replace("'", "\\'").replace('"', '\\"')
|
||||
)
|
||||
js_code = self._build_chain_eval(
|
||||
f"""
|
||||
const el = target;
|
||||
@@ -742,14 +763,22 @@ class ExtensionLocator:
|
||||
async def text_content(self) -> str | None:
|
||||
"""Get the text content of the located element."""
|
||||
|
||||
js_code = self._build_chain_eval("return target ? target.textContent : null;", "return null;")
|
||||
js_code = self._build_chain_eval(
|
||||
"return target ? target.textContent : null;", "return null;"
|
||||
)
|
||||
result = await self._page.eval_js(js_code)
|
||||
return str(result) if result is not None else None
|
||||
|
||||
def nth(self, index: int) -> "PageLocator":
|
||||
"""Get the nth matching element."""
|
||||
|
||||
return cast(PageLocator, ExtensionLocator(self._page, self._selector, index, self._parent))
|
||||
return cast(
|
||||
PageLocator,
|
||||
cast(
|
||||
object,
|
||||
ExtensionLocator(self._page, self._selector, index, self._parent),
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def first(self) -> "PageLocator":
|
||||
@@ -770,7 +799,9 @@ class ExtensionClient:
|
||||
await page.click("button")
|
||||
"""
|
||||
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 17373) -> None:
|
||||
def __init__(
|
||||
self, host: str = "0.0.0.0", port: int = DEFAULT_EXTENSION_PORT
|
||||
) -> None:
|
||||
self._host: str = host
|
||||
self._port: int = port
|
||||
self._server: Server | None = None
|
||||
@@ -809,11 +840,12 @@ class ExtensionClient:
|
||||
try:
|
||||
_ = await asyncio.wait_for(self._connected.wait(), timeout=10.0)
|
||||
_logger.info("Extension connected successfully")
|
||||
except asyncio.TimeoutError as exc: # pragma: no cover - covered by unit test raising connection error
|
||||
except (
|
||||
asyncio.TimeoutError
|
||||
) as exc: # pragma: no cover - covered by unit test raising connection error
|
||||
await self.close()
|
||||
raise BrowserConnectionError(
|
||||
"Extension did not connect within 10 seconds. "
|
||||
"Make sure Chrome is running with the Terminator Bridge extension loaded.",
|
||||
"Extension did not connect within 10 seconds. Make sure Chrome is running with the Terminator Bridge extension loaded.",
|
||||
details={"host": self._host, "port": self._port, "timeout_seconds": 10},
|
||||
) from exc
|
||||
|
||||
@@ -853,22 +885,39 @@ class ExtensionClient:
|
||||
_logger.info("Extension connected")
|
||||
self._ws = websocket
|
||||
self._page = ExtensionPage(self)
|
||||
# Set connected flag only AFTER page is initialized to prevent race condition
|
||||
self._connected.set()
|
||||
|
||||
try:
|
||||
async for message in websocket:
|
||||
# Use timeout to detect hung connections
|
||||
while True:
|
||||
try:
|
||||
data = cast(dict[str, JSONValue], json.loads(message))
|
||||
if self._page:
|
||||
self._page.handle_response(data)
|
||||
except json.JSONDecodeError:
|
||||
_logger.warning(f"Invalid JSON from extension: {message}")
|
||||
except Exception as e:
|
||||
_logger.error(f"Error handling extension message: {e}")
|
||||
# 60s timeout - if no message in 60s, check connection health
|
||||
message = await asyncio.wait_for(websocket.recv(), timeout=60.0)
|
||||
try:
|
||||
data = cast(dict[str, JSONValue], json.loads(message))
|
||||
if self._page:
|
||||
self._page.handle_response(data)
|
||||
except json.JSONDecodeError:
|
||||
_logger.warning("Invalid JSON from extension: %s", message)
|
||||
except Exception as e:
|
||||
_logger.error("Error handling extension message: %s", e)
|
||||
except asyncio.TimeoutError:
|
||||
# No message in 60s - check if connection is still alive
|
||||
_logger.debug("No message from extension for 60s, sending ping")
|
||||
try:
|
||||
_ = await websocket.ping()
|
||||
except Exception:
|
||||
_logger.warning("Extension ping failed, connection may be dead")
|
||||
break
|
||||
except Exception as e:
|
||||
_logger.error(f"Extension connection error: {e}")
|
||||
_logger.error("Extension connection error: %s", e)
|
||||
finally:
|
||||
_logger.info("Extension disconnected")
|
||||
# Ensure WebSocket is properly closed before clearing reference
|
||||
with contextlib.suppress(Exception):
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
self._connected.clear()
|
||||
|
||||
|
||||
@@ -1,395 +1,30 @@
|
||||
"""High-level page interaction helpers for demo actions.
|
||||
|
||||
Provides a stateful wrapper around Playwright Page with:
|
||||
- Integrated wait utilities for page conditions
|
||||
- Diagnostics capture for debugging
|
||||
- Accordion collapse and other UI patterns
|
||||
- Fluent API for common interaction sequences
|
||||
Composes small, focused mixins so each responsibility stays isolated while
|
||||
keeping the public API stable for actions.
|
||||
"""
|
||||
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from typing import TypedDict
|
||||
|
||||
from playwright.async_api import TimeoutError as PlaywrightTimeoutError
|
||||
|
||||
from guide.app import errors
|
||||
from guide.app.browser.elements import dropdown
|
||||
from guide.app.browser.types import PageLike, PageLocator
|
||||
from guide.app.browser.wait import (
|
||||
wait_for_network_idle,
|
||||
wait_for_navigation,
|
||||
wait_for_selector,
|
||||
wait_for_stable_page,
|
||||
)
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
from guide.app.browser.elements.layout import Accordion, AccordionCollapseResult
|
||||
from guide.app.browser.mixins import DiagnosticsMixin, InteractionMixin, WaitMixin
|
||||
from guide.app.browser.types import PageLike
|
||||
from guide.app.core.config import Timeouts
|
||||
|
||||
|
||||
class AccordionCollapseResult(TypedDict):
|
||||
"""Result from collapse_accordions operation."""
|
||||
class PageHelpers(WaitMixin, DiagnosticsMixin, InteractionMixin):
|
||||
"""High-level page interaction wrapper for demo actions."""
|
||||
|
||||
collapsed_count: int
|
||||
"""Number of successfully collapsed buttons."""
|
||||
|
||||
total_found: int
|
||||
"""Total number of buttons found."""
|
||||
|
||||
failed_indices: list[int]
|
||||
"""List of button indices that failed to click."""
|
||||
|
||||
|
||||
class DropdownOption(TypedDict):
|
||||
"""Dropdown option from DOM query."""
|
||||
|
||||
index: int
|
||||
"""Zero-based index in the dropdown list."""
|
||||
|
||||
text: str
|
||||
"""Visible text content of the option."""
|
||||
|
||||
|
||||
class PageHelpers:
|
||||
"""High-level page interaction wrapper for demo actions.
|
||||
|
||||
Wraps a Playwright Page instance with convenient methods for:
|
||||
- Waiting on page conditions (selector, network idle, stability)
|
||||
- Capturing diagnostics (screenshot, HTML, logs)
|
||||
- Common UI patterns (fill and advance, click and wait, search and select)
|
||||
- Accordion collapse and similar UI operations
|
||||
|
||||
Usage:
|
||||
async with browser_client.open_page(host_id) as page:
|
||||
helpers = PageHelpers(page)
|
||||
await helpers.fill_and_advance(
|
||||
selector=app_strings.intake.description_field,
|
||||
value="My request",
|
||||
next_selector=app_strings.intake.next_button,
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, page: PageLike) -> None:
|
||||
"""Initialize helpers with a Playwright page.
|
||||
|
||||
Args:
|
||||
page: The Playwright page instance to wrap
|
||||
"""
|
||||
def __init__(self, page: PageLike, *, timeouts: Timeouts | None = None) -> None:
|
||||
self.page: PageLike = page
|
||||
|
||||
# --- Wait utilities (wrapped for convenience) ---
|
||||
|
||||
async def wait_for_selector(
|
||||
self,
|
||||
selector: str,
|
||||
timeout_ms: int = 5000,
|
||||
) -> None:
|
||||
"""Wait for selector to appear in the DOM.
|
||||
|
||||
Args:
|
||||
selector: CSS or Playwright selector string
|
||||
timeout_ms: Maximum time to wait in milliseconds (default: 5000)
|
||||
|
||||
Raises:
|
||||
GuideError: If selector not found within timeout
|
||||
"""
|
||||
await wait_for_selector(self.page, selector, timeout_ms)
|
||||
|
||||
async def wait_for_network_idle(self, timeout_ms: int = 5000) -> None:
|
||||
"""Wait for network to become idle (no active requests).
|
||||
|
||||
Args:
|
||||
timeout_ms: Maximum time to wait in milliseconds (default: 5000)
|
||||
|
||||
Raises:
|
||||
GuideError: If network does not idle within timeout
|
||||
"""
|
||||
await wait_for_network_idle(self.page, timeout_ms)
|
||||
|
||||
async def wait_for_navigation(self, timeout_ms: int = 5000) -> None:
|
||||
"""Wait for page navigation to complete.
|
||||
|
||||
Args:
|
||||
timeout_ms: Maximum time to wait in milliseconds (default: 5000)
|
||||
|
||||
Raises:
|
||||
GuideError: If navigation does not complete within timeout
|
||||
"""
|
||||
await wait_for_navigation(self.page, timeout_ms)
|
||||
|
||||
async def wait_for_stable(
|
||||
self,
|
||||
stability_check_ms: int = 500,
|
||||
samples: int = 3,
|
||||
) -> None:
|
||||
"""Wait for page to become visually stable (DOM not changing).
|
||||
|
||||
Args:
|
||||
stability_check_ms: Delay between stability checks in ms (default: 500)
|
||||
samples: Number of stable samples required (default: 3)
|
||||
|
||||
Raises:
|
||||
GuideError: If page does not stabilize after retries
|
||||
"""
|
||||
await wait_for_stable_page(self.page, stability_check_ms, samples)
|
||||
|
||||
# --- Diagnostics ---
|
||||
|
||||
async def capture_diagnostics(self):
|
||||
"""Capture all diagnostic information (screenshot, HTML, logs).
|
||||
|
||||
Returns:
|
||||
DebugInfo with screenshot, HTML content, and console logs
|
||||
"""
|
||||
# Lazy import to avoid circular dependencies with Pydantic models
|
||||
from guide.app.browser.diagnostics import capture_all_diagnostics
|
||||
|
||||
return await capture_all_diagnostics(self.page)
|
||||
|
||||
# --- High-level UI operations ---
|
||||
|
||||
async def fill_and_advance(
|
||||
self,
|
||||
selector: str,
|
||||
value: str,
|
||||
next_selector: str,
|
||||
wait_for_idle: bool = True,
|
||||
) -> None:
|
||||
"""Fill a field and click next button (common pattern).
|
||||
|
||||
Fills an input field with a value, then clicks a next/continue button.
|
||||
Optionally waits for network idle after the click.
|
||||
|
||||
Args:
|
||||
selector: Field selector to fill
|
||||
value: Value to enter in the field
|
||||
next_selector: Button selector to click after filling
|
||||
wait_for_idle: Wait for network idle after click (default: True)
|
||||
"""
|
||||
await self.page.fill(selector, value)
|
||||
await self.page.click(next_selector)
|
||||
if wait_for_idle:
|
||||
await self.wait_for_network_idle()
|
||||
|
||||
async def search_and_select(
|
||||
self,
|
||||
search_input: str,
|
||||
query: str,
|
||||
result_selector: str,
|
||||
index: int = 0,
|
||||
) -> None:
|
||||
"""Type in search box and select result (common pattern).
|
||||
|
||||
Fills a search input, waits for network idle, then clicks a result item.
|
||||
|
||||
Args:
|
||||
search_input: Search input field selector
|
||||
query: Search query text to type
|
||||
result_selector: Result item selector
|
||||
index: Which result to click (default: 0 for first)
|
||||
"""
|
||||
await self.page.fill(search_input, query)
|
||||
await self.wait_for_network_idle()
|
||||
results = self.page.locator(result_selector)
|
||||
await results.nth(index).click()
|
||||
|
||||
async def click_and_wait(
|
||||
self,
|
||||
selector: str,
|
||||
wait_for_idle: bool = True,
|
||||
wait_for_stable: bool = False,
|
||||
) -> None:
|
||||
"""Click element and optionally wait for page state.
|
||||
|
||||
Args:
|
||||
selector: Element to click
|
||||
wait_for_idle: Wait for network idle after click (default: True)
|
||||
wait_for_stable: Wait for page stability after click (default: False)
|
||||
"""
|
||||
await self.page.click(selector)
|
||||
if wait_for_idle:
|
||||
await self.wait_for_network_idle()
|
||||
if wait_for_stable:
|
||||
await self.wait_for_stable()
|
||||
|
||||
# --- Dropdown operations ---
|
||||
|
||||
async def select_dropdown_options(
|
||||
self,
|
||||
field_selector: str,
|
||||
target_values: list[str],
|
||||
close_after: bool = True, # Reserved for future implementation
|
||||
) -> dict[str, list[str]]:
|
||||
_ = close_after # Reserved for future implementation
|
||||
# Delegate to dropdown helper for consistency
|
||||
result = await dropdown.select_multi(self.page, field_selector, target_values)
|
||||
return {
|
||||
"selected": result["selected"],
|
||||
"not_found": result["not_found"],
|
||||
"available_options": result.get("available", []),
|
||||
}
|
||||
|
||||
async def _send_key(self, key: str) -> None:
|
||||
"""Send keyboard event to active element.
|
||||
|
||||
Internal helper for dropdown navigation.
|
||||
|
||||
Args:
|
||||
key: Key name (ArrowDown, ArrowUp, Enter, etc.)
|
||||
"""
|
||||
keycode_map = {
|
||||
"ArrowDown": 40,
|
||||
"ArrowUp": 38,
|
||||
"Enter": 13,
|
||||
"Escape": 27,
|
||||
"Tab": 9,
|
||||
}
|
||||
keycode = keycode_map.get(key, 0)
|
||||
|
||||
_ = await self.page.evaluate(
|
||||
f"""
|
||||
(() => {{
|
||||
const event = new KeyboardEvent('keydown', {{
|
||||
key: '{key}',
|
||||
code: '{key}',
|
||||
keyCode: {keycode},
|
||||
which: {keycode},
|
||||
bubbles: true,
|
||||
cancelable: true,
|
||||
composed: true
|
||||
}});
|
||||
if (document.activeElement) {{
|
||||
document.activeElement.dispatchEvent(event);
|
||||
}}
|
||||
return true;
|
||||
}})()
|
||||
"""
|
||||
)
|
||||
|
||||
async def _collapse_open_listboxes(self) -> None:
|
||||
"""Hide any visible role=listbox elements without dismissing modals."""
|
||||
with contextlib.suppress(Exception):
|
||||
_ = await self.page.evaluate(
|
||||
"""
|
||||
(() => {
|
||||
const boxes = Array.from(document.querySelectorAll('[role="listbox"]'));
|
||||
boxes.forEach(box => {
|
||||
if (box && box.style.display !== 'none') {
|
||||
box.style.setProperty('display', 'none', 'important');
|
||||
}
|
||||
});
|
||||
return true;
|
||||
})();
|
||||
"""
|
||||
)
|
||||
|
||||
async def _wait_for_role_option(self, timeout_ms: int = 3000) -> None:
|
||||
"""Wait for any element with role="option" using JS polling.
|
||||
|
||||
Works in extension mode where Playwright's wait_for_selector isn't available.
|
||||
"""
|
||||
_ = await self.page.evaluate(
|
||||
f"""
|
||||
(() => {{
|
||||
const timeout = {timeout_ms};
|
||||
const start = Date.now();
|
||||
return new Promise((resolve, reject) => {{
|
||||
const check = () => {{
|
||||
if (document.querySelector('[role="option"]')) {{
|
||||
resolve(true);
|
||||
return;
|
||||
}}
|
||||
if (Date.now() - start > timeout) {{
|
||||
reject(new Error('Timeout waiting for role=option'));
|
||||
return;
|
||||
}}
|
||||
setTimeout(check, 100);
|
||||
}};
|
||||
check();
|
||||
}});
|
||||
}})()
|
||||
"""
|
||||
)
|
||||
|
||||
# --- Accordion operations ---
|
||||
self.timeouts: Timeouts = timeouts or Timeouts()
|
||||
|
||||
async def collapse_accordions(
|
||||
self,
|
||||
selector: str,
|
||||
_timeout_ms: int = 5000,
|
||||
timeout_ms: int | None = None,
|
||||
) -> AccordionCollapseResult:
|
||||
"""Collapse all expanded accordion buttons matching selector.
|
||||
|
||||
Detects expanded state by checking for SVG with
|
||||
data-testid="KeyboardArrowUpOutlinedIcon" (Material-UI chevron up icon).
|
||||
Clicks expanded buttons to collapse them.
|
||||
|
||||
Args:
|
||||
selector: CSS selector for accordion buttons
|
||||
_timeout_ms: Reserved for future timeout implementation (not currently used)
|
||||
|
||||
Returns:
|
||||
Dict with keys:
|
||||
- collapsed_count: Number of successfully collapsed buttons
|
||||
- total_found: Total number of buttons found
|
||||
- failed_indices: List of indices that failed to click
|
||||
|
||||
Raises:
|
||||
ActionExecutionError: If no buttons found or all clicks failed
|
||||
"""
|
||||
# Find all buttons matching the selector
|
||||
buttons = self.page.locator(selector)
|
||||
count = await buttons.count()
|
||||
|
||||
if count == 0:
|
||||
# No buttons found - return success with zero count
|
||||
return {
|
||||
"collapsed_count": 0,
|
||||
"total_found": 0,
|
||||
"failed_indices": [],
|
||||
}
|
||||
|
||||
collapsed_count = 0
|
||||
failed_indices: list[int] = []
|
||||
|
||||
# Click each button that contains the "up" chevron icon (expanded state)
|
||||
# Material-UI uses KeyboardArrowUpOutlinedIcon when expanded
|
||||
for i in range(count):
|
||||
try:
|
||||
button: PageLocator = buttons.nth(i)
|
||||
# Check if this specific button contains the up icon
|
||||
up_icon: PageLocator = button.locator(
|
||||
'svg[data-testid="KeyboardArrowUpOutlinedIcon"]'
|
||||
)
|
||||
# Fast check: if icon exists, button is expanded
|
||||
if await up_icon.count() > 0:
|
||||
await button.click()
|
||||
collapsed_count += 1
|
||||
except PlaywrightTimeoutError:
|
||||
# Timeout on click - track failure but continue
|
||||
failed_indices.append(i)
|
||||
except Exception:
|
||||
# Other errors (element gone, stale reference, etc.)
|
||||
failed_indices.append(i)
|
||||
|
||||
# If total failure (found buttons but couldn't collapse any),
|
||||
# raise error with details
|
||||
if count > 0 and collapsed_count == 0:
|
||||
failed_indices_str = ",".join(str(i) for i in failed_indices)
|
||||
raise errors.ActionExecutionError(
|
||||
f"Failed to collapse any accordions (found {count}, all failed)",
|
||||
details={
|
||||
"selector": selector,
|
||||
"found_count": count,
|
||||
"failed_indices": failed_indices_str,
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"collapsed_count": collapsed_count,
|
||||
"total_found": count,
|
||||
"failed_indices": failed_indices,
|
||||
}
|
||||
"""Collapse all expanded accordion buttons matching selector."""
|
||||
accordion = Accordion(self.page, timeouts=self.timeouts)
|
||||
return await accordion.collapse_all(selector, timeout_ms=timeout_ms)
|
||||
|
||||
|
||||
__all__ = ["PageHelpers", "AccordionCollapseResult"]
|
||||
|
||||
7
src/guide/app/browser/mixins/__init__.py
Normal file
7
src/guide/app/browser/mixins/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Composable mixins used by PageHelpers to keep responsibilities focused."""
|
||||
|
||||
from guide.app.browser.mixins.diagnostics import DiagnosticsMixin
|
||||
from guide.app.browser.mixins.interaction import InteractionMixin
|
||||
from guide.app.browser.mixins.wait import WaitMixin
|
||||
|
||||
__all__ = ["WaitMixin", "DiagnosticsMixin", "InteractionMixin"]
|
||||
33
src/guide/app/browser/mixins/diagnostics.py
Normal file
33
src/guide/app/browser/mixins/diagnostics.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Diagnostics helpers mixed into PageHelpers."""
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from guide.app.browser.types import PageLike
|
||||
|
||||
|
||||
class _DiagnosticsProtocol(Protocol):
|
||||
"""Protocol for classes that mix in DiagnosticsMixin."""
|
||||
|
||||
@property
|
||||
def page(self) -> "PageLike": ...
|
||||
|
||||
|
||||
class DiagnosticsMixin:
|
||||
"""Capture diagnostics (HTML, screenshot, console logs).
|
||||
|
||||
This mixin expects `page` to be set by the class that uses it.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: object) -> None:
|
||||
"""Verify subclass provides required attributes."""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
async def capture_diagnostics(self: _DiagnosticsProtocol):
|
||||
"""Capture all diagnostic information (screenshot, HTML, logs)."""
|
||||
from guide.app.browser.diagnostics import capture_all_diagnostics
|
||||
|
||||
return await capture_all_diagnostics(self.page)
|
||||
|
||||
|
||||
__all__ = ["DiagnosticsMixin"]
|
||||
100
src/guide/app/browser/mixins/interaction.py
Normal file
100
src/guide/app/browser/mixins/interaction.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""High-level interaction helpers mixed into PageHelpers."""
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from guide.app.browser.elements import dropdown
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from guide.app.browser.types import PageLike
|
||||
|
||||
|
||||
class _InteractionMixinProtocol(Protocol):
|
||||
"""Protocol for classes that mix in InteractionMixin."""
|
||||
|
||||
@property
|
||||
def page(self) -> "PageLike": ...
|
||||
|
||||
async def wait_for_network_idle(self, timeout_ms: int | None = None) -> None:
|
||||
"""Wait for network idle state."""
|
||||
...
|
||||
|
||||
async def wait_for_stable(
|
||||
self,
|
||||
stability_check_ms: int | None = None,
|
||||
samples: int = 3,
|
||||
) -> None:
|
||||
"""Wait for page stability."""
|
||||
...
|
||||
|
||||
|
||||
class InteractionMixin:
|
||||
"""Shared interaction patterns built on top of PageLike primitives."""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: object) -> None:
|
||||
"""Verify subclass provides required attributes."""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
async def wait_for_network_idle(self, timeout_ms: int | None = None) -> None:
|
||||
"""Abstract - implemented by WaitMixin."""
|
||||
raise NotImplementedError("Requires WaitMixin")
|
||||
|
||||
async def wait_for_stable(
|
||||
self,
|
||||
stability_check_ms: int | None = None,
|
||||
samples: int = 3,
|
||||
) -> None:
|
||||
"""Abstract - implemented by WaitMixin."""
|
||||
raise NotImplementedError("Requires WaitMixin")
|
||||
|
||||
async def fill_and_advance(
|
||||
self: _InteractionMixinProtocol,
|
||||
selector: str,
|
||||
value: str,
|
||||
next_selector: str,
|
||||
wait_for_idle: bool = True,
|
||||
) -> None:
|
||||
await self.page.fill(selector, value)
|
||||
await self.page.click(next_selector)
|
||||
if wait_for_idle:
|
||||
await self.wait_for_network_idle()
|
||||
|
||||
async def search_and_select(
|
||||
self: _InteractionMixinProtocol,
|
||||
search_input: str,
|
||||
query: str,
|
||||
result_selector: str,
|
||||
index: int = 0,
|
||||
) -> None:
|
||||
await self.page.fill(search_input, query)
|
||||
await self.wait_for_network_idle()
|
||||
results = self.page.locator(result_selector)
|
||||
await results.nth(index).click()
|
||||
|
||||
async def click_and_wait(
|
||||
self: _InteractionMixinProtocol,
|
||||
selector: str,
|
||||
wait_for_idle: bool = True,
|
||||
wait_for_stable: bool = False,
|
||||
) -> None:
|
||||
await self.page.click(selector)
|
||||
if wait_for_idle:
|
||||
await self.wait_for_network_idle()
|
||||
if wait_for_stable:
|
||||
await self.wait_for_stable()
|
||||
|
||||
async def select_dropdown_options(
|
||||
self: _InteractionMixinProtocol,
|
||||
field_selector: str,
|
||||
target_values: list[str],
|
||||
close_after: bool = True,
|
||||
) -> dict[str, list[str]]:
|
||||
_ = close_after # Reserved for future implementation
|
||||
result = await dropdown.select_multi(self.page, field_selector, target_values)
|
||||
return {
|
||||
"selected": result["selected"],
|
||||
"not_found": result["not_found"],
|
||||
"available_options": result.get("available", []),
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["InteractionMixin"]
|
||||
75
src/guide/app/browser/mixins/wait.py
Normal file
75
src/guide/app/browser/mixins/wait.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Wait-related helpers mixed into PageHelpers."""
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from guide.app.browser.wait import (
|
||||
wait_for_network_idle,
|
||||
wait_for_navigation,
|
||||
wait_for_selector,
|
||||
wait_for_stable_page,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from guide.app.browser.types import PageLike
|
||||
from guide.app.core.config import Timeouts
|
||||
|
||||
|
||||
class _WaitMixinProtocol(Protocol):
|
||||
"""Protocol for classes that mix in WaitMixin."""
|
||||
|
||||
@property
|
||||
def page(self) -> "PageLike": ...
|
||||
|
||||
@property
|
||||
def timeouts(self) -> "Timeouts": ...
|
||||
|
||||
|
||||
class WaitMixin:
|
||||
"""Provide wait utilities that reuse shared timeout configuration."""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: object) -> None:
|
||||
"""Verify subclass provides required attributes."""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
async def wait_for_selector(
|
||||
self: _WaitMixinProtocol, selector: str, timeout_ms: int | None = None
|
||||
) -> None:
|
||||
await wait_for_selector(
|
||||
self.page,
|
||||
selector,
|
||||
timeout_ms=timeout_ms,
|
||||
timeouts=self.timeouts,
|
||||
)
|
||||
|
||||
async def wait_for_network_idle(
|
||||
self: _WaitMixinProtocol, timeout_ms: int | None = None
|
||||
) -> None:
|
||||
await wait_for_network_idle(
|
||||
self.page,
|
||||
timeout_ms=timeout_ms,
|
||||
timeouts=self.timeouts,
|
||||
)
|
||||
|
||||
async def wait_for_navigation(
|
||||
self: _WaitMixinProtocol, timeout_ms: int | None = None
|
||||
) -> None:
|
||||
await wait_for_navigation(
|
||||
self.page,
|
||||
timeout_ms=timeout_ms,
|
||||
timeouts=self.timeouts,
|
||||
)
|
||||
|
||||
async def wait_for_stable(
|
||||
self: _WaitMixinProtocol,
|
||||
stability_check_ms: int | None = None,
|
||||
samples: int = 3,
|
||||
) -> None:
|
||||
await wait_for_stable_page(
|
||||
self.page,
|
||||
stability_check_ms=stability_check_ms,
|
||||
samples=samples,
|
||||
timeouts=self.timeouts,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["WaitMixin"]
|
||||
@@ -10,6 +10,7 @@ Architecture:
|
||||
- No page/context pooling: Each action gets a clean slate
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from pathlib import Path
|
||||
@@ -22,13 +23,19 @@ from playwright.async_api import (
|
||||
Playwright,
|
||||
async_playwright,
|
||||
)
|
||||
from playwright._impl._errors import TargetClosedError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from playwright.async_api import StorageState
|
||||
else:
|
||||
StorageState = dict[str, object] # Runtime fallback
|
||||
|
||||
from guide.app.core.config import AppSettings, BrowserHostConfig, HostKind
|
||||
from guide.app.core.config import (
|
||||
AppSettings,
|
||||
BrowserHostConfig,
|
||||
DEFAULT_EXTENSION_PORT,
|
||||
HostKind,
|
||||
)
|
||||
from guide.app import errors
|
||||
from guide.app.browser.extension_client import ExtensionClient, ExtensionPage
|
||||
|
||||
@@ -124,9 +131,32 @@ class BrowserInstance:
|
||||
|
||||
async def _allocate_cdp_page(self) -> tuple[BrowserContext, PageLike, bool]:
|
||||
_logger.info(f"[CDP-{self.host_id}] allocate_context_and_page called")
|
||||
|
||||
# Check if cached page is still valid
|
||||
if self._cdp_page is not None and self._cdp_page.is_closed():
|
||||
_logger.warning(
|
||||
f"[CDP-{self.host_id}] Cached page is closed, clearing cache"
|
||||
)
|
||||
self._cdp_context = None
|
||||
self._cdp_page = None
|
||||
|
||||
# Check if browser connection is still alive
|
||||
if self.browser is not None and not self.browser.is_connected():
|
||||
_logger.warning(
|
||||
f"[CDP-{self.host_id}] Browser disconnected, clearing cache"
|
||||
)
|
||||
self._cdp_context = None
|
||||
self._cdp_page = None
|
||||
raise errors.BrowserConnectionError(
|
||||
f"CDP browser disconnected for host {self.host_id}",
|
||||
details={"host_id": self.host_id},
|
||||
)
|
||||
|
||||
if self._cdp_context is None or self._cdp_page is None:
|
||||
browser = self._require_browser()
|
||||
_logger.info(f"[CDP-{self.host_id}] First access - querying browser.contexts")
|
||||
_logger.info(
|
||||
f"[CDP-{self.host_id}] First access - querying browser.contexts"
|
||||
)
|
||||
contexts = browser.contexts
|
||||
_logger.info(f"[CDP-{self.host_id}] Got {len(contexts)} contexts")
|
||||
if not contexts:
|
||||
@@ -137,14 +167,18 @@ class BrowserInstance:
|
||||
|
||||
context = contexts[0]
|
||||
pages = context.pages
|
||||
_logger.info(f"[CDP-{self.host_id}] Got {len(pages)} pages: {[p.url for p in pages]}")
|
||||
_logger.info(
|
||||
f"[CDP-{self.host_id}] Got {len(pages)} pages: {[p.url for p in pages]}"
|
||||
)
|
||||
if not pages:
|
||||
raise errors.BrowserConnectionError(
|
||||
f"No pages available in CDP browser context for host {self.host_id}",
|
||||
details={"host_id": self.host_id},
|
||||
)
|
||||
|
||||
non_devtools_pages = [p for p in pages if not p.url.startswith("devtools://")]
|
||||
non_devtools_pages = [
|
||||
p for p in pages if not p.url.startswith("devtools://")
|
||||
]
|
||||
if not non_devtools_pages:
|
||||
raise errors.BrowserConnectionError(
|
||||
"No application pages found in CDP browser (only devtools pages)",
|
||||
@@ -155,9 +189,14 @@ class BrowserInstance:
|
||||
self._cdp_page = non_devtools_pages[-1]
|
||||
_logger.info(f"[CDP-{self.host_id}] Cached page: {self._cdp_page.url}")
|
||||
else:
|
||||
_logger.info(f"[CDP-{self.host_id}] Using cached page: {self._cdp_page.url}")
|
||||
_logger.info(
|
||||
f"[CDP-{self.host_id}] Using cached page: {self._cdp_page.url}"
|
||||
)
|
||||
|
||||
# Assert non-None for type narrowing (values were just assigned above)
|
||||
assert self._cdp_context is not None, "CDP context should be initialized"
|
||||
assert self._cdp_page is not None, "CDP page should be initialized"
|
||||
|
||||
assert self._cdp_context is not None and self._cdp_page is not None
|
||||
return self._cdp_context, self._cdp_page, False
|
||||
|
||||
async def _allocate_headless_page(
|
||||
@@ -206,6 +245,8 @@ class BrowserPool:
|
||||
self._instances: dict[str, BrowserInstance] = {}
|
||||
self._playwright: Playwright | None = None
|
||||
self._closed: bool = False
|
||||
# Per-host locks to prevent race conditions during concurrent allocation
|
||||
self._locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the browser pool.
|
||||
@@ -215,6 +256,13 @@ class BrowserPool:
|
||||
"""
|
||||
if self._playwright is not None:
|
||||
return
|
||||
|
||||
# Warn if no browser hosts configured
|
||||
if not self.settings.browser_hosts:
|
||||
_logger.warning(
|
||||
"No browser hosts configured. Actions requiring browser access will fail."
|
||||
)
|
||||
|
||||
self._playwright = await async_playwright().start()
|
||||
|
||||
# Eagerly connect to all CDP hosts to avoid refresh on first use
|
||||
@@ -225,7 +273,9 @@ class BrowserPool:
|
||||
self._instances[host_id] = instance
|
||||
# Eagerly cache the page reference to avoid querying on first request
|
||||
_ = await instance.allocate_context_and_page()
|
||||
_logger.info(f"Eagerly connected to CDP host '{host_id}' and cached page")
|
||||
_logger.info(
|
||||
f"Eagerly connected to CDP host '{host_id}' and cached page"
|
||||
)
|
||||
except Exception as exc:
|
||||
_logger.warning(
|
||||
f"Failed to eagerly connect to CDP host '{host_id}': {exc}"
|
||||
@@ -258,6 +308,7 @@ class BrowserPool:
|
||||
"""Allocate a fresh context and page for the specified host.
|
||||
|
||||
Lazily creates browser connections on first request per host.
|
||||
Automatically reconnects if the connection has gone stale.
|
||||
|
||||
Args:
|
||||
host_id: The host identifier, or None for the default host
|
||||
@@ -288,14 +339,39 @@ class BrowserPool:
|
||||
"Browser pool not initialized. Call initialize() first."
|
||||
)
|
||||
|
||||
# Get or create the browser instance for this host
|
||||
if resolved_id not in self._instances:
|
||||
instance = await self._create_instance(resolved_id, host_config)
|
||||
self._instances[resolved_id] = instance
|
||||
# Get or create per-host lock to prevent race conditions
|
||||
if resolved_id not in self._locks:
|
||||
self._locks[resolved_id] = asyncio.Lock()
|
||||
|
||||
return await self._instances[resolved_id].allocate_context_and_page(
|
||||
storage_state=storage_state
|
||||
)
|
||||
async with self._locks[resolved_id]:
|
||||
# Get or create the browser instance for this host
|
||||
if resolved_id not in self._instances:
|
||||
instance = await self._create_instance(resolved_id, host_config)
|
||||
self._instances[resolved_id] = instance
|
||||
|
||||
try:
|
||||
return await self._instances[resolved_id].allocate_context_and_page(
|
||||
storage_state=storage_state
|
||||
)
|
||||
except (TargetClosedError, errors.BrowserConnectionError) as exc:
|
||||
# Connection is stale - evict and reconnect once
|
||||
_logger.warning(
|
||||
f"Stale connection detected for host '{resolved_id}', reconnecting: {exc}"
|
||||
)
|
||||
await self._evict_instance(resolved_id)
|
||||
instance = await self._create_instance(resolved_id, host_config)
|
||||
self._instances[resolved_id] = instance
|
||||
return await instance.allocate_context_and_page(
|
||||
storage_state=storage_state
|
||||
)
|
||||
|
||||
async def _evict_instance(self, host_id: str) -> None:
|
||||
"""Evict and close a stale browser instance."""
|
||||
if host_id in self._instances:
|
||||
instance = self._instances.pop(host_id)
|
||||
with contextlib.suppress(Exception):
|
||||
await instance.close()
|
||||
_logger.info(f"Evicted stale browser instance for host '{host_id}'")
|
||||
|
||||
async def _create_instance(
|
||||
self, host_id: str, host_config: BrowserHostConfig
|
||||
@@ -303,7 +379,7 @@ class BrowserPool:
|
||||
"""Create a new browser instance for the given host."""
|
||||
if host_config.kind == HostKind.EXTENSION:
|
||||
# Create and start extension client
|
||||
port = host_config.port or 17373
|
||||
port = host_config.port or DEFAULT_EXTENSION_PORT
|
||||
extension_client = ExtensionClient(port=port)
|
||||
try:
|
||||
await extension_client.start()
|
||||
@@ -319,7 +395,9 @@ class BrowserPool:
|
||||
**exc.details,
|
||||
},
|
||||
) from exc
|
||||
except Exception as exc: # pragma: no cover - safety net for unexpected failures
|
||||
except (
|
||||
Exception
|
||||
) as exc: # pragma: no cover - safety net for unexpected failures
|
||||
await extension_client.close()
|
||||
raise errors.BrowserConnectionError(
|
||||
f"Cannot start extension host '{host_id}'",
|
||||
@@ -334,7 +412,10 @@ class BrowserPool:
|
||||
)
|
||||
return instance
|
||||
|
||||
assert self._playwright is not None
|
||||
if self._playwright is None:
|
||||
raise errors.ConfigError(
|
||||
"Browser pool not initialized. Call initialize() first."
|
||||
)
|
||||
if host_config.kind == HostKind.CDP:
|
||||
browser = await self._connect_cdp(host_config)
|
||||
else:
|
||||
@@ -347,26 +428,46 @@ class BrowserPool:
|
||||
return instance
|
||||
|
||||
async def _connect_cdp(self, host_config: BrowserHostConfig) -> Browser:
|
||||
"""Connect to a CDP host."""
|
||||
assert self._playwright is not None
|
||||
"""Connect to a CDP host.
|
||||
|
||||
if not host_config.host or host_config.port is None:
|
||||
raise errors.ConfigError("CDP host requires 'host' and 'port' fields.")
|
||||
Supports either a host/port pair (Chrome-style /json/version) or an
|
||||
explicit websocket/CDP URL for gateways like Browserless that return a
|
||||
non-routable `webSocketDebuggerUrl`.
|
||||
"""
|
||||
|
||||
if self._playwright is None:
|
||||
raise errors.ConfigError(
|
||||
"Browser pool not initialized. Call initialize() first."
|
||||
)
|
||||
|
||||
target_url = host_config.cdp_url
|
||||
details: dict[str, object] = {}
|
||||
|
||||
if not target_url:
|
||||
if not host_config.host or host_config.port is None:
|
||||
raise errors.ConfigError(
|
||||
"CDP host requires 'host' and 'port' fields when cdp_url is not provided."
|
||||
)
|
||||
target_url = f"http://{host_config.host}:{host_config.port}"
|
||||
details.update({"host": host_config.host, "port": host_config.port})
|
||||
else:
|
||||
details["cdp_url"] = target_url
|
||||
|
||||
cdp_url = f"http://{host_config.host}:{host_config.port}"
|
||||
try:
|
||||
browser = await self._playwright.chromium.connect_over_cdp(cdp_url)
|
||||
_logger.info(f"Connected to CDP endpoint: {cdp_url}")
|
||||
browser = await self._playwright.chromium.connect_over_cdp(target_url)
|
||||
_logger.info(f"Connected to CDP endpoint: {target_url}")
|
||||
return browser
|
||||
except Exception as exc:
|
||||
raise errors.BrowserConnectionError(
|
||||
f"Cannot connect to CDP endpoint {cdp_url}",
|
||||
details={"host": host_config.host, "port": host_config.port},
|
||||
f"Cannot connect to CDP endpoint {target_url}", details=details
|
||||
) from exc
|
||||
|
||||
async def _launch_headless(self, host_config: BrowserHostConfig) -> Browser:
|
||||
"""Launch a headless browser."""
|
||||
assert self._playwright is not None
|
||||
if self._playwright is None:
|
||||
raise errors.ConfigError(
|
||||
"Browser pool not initialized. Call initialize() first."
|
||||
)
|
||||
|
||||
browser_type = self._resolve_browser_type(host_config.browser)
|
||||
try:
|
||||
@@ -383,7 +484,10 @@ class BrowserPool:
|
||||
|
||||
def _resolve_browser_type(self, browser: str | None):
|
||||
"""Resolve browser type from configuration."""
|
||||
assert self._playwright is not None
|
||||
if self._playwright is None:
|
||||
raise errors.ConfigError(
|
||||
"Browser pool not initialized. Call initialize() first."
|
||||
)
|
||||
|
||||
desired = (browser or "chromium").lower()
|
||||
if desired == "chromium":
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Type definitions for browser automation interfaces."""
|
||||
|
||||
from typing import Literal, Protocol
|
||||
from typing import Literal, Protocol, TypeGuard, runtime_checkable
|
||||
|
||||
|
||||
class PageLocator(Protocol):
|
||||
@@ -27,11 +27,25 @@ class PageLocator(Protocol):
|
||||
"""Get the text content of the element."""
|
||||
...
|
||||
|
||||
async def click(self) -> None:
|
||||
async def click(self, *, timeout: int | float | None = None) -> None:
|
||||
"""Click the element."""
|
||||
...
|
||||
|
||||
async def fill(self, value: str, *, timeout: int | float | None = None) -> None:
|
||||
"""Fill the element with a value."""
|
||||
...
|
||||
|
||||
async def wait_for(
|
||||
self,
|
||||
*,
|
||||
state: Literal["attached", "detached", "hidden", "visible"] | None = None,
|
||||
timeout: int | float | None = None,
|
||||
) -> None:
|
||||
"""Wait for the element to match the given state."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class PageLike(Protocol):
|
||||
"""Protocol for page-like objects that support common browser automation operations.
|
||||
|
||||
@@ -57,7 +71,8 @@ class PageLike(Protocol):
|
||||
url: str,
|
||||
*,
|
||||
timeout: float | None = None,
|
||||
wait_until: Literal["commit", "domcontentloaded", "load", "networkidle"] | None = None,
|
||||
wait_until: Literal["commit", "domcontentloaded", "load", "networkidle"]
|
||||
| None = None,
|
||||
referer: str | None = None,
|
||||
) -> object | None:
|
||||
"""Navigate to a URL."""
|
||||
@@ -104,5 +119,18 @@ class PageLike(Protocol):
|
||||
...
|
||||
|
||||
|
||||
__all__ = ["PageLike", "PageLocator"]
|
||||
@runtime_checkable
|
||||
class PageWithTrustedClick(PageLike, Protocol):
|
||||
"""Page-like object that supports trusted_click (extension-only)."""
|
||||
|
||||
async def trusted_click(self, x: float, y: float) -> None:
|
||||
"""Perform a trusted click at specific coordinates."""
|
||||
...
|
||||
|
||||
|
||||
def supports_trusted_click(page: PageLike) -> TypeGuard["PageWithTrustedClick"]:
|
||||
"""Type guard for objects that implement trusted_click."""
|
||||
return hasattr(page, "trusted_click")
|
||||
|
||||
|
||||
__all__ = ["PageLike", "PageLocator", "PageWithTrustedClick", "supports_trusted_click"]
|
||||
|
||||
34
src/guide/app/browser/utils.py
Normal file
34
src/guide/app/browser/utils.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Browser utility functions.
|
||||
|
||||
Shared utilities used across browser automation modules.
|
||||
These are placed in a neutral location to avoid circular imports.
|
||||
"""
|
||||
|
||||
|
||||
def escape_selector(selector: str) -> str:
|
||||
"""Escape a CSS selector for safe use in JavaScript code.
|
||||
|
||||
Handles backslashes, single quotes, and double quotes.
|
||||
|
||||
Args:
|
||||
selector: Raw CSS selector string
|
||||
|
||||
Returns:
|
||||
Escaped selector safe for JS string interpolation
|
||||
"""
|
||||
return selector.replace("\\", "\\\\").replace("'", "\\'").replace('"', '\\"')
|
||||
|
||||
|
||||
def escape_js_string(value: str) -> str:
|
||||
"""Escape a string for safe use in JavaScript single-quoted strings.
|
||||
|
||||
Args:
|
||||
value: Raw string value
|
||||
|
||||
Returns:
|
||||
Escaped string safe for JS interpolation
|
||||
"""
|
||||
return value.replace("\\", "\\\\").replace("'", "\\'").replace('"', '\\"')
|
||||
|
||||
|
||||
__all__ = ["escape_selector", "escape_js_string"]
|
||||
@@ -10,75 +10,98 @@ from playwright.async_api import TimeoutError as PlaywrightTimeoutError
|
||||
|
||||
from guide.app import errors
|
||||
from guide.app.browser.types import PageLike
|
||||
from guide.app.core.config import Timeouts
|
||||
from guide.app.utils.retry import retry_async
|
||||
|
||||
|
||||
_DEFAULT_TIMEOUTS = Timeouts()
|
||||
|
||||
|
||||
async def wait_for_selector(
|
||||
page: PageLike,
|
||||
selector: str,
|
||||
timeout_ms: int = 5000,
|
||||
timeout_ms: int | None = None,
|
||||
*,
|
||||
timeouts: Timeouts | None = None,
|
||||
) -> None:
|
||||
"""Wait for an element matching selector to be present in DOM.
|
||||
|
||||
Args:
|
||||
page: The Playwright page instance
|
||||
selector: CSS or Playwright selector string
|
||||
timeout_ms: Maximum time to wait in milliseconds (default: 5000)
|
||||
timeout_ms: Maximum time to wait in milliseconds
|
||||
timeouts: Optional Timeouts configuration to use for defaults
|
||||
|
||||
Raises:
|
||||
GuideError: If selector not found within timeout
|
||||
"""
|
||||
effective_timeouts = timeouts or _DEFAULT_TIMEOUTS
|
||||
max_wait = (
|
||||
timeout_ms if timeout_ms is not None else effective_timeouts.element_default
|
||||
)
|
||||
try:
|
||||
_ = await page.wait_for_selector(selector, timeout=timeout_ms)
|
||||
_ = await page.wait_for_selector(selector, timeout=max_wait)
|
||||
except PlaywrightTimeoutError as exc:
|
||||
msg = f"Selector '{selector}' not found within {timeout_ms}ms"
|
||||
msg = f"Selector '{selector}' not found within {max_wait}ms"
|
||||
raise errors.GuideError(msg) from exc
|
||||
|
||||
|
||||
async def wait_for_navigation(
|
||||
page: PageLike,
|
||||
timeout_ms: int = 5000,
|
||||
timeout_ms: int | None = None,
|
||||
*,
|
||||
timeouts: Timeouts | None = None,
|
||||
) -> None:
|
||||
"""Wait for page navigation to complete.
|
||||
|
||||
Args:
|
||||
page: The Playwright page instance
|
||||
timeout_ms: Maximum time to wait in milliseconds (default: 5000)
|
||||
timeout_ms: Maximum time to wait in milliseconds
|
||||
timeouts: Optional Timeouts configuration to use for defaults
|
||||
|
||||
Raises:
|
||||
GuideError: If navigation does not complete within timeout
|
||||
"""
|
||||
effective_timeouts = timeouts or _DEFAULT_TIMEOUTS
|
||||
max_wait = timeout_ms if timeout_ms is not None else effective_timeouts.network_idle
|
||||
try:
|
||||
await page.wait_for_load_state("networkidle", timeout=timeout_ms)
|
||||
await page.wait_for_load_state("networkidle", timeout=max_wait)
|
||||
except PlaywrightTimeoutError as exc:
|
||||
msg = f"Page navigation did not complete within {timeout_ms}ms"
|
||||
msg = f"Page navigation did not complete within {max_wait}ms"
|
||||
raise errors.GuideError(msg) from exc
|
||||
|
||||
|
||||
async def wait_for_network_idle(
|
||||
page: PageLike,
|
||||
timeout_ms: int = 5000,
|
||||
timeout_ms: int | None = None,
|
||||
*,
|
||||
timeouts: Timeouts | None = None,
|
||||
) -> None:
|
||||
"""Wait for network to become idle (no active requests).
|
||||
|
||||
Args:
|
||||
page: The Playwright page instance
|
||||
timeout_ms: Maximum time to wait in milliseconds (default: 5000)
|
||||
timeout_ms: Maximum time to wait in milliseconds
|
||||
timeouts: Optional Timeouts configuration to use for defaults
|
||||
|
||||
Raises:
|
||||
GuideError: If network does not idle within timeout
|
||||
"""
|
||||
effective_timeouts = timeouts or _DEFAULT_TIMEOUTS
|
||||
max_wait = timeout_ms if timeout_ms is not None else effective_timeouts.network_idle
|
||||
try:
|
||||
await page.wait_for_load_state("networkidle", timeout=timeout_ms)
|
||||
await page.wait_for_load_state("networkidle", timeout=max_wait)
|
||||
except PlaywrightTimeoutError as exc:
|
||||
msg = f"Network did not idle within {timeout_ms}ms"
|
||||
msg = f"Network did not idle within {max_wait}ms"
|
||||
raise errors.GuideError(msg) from exc
|
||||
|
||||
|
||||
async def is_page_stable(
|
||||
page: PageLike,
|
||||
stability_check_ms: int = 500,
|
||||
stability_check_ms: int | None = None,
|
||||
samples: int = 3,
|
||||
*,
|
||||
timeouts: Timeouts | None = None,
|
||||
) -> bool:
|
||||
"""Check if page is visually stable (DOM not changing).
|
||||
|
||||
@@ -87,12 +110,19 @@ async def is_page_stable(
|
||||
|
||||
Args:
|
||||
page: The Playwright page instance
|
||||
stability_check_ms: Delay between samples in milliseconds (default: 500)
|
||||
stability_check_ms: Delay between samples in milliseconds
|
||||
samples: Number of stable samples required (default: 3)
|
||||
timeouts: Optional Timeouts configuration to use for defaults
|
||||
|
||||
Returns:
|
||||
True if page is stable, False otherwise
|
||||
"""
|
||||
effective_timeouts = timeouts or _DEFAULT_TIMEOUTS
|
||||
interval_ms = (
|
||||
stability_check_ms
|
||||
if stability_check_ms is not None
|
||||
else effective_timeouts.stability_check
|
||||
)
|
||||
try:
|
||||
previous_content: str | None = None
|
||||
|
||||
@@ -103,19 +133,24 @@ async def is_page_stable(
|
||||
return False
|
||||
|
||||
previous_content = current_content
|
||||
await asyncio.sleep(stability_check_ms / 1000)
|
||||
await asyncio.sleep(interval_ms / 1000)
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
# If we can't check stability, assume page is stable
|
||||
return True
|
||||
except Exception as exc:
|
||||
# If we can't check stability, log warning and return False to be safe
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).warning(f"Page stability check failed: {exc}")
|
||||
return False
|
||||
|
||||
|
||||
@retry_async(retries=3, delay_seconds=0.2)
|
||||
async def wait_for_stable_page(
|
||||
page: PageLike,
|
||||
stability_check_ms: int = 500,
|
||||
stability_check_ms: int | None = None,
|
||||
samples: int = 3,
|
||||
*,
|
||||
timeouts: Timeouts | None = None,
|
||||
) -> None:
|
||||
"""Wait for page to become visually stable, with retries.
|
||||
|
||||
@@ -124,13 +159,19 @@ async def wait_for_stable_page(
|
||||
|
||||
Args:
|
||||
page: The Playwright page instance
|
||||
stability_check_ms: Delay between samples in milliseconds (default: 500)
|
||||
stability_check_ms: Delay between samples in milliseconds
|
||||
samples: Number of stable samples required (default: 3)
|
||||
timeouts: Optional Timeouts configuration to use for defaults
|
||||
|
||||
Raises:
|
||||
GuideError: If page does not stabilize after retries
|
||||
"""
|
||||
stable = await is_page_stable(page, stability_check_ms, samples)
|
||||
stable = await is_page_stable(
|
||||
page,
|
||||
stability_check_ms,
|
||||
samples,
|
||||
timeouts=timeouts,
|
||||
)
|
||||
if not stable:
|
||||
msg = "Page did not stabilize after retries"
|
||||
raise errors.GuideError(msg)
|
||||
|
||||
@@ -6,7 +6,7 @@ from pathlib import Path
|
||||
from collections.abc import Mapping
|
||||
from typing import ClassVar, Protocol, TypeAlias, TypeVar, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import AliasChoices, BaseModel, Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from guide.app.models.boards.models import BoardConfig
|
||||
@@ -45,16 +45,76 @@ class HostKind(str, Enum):
|
||||
EXTENSION = "extension"
|
||||
|
||||
|
||||
# Default port for extension WebSocket server
|
||||
DEFAULT_EXTENSION_PORT = 17373
|
||||
|
||||
|
||||
class BrowserHostConfig(BaseModel):
|
||||
"""Configuration for a browser host (CDP or headless)."""
|
||||
|
||||
id: str
|
||||
kind: HostKind
|
||||
host: str | None = None
|
||||
port: int | None = None
|
||||
port: int | None = Field(default=None, ge=1, le=65535)
|
||||
cdp_url: str | None = None # explicit CDP endpoint override
|
||||
browser: str | None = None # chromium/firefox/webkit for headless
|
||||
|
||||
|
||||
class Timeouts(BaseModel):
|
||||
"""Centralized timeout configuration for browser interactions.
|
||||
|
||||
All timeout values are in milliseconds unless otherwise noted.
|
||||
"""
|
||||
|
||||
# Browser element operations (milliseconds)
|
||||
element_default: int = 5000
|
||||
"""Default timeout for element wait operations."""
|
||||
network_idle: int = 5000
|
||||
"""Timeout for network idle detection."""
|
||||
animation: int = 300
|
||||
"""Animation transition duration."""
|
||||
stability_check: int = 500
|
||||
"""Interval for page stability checks."""
|
||||
|
||||
# UI component timeouts (milliseconds)
|
||||
listbox_wait: int = 600
|
||||
"""Timeout for dropdown listbox appearance."""
|
||||
dropdown_field: int = 1000
|
||||
"""Timeout for dropdown field interactions."""
|
||||
combobox_listbox: int = 2500
|
||||
"""Extended timeout for combobox listbox population."""
|
||||
|
||||
# Extension client timeouts (seconds for consistency with asyncio)
|
||||
extension_connection_s: float = 10.0
|
||||
"""Timeout waiting for extension to connect."""
|
||||
extension_command_s: float = 30.0
|
||||
"""Timeout for extension command/eval execution."""
|
||||
extension_trusted_click_s: float = 10.0
|
||||
"""Timeout for trusted click operations."""
|
||||
websocket_receive_s: float = 60.0
|
||||
"""Timeout for WebSocket message receive (triggers ping)."""
|
||||
|
||||
# GraphQL/HTTP timeouts (seconds)
|
||||
graphql_base_s: float = 10.0
|
||||
"""Base HTTP client timeout for GraphQL."""
|
||||
graphql_operation_s: float = 30.0
|
||||
"""Extended timeout for GraphQL operations."""
|
||||
|
||||
# Debounce delays (milliseconds)
|
||||
debounce_network: int = 100
|
||||
"""Delay for network request stabilization."""
|
||||
debounce_typeahead: int = 150
|
||||
"""Delay for typeahead input debouncing."""
|
||||
scroll_wait: int = 50
|
||||
"""Brief pause for smooth scrolling animations."""
|
||||
|
||||
# Auth flow timeouts (seconds)
|
||||
auth_stability_s: float = 8.0
|
||||
"""Page stability check after Auth0 redirects."""
|
||||
otp_callback_default_s: float = 120.0
|
||||
"""Default timeout for OTP callback wait."""
|
||||
|
||||
|
||||
class AppSettings(BaseSettings):
|
||||
"""Application settings loaded from YAML files + environment variables.
|
||||
|
||||
@@ -69,28 +129,62 @@ class AppSettings(BaseSettings):
|
||||
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
|
||||
env_prefix="RAINDROP_DEMO_",
|
||||
case_sensitive=False,
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
raindrop_base_url: str = "https://stg.raindrop.com"
|
||||
raindrop_graphql_url: str = "https://raindrop-staging.hasura.app/v1/graphql"
|
||||
default_browser_host_id: str = "demo-cdp"
|
||||
# Raindrop URLs (use RAINDROP_STAGING_* env vars)
|
||||
raindrop_base_url: str = Field(
|
||||
default="https://stg.raindrop.com",
|
||||
validation_alias=AliasChoices("RAINDROP_STAGING_BASE_URL", "raindrop_base_url"),
|
||||
)
|
||||
"""Base URL for Raindrop app (login page, etc.)"""
|
||||
raindrop_graphql_url: str = Field(
|
||||
default="https://raindrop-staging.hasura.app/v1/graphql",
|
||||
validation_alias=AliasChoices(
|
||||
"RAINDROP_STAGING_GRAPHQL_URL", "raindrop_graphql_url"
|
||||
),
|
||||
)
|
||||
"""GraphQL API endpoint."""
|
||||
|
||||
# Browser configuration
|
||||
default_browser_host_id: str = "browserless-cdp"
|
||||
browser_hosts: dict[str, BrowserHostConfig] = Field(default_factory=dict)
|
||||
personas: dict[str, DemoPersona] = Field(default_factory=dict)
|
||||
boards: dict[str, BoardConfig] = Field(default_factory=dict)
|
||||
docling_base_url: str = "http://192.168.50.185:50011"
|
||||
"""Base URL for Docling document conversion API"""
|
||||
docling_api_key: str | None = None
|
||||
"""Optional API key for Docling authentication (X-API-Key header)"""
|
||||
docling_enabled: bool = True
|
||||
"""Enable/disable Docling UI element extraction on action errors"""
|
||||
timeouts: Timeouts = Field(default_factory=Timeouts)
|
||||
|
||||
# Docling (use RAINDROP_DEMO_DOCLING_* env vars)
|
||||
docling_base_url: str = Field(
|
||||
default="http://localhost:50011",
|
||||
validation_alias=AliasChoices(
|
||||
"RAINDROP_DEMO_DOCLING_BASE_URL", "docling_base_url"
|
||||
),
|
||||
)
|
||||
"""Base URL for Docling document conversion API."""
|
||||
docling_api_key: str | None = Field(default=None)
|
||||
"""Optional API key for Docling authentication (X-API-Key header)."""
|
||||
docling_enabled: bool = Field(default=False)
|
||||
"""Enable/disable Docling UI element extraction on action errors."""
|
||||
|
||||
# Session Management
|
||||
session_storage_dir: Path = Field(default=Path(".sessions"))
|
||||
"""Directory to store session files (relative to project root)"""
|
||||
"""Directory to store session files (relative to project root)."""
|
||||
session_ttl_minutes: int = Field(default=60)
|
||||
"""Session time-to-live in minutes before requiring re-authentication"""
|
||||
"""Session time-to-live in minutes before requiring re-authentication."""
|
||||
session_auto_persist: bool = Field(default=True)
|
||||
"""Automatically save sessions after successful login"""
|
||||
"""Automatically save sessions after successful login."""
|
||||
|
||||
# n8n Integration (use RAINDROP_DEMO_N8N_* env vars)
|
||||
n8n_webhook_url: str | None = Field(default=None)
|
||||
"""Webhook URL for n8n OTP request notifications."""
|
||||
n8n_otp_callback_timeout: int = Field(default=120)
|
||||
"""Timeout in seconds waiting for OTP callback from n8n."""
|
||||
n8n_otp_email_delay: float = Field(default=5.0)
|
||||
"""Delay in seconds after triggering OTP email before notifying n8n."""
|
||||
callback_base_url: str = Field(default="http://localhost:8765")
|
||||
"""Base URL for callback endpoints (used by n8n to call back)."""
|
||||
|
||||
|
||||
def _load_yaml_file(path: Path) -> dict[str, object]:
|
||||
@@ -275,7 +369,7 @@ def load_settings() -> AppSettings:
|
||||
# Load JSON overrides from environment
|
||||
if browser_hosts_json := os.environ.get("RAINDROP_DEMO_BROWSER_HOSTS_JSON"):
|
||||
try:
|
||||
_settings_extraction_from_json(
|
||||
_load_json_override_records(
|
||||
browser_hosts_json,
|
||||
"RAINDROP_DEMO_BROWSER_HOSTS_JSON must be a JSON array",
|
||||
BrowserHostConfig,
|
||||
@@ -286,7 +380,7 @@ def load_settings() -> AppSettings:
|
||||
|
||||
if personas_json := os.environ.get("RAINDROP_DEMO_PERSONAS_JSON"):
|
||||
try:
|
||||
_settings_extraction_from_json(
|
||||
_load_json_override_records(
|
||||
personas_json,
|
||||
"RAINDROP_DEMO_PERSONAS_JSON must be a JSON array",
|
||||
DemoPersona,
|
||||
@@ -311,15 +405,14 @@ def load_settings() -> AppSettings:
|
||||
return settings
|
||||
|
||||
|
||||
# TODO Rename this here and in `load_settings`
|
||||
def _settings_extraction_from_json(
|
||||
def _load_json_override_records(
|
||||
json_str: str,
|
||||
error_message: str,
|
||||
model_class: type[_T],
|
||||
target_dict: dict[str, _T],
|
||||
) -> None:
|
||||
"""Extract and validate records from JSON string into target dictionary.
|
||||
|
||||
|
||||
Args:
|
||||
json_str: JSON string containing an array of records
|
||||
error_message: Error message to raise if JSON is not an array
|
||||
|
||||
@@ -16,6 +16,7 @@ from guide.app.api import router as api_router
|
||||
from guide.app import errors
|
||||
from guide.app.models.boards import BoardStore
|
||||
from guide.app.models.personas import PersonaStore
|
||||
from guide.app.raindrop import GraphQLClient
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
@@ -44,11 +45,15 @@ def create_app() -> FastAPI:
|
||||
browser_pool = BrowserPool(settings)
|
||||
browser_client = BrowserClient(browser_pool)
|
||||
|
||||
# Create GraphQL client with connection pooling
|
||||
graphql_client = GraphQLClient(settings)
|
||||
|
||||
app = FastAPI(title="Raindrop Demo Automation", version="0.1.0")
|
||||
app.state.settings = settings
|
||||
app.state.action_registry = registry
|
||||
app.state.browser_client = browser_client
|
||||
app.state.browser_pool = browser_pool
|
||||
app.state.graphql_client = graphql_client
|
||||
app.state.persona_store = persona_store
|
||||
app.state.board_store = board_store
|
||||
app.state.session_manager = session_manager
|
||||
@@ -59,10 +64,18 @@ def create_app() -> FastAPI:
|
||||
# Startup/shutdown lifecycle using modern lifespan context manager
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(_app: FastAPI):
|
||||
"""Manage browser pool lifecycle."""
|
||||
await browser_pool.initialize()
|
||||
yield
|
||||
await browser_pool.close()
|
||||
"""Manage browser pool and GraphQL client lifecycle."""
|
||||
try:
|
||||
await browser_pool.initialize()
|
||||
except Exception:
|
||||
# Ensure GraphQL client is closed if pool init fails
|
||||
await graphql_client.aclose()
|
||||
raise
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await graphql_client.aclose()
|
||||
await browser_pool.close()
|
||||
|
||||
app.router.lifespan_context = lifespan
|
||||
app.include_router(api_router)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from guide.app.models.personas.models import DemoPersona, LoginMethod, PersonaRole
|
||||
from guide.app.models.personas.resolver import PersonaResolver
|
||||
from guide.app.models.personas.store import PersonaStore
|
||||
|
||||
__all__ = ["DemoPersona", "LoginMethod", "PersonaRole", "PersonaStore"]
|
||||
__all__ = ["DemoPersona", "LoginMethod", "PersonaRole", "PersonaResolver", "PersonaStore"]
|
||||
|
||||
82
src/guide/app/models/personas/resolver.py
Normal file
82
src/guide/app/models/personas/resolver.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Persona resolver for flexible email-based persona resolution.
|
||||
|
||||
Wraps PersonaStore to support both config-based and runtime ephemeral personas.
|
||||
"""
|
||||
|
||||
from guide.app.models.personas.models import DemoPersona, LoginMethod, PersonaRole
|
||||
from guide.app.models.personas.store import PersonaStore
|
||||
|
||||
|
||||
class PersonaResolver:
|
||||
"""Resolve personas from config or create runtime ephemeral personas.
|
||||
|
||||
Priority:
|
||||
1. Lookup by persona_id in PersonaStore
|
||||
2. Lookup by email in PersonaStore (scan all personas)
|
||||
3. Create ephemeral persona from email (no config required)
|
||||
"""
|
||||
|
||||
_store: PersonaStore
|
||||
|
||||
def __init__(self, persona_store: PersonaStore) -> None:
|
||||
"""Initialize resolver with persona store.
|
||||
|
||||
Args:
|
||||
persona_store: Store for looking up configured personas.
|
||||
"""
|
||||
self._store = persona_store
|
||||
|
||||
def resolve_by_id(self, persona_id: str) -> DemoPersona:
|
||||
"""Resolve persona by ID (delegates to PersonaStore).
|
||||
|
||||
Args:
|
||||
persona_id: Configured persona ID.
|
||||
|
||||
Returns:
|
||||
Matching DemoPersona.
|
||||
|
||||
Raises:
|
||||
PersonaError: If persona not found.
|
||||
"""
|
||||
return self._store.get(persona_id)
|
||||
|
||||
def resolve_by_email(self, email: str) -> DemoPersona:
|
||||
"""Resolve persona by email, falling back to ephemeral creation.
|
||||
|
||||
Searches existing personas by email (case-insensitive).
|
||||
If not found, creates a temporary ephemeral persona.
|
||||
|
||||
Args:
|
||||
email: User email address.
|
||||
|
||||
Returns:
|
||||
Existing or newly created DemoPersona.
|
||||
"""
|
||||
normalized_email = email.lower()
|
||||
for persona in self._store.list():
|
||||
if persona.email.lower() == normalized_email:
|
||||
return persona
|
||||
return self._create_ephemeral(email)
|
||||
|
||||
def _create_ephemeral(self, email: str) -> DemoPersona:
|
||||
"""Create ephemeral persona from email only.
|
||||
|
||||
Generates deterministic ID from email for session caching compatibility.
|
||||
|
||||
Args:
|
||||
email: User email address.
|
||||
|
||||
Returns:
|
||||
New ephemeral DemoPersona with default settings.
|
||||
"""
|
||||
email_slug = email.split("@")[0].lower().replace(".", "-")
|
||||
return DemoPersona(
|
||||
id=f"ephemeral-{email_slug}",
|
||||
email=email,
|
||||
role=PersonaRole.BUYER,
|
||||
login_method=LoginMethod.MFA_EMAIL,
|
||||
browser_host_id=None,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["PersonaResolver"]
|
||||
@@ -1,18 +1,50 @@
|
||||
import httpx
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
|
||||
from guide.app import errors
|
||||
from guide.app.auth.session import extract_bearer_token
|
||||
from guide.app.browser.types import PageLike
|
||||
from guide.app.core.config import AppSettings
|
||||
from guide.app.models.personas.models import DemoPersona
|
||||
from guide.app.models.types import JSONValue
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GraphQLClient:
|
||||
"""GraphQL client with connection pooling.
|
||||
|
||||
Uses a persistent httpx.AsyncClient for efficient connection reuse.
|
||||
Call aclose() to clean up resources when done.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: AppSettings) -> None:
|
||||
self._settings: AppSettings = settings
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
def _get_client(self) -> httpx.AsyncClient:
|
||||
"""Get or create the persistent HTTP client."""
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(
|
||||
self._settings.timeouts.graphql_base_s,
|
||||
read=self._settings.timeouts.graphql_operation_s,
|
||||
),
|
||||
limits=httpx.Limits(
|
||||
max_keepalive_connections=5,
|
||||
max_connections=10,
|
||||
),
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Close the persistent HTTP client and release connections."""
|
||||
if self._client is not None:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
@@ -23,30 +55,53 @@ class GraphQLClient:
|
||||
operation_name: str | None = None,
|
||||
bearer_token: str | None = None,
|
||||
page: PageLike | None = None,
|
||||
) -> dict[str, object]:
|
||||
) -> dict[str, JSONValue]:
|
||||
"""Execute a GraphQL query with connection pooling.
|
||||
|
||||
Args:
|
||||
query: GraphQL query string.
|
||||
variables: Query variables.
|
||||
persona: Optional persona for auth context.
|
||||
operation_name: Optional GraphQL operation name.
|
||||
bearer_token: Optional explicit bearer token.
|
||||
page: Optional page to auto-extract token from localStorage.
|
||||
|
||||
Returns:
|
||||
Parsed GraphQL data node as dict.
|
||||
|
||||
Raises:
|
||||
GraphQLTransportError: On HTTP/network errors.
|
||||
GraphQLOperationError: On GraphQL-level errors.
|
||||
"""
|
||||
# Auto-discover token from page if not explicitly provided
|
||||
if bearer_token is None and page is not None:
|
||||
extracted = await extract_bearer_token(page)
|
||||
if extracted:
|
||||
bearer_token = extracted.value
|
||||
_logger.debug("Auto-extracted bearer token from page localStorage")
|
||||
else:
|
||||
_logger.warning(
|
||||
"No bearer token found in page localStorage - GraphQL request will be unauthenticated"
|
||||
)
|
||||
|
||||
url = self._settings.raindrop_graphql_url
|
||||
headers = self._build_headers(persona, bearer_token=bearer_token)
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
try:
|
||||
resp = await client.post(
|
||||
url,
|
||||
json={
|
||||
"query": query,
|
||||
"variables": variables or {},
|
||||
"operationName": operation_name,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
except httpx.HTTPError as exc:
|
||||
raise errors.GraphQLTransportError(
|
||||
f"Transport error calling GraphQL: {exc}"
|
||||
) from exc
|
||||
client = self._get_client()
|
||||
|
||||
try:
|
||||
resp = await client.post(
|
||||
url,
|
||||
json={
|
||||
"query": query,
|
||||
"variables": variables or {},
|
||||
"operationName": operation_name,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
except httpx.HTTPError as exc:
|
||||
raise errors.GraphQLTransportError(
|
||||
f"Transport error calling GraphQL: {exc}"
|
||||
) from exc
|
||||
|
||||
if resp.status_code >= 400:
|
||||
raise errors.GraphQLTransportError(
|
||||
@@ -54,14 +109,7 @@ class GraphQLClient:
|
||||
details={"status_code": resp.status_code, "body": resp.text},
|
||||
)
|
||||
|
||||
data = cast(dict[str, object], resp.json())
|
||||
if errors_list := data.get("errors"):
|
||||
details: dict[str, object] = {"errors": errors_list}
|
||||
raise errors.GraphQLOperationError(
|
||||
"GraphQL operation failed", details=details
|
||||
)
|
||||
payload = data.get("data", {})
|
||||
return cast(dict[str, object], payload) if isinstance(payload, dict) else {}
|
||||
return _extract_data_root(resp)
|
||||
|
||||
def _build_headers(
|
||||
self,
|
||||
@@ -75,3 +123,69 @@ class GraphQLClient:
|
||||
# Reserved for future persona-specific auth
|
||||
_ = persona
|
||||
return headers
|
||||
|
||||
|
||||
def _extract_data_root(response: httpx.Response) -> dict[str, JSONValue]:
|
||||
"""Extract and validate the top-level data node from a GraphQL response."""
|
||||
try:
|
||||
payload = response.json()
|
||||
except ValueError as exc:
|
||||
raise errors.GraphQLTransportError(
|
||||
"Invalid JSON returned from GraphQL",
|
||||
details={"status_code": response.status_code},
|
||||
) from exc
|
||||
|
||||
if not isinstance(payload, Mapping):
|
||||
raise errors.GraphQLTransportError(
|
||||
"Unexpected GraphQL response shape",
|
||||
details={"received_type": type(payload).__name__},
|
||||
)
|
||||
|
||||
# Cast payload to help type checker understand structure
|
||||
typed_payload = cast(Mapping[str, JSONValue], payload)
|
||||
|
||||
# Check for errors node
|
||||
errors_node_raw: JSONValue | None = typed_payload.get("errors")
|
||||
if isinstance(errors_node_raw, list) and len(errors_node_raw) > 0:
|
||||
errors_node = cast(list[dict[str, JSONValue]], errors_node_raw)
|
||||
raise errors.GraphQLOperationError(
|
||||
"GraphQL operation failed", details={"errors": errors_node}
|
||||
)
|
||||
|
||||
# Extract and validate data node
|
||||
data_node_raw: JSONValue | None = typed_payload.get("data")
|
||||
if not isinstance(data_node_raw, Mapping):
|
||||
raise errors.GraphQLTransportError(
|
||||
"GraphQL response missing data",
|
||||
details={"received_type": type(data_node_raw).__name__ if data_node_raw is not None else "None"},
|
||||
)
|
||||
|
||||
# Convert Mapping to dict with proper typing
|
||||
data_node = cast(Mapping[str, JSONValue], data_node_raw)
|
||||
return {str(key): value for key, value in data_node.items()}
|
||||
|
||||
|
||||
def parse_graphql_response(
|
||||
response: httpx.Response,
|
||||
key: str,
|
||||
) -> dict[str, JSONValue] | list[JSONValue]:
|
||||
"""Parse a GraphQL response and return the requested data field."""
|
||||
data_node = _extract_data_root(response)
|
||||
if key not in data_node:
|
||||
raise errors.GraphQLOperationError(
|
||||
"GraphQL response missing expected key", details={"key": key}
|
||||
)
|
||||
|
||||
value = data_node[key]
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
|
||||
raise errors.GraphQLOperationError(
|
||||
"GraphQL data node must be an object or list",
|
||||
details={"key": key, "received_type": type(value).__name__},
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["GraphQLClient", "parse_graphql_response"]
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
"""Board item (intake request) operations via GraphQL API."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from collections.abc import Mapping
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
|
||||
from guide.app import errors
|
||||
from guide.app.models.boards import BoardConfig
|
||||
from guide.app.raindrop.graphql import parse_graphql_response
|
||||
from guide.app.raindrop.operations.archive import (
|
||||
ArchiveResult,
|
||||
EntityType,
|
||||
@@ -109,35 +112,15 @@ async def create_board_item(
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
response_data = cast(dict[str, object], resp.json())
|
||||
item_node = parse_graphql_response(resp, "insert_board_item_one")
|
||||
|
||||
# Check for errors
|
||||
if errors := response_data.get("errors"):
|
||||
error_list = cast(list[dict[str, object]], errors)
|
||||
first_error = error_list[0] if error_list else {}
|
||||
raise ValueError(f"GraphQL error: {first_error.get('message', 'Unknown error')}")
|
||||
if not isinstance(item_node, Mapping):
|
||||
raise errors.GraphQLOperationError(
|
||||
"Unexpected response format for created item",
|
||||
details={"received_type": type(item_node).__name__},
|
||||
)
|
||||
|
||||
# Extract created item
|
||||
data_section = response_data.get("data")
|
||||
if not isinstance(data_section, dict):
|
||||
raise ValueError("Unexpected response format: missing data section")
|
||||
|
||||
item = cast(dict[str, object], data_section).get("insert_board_item_one")
|
||||
if not isinstance(item, dict):
|
||||
raise ValueError("Unexpected response format: missing created item")
|
||||
|
||||
item_dict = cast(dict[str, object], item)
|
||||
return BoardItemResult(
|
||||
uuid=str(item_dict["uuid"]),
|
||||
id=str(item_dict["id"]),
|
||||
board_id=int(cast(int, item_dict["board_id"])),
|
||||
board_name=str(item_dict["board_name"]),
|
||||
instance_id=int(cast(int, item_dict["instance_id"])),
|
||||
data=cast(dict[str, object], item_dict.get("data", {})),
|
||||
requested_by=str(item_dict["requested_by"]) if item_dict.get("requested_by") else None,
|
||||
created_at=str(item_dict["created_at"]),
|
||||
is_archived=bool(item_dict.get("is_archived", False)),
|
||||
)
|
||||
return _build_board_item_result(item_node)
|
||||
|
||||
|
||||
async def archive_board_item(
|
||||
@@ -199,38 +182,22 @@ async def get_board_item(
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
response_data = cast(dict[str, object], resp.json())
|
||||
items_node = parse_graphql_response(resp, "board_item")
|
||||
|
||||
# Check for errors
|
||||
if errors := response_data.get("errors"):
|
||||
error_list = cast(list[dict[str, object]], errors)
|
||||
first_error = error_list[0] if error_list else {}
|
||||
raise ValueError(f"GraphQL error: {first_error.get('message', 'Unknown error')}")
|
||||
if not isinstance(items_node, list) or len(items_node) == 0:
|
||||
raise errors.GraphQLOperationError(
|
||||
f"Board item with uuid {uuid} not found",
|
||||
details={"key": "board_item"},
|
||||
)
|
||||
|
||||
# Extract item
|
||||
data_section = response_data.get("data")
|
||||
if not isinstance(data_section, dict):
|
||||
raise ValueError("Unexpected response format: missing data section")
|
||||
first_item = items_node[0]
|
||||
if not isinstance(first_item, Mapping):
|
||||
raise errors.GraphQLOperationError(
|
||||
"Unexpected item format in GraphQL response",
|
||||
details={"received_type": type(first_item).__name__},
|
||||
)
|
||||
|
||||
items = cast(dict[str, object], data_section).get("board_item")
|
||||
if not isinstance(items, list):
|
||||
raise ValueError(f"Board item with uuid {uuid} not found")
|
||||
items_list = cast(list[dict[str, object]], items)
|
||||
if len(items_list) == 0:
|
||||
raise ValueError(f"Board item with uuid {uuid} not found")
|
||||
|
||||
item = items_list[0]
|
||||
return BoardItemResult(
|
||||
uuid=str(item["uuid"]),
|
||||
id=str(item["id"]),
|
||||
board_id=int(cast(int, item["board_id"])),
|
||||
board_name=str(item["board_name"]),
|
||||
instance_id=int(cast(int, item["instance_id"])),
|
||||
data=cast(dict[str, object], item.get("data", {})),
|
||||
requested_by=str(item["requested_by"]) if item.get("requested_by") else None,
|
||||
created_at=str(item["created_at"]),
|
||||
is_archived=bool(item.get("is_archived", False)),
|
||||
)
|
||||
return _build_board_item_result(first_item)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -297,3 +264,80 @@ async def create_and_archive_board_item(
|
||||
archived=archived,
|
||||
verified=verified,
|
||||
)
|
||||
|
||||
|
||||
def _require_str(mapping: Mapping[str, object], key: str) -> str:
|
||||
value = mapping.get(key)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
raise errors.GraphQLOperationError(
|
||||
f"Expected string for '{key}'", details={"received_type": type(value).__name__}
|
||||
)
|
||||
|
||||
|
||||
def _optional_str(mapping: Mapping[str, object], key: str) -> str | None:
|
||||
"""Get an optional string value from mapping."""
|
||||
value = mapping.get(key)
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
raise errors.GraphQLOperationError(
|
||||
f"Expected string or None for '{key}'", details={"received_type": type(value).__name__}
|
||||
)
|
||||
|
||||
|
||||
def _require_int(mapping: Mapping[str, object], key: str) -> int:
|
||||
value = mapping.get(key)
|
||||
if isinstance(value, bool):
|
||||
raise errors.GraphQLOperationError(
|
||||
f"Expected integer for '{key}', got bool", details={"key": key}
|
||||
)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
if isinstance(value, float) and value.is_integer():
|
||||
return int(value)
|
||||
raise errors.GraphQLOperationError(
|
||||
f"Expected integer for '{key}'", details={"received_type": type(value).__name__}
|
||||
)
|
||||
|
||||
|
||||
def _require_bool(mapping: Mapping[str, object], key: str, *, default: bool = False) -> bool:
|
||||
value = mapping.get(key, default)
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, (int, float)):
|
||||
return bool(value)
|
||||
raise errors.GraphQLOperationError(
|
||||
f"Expected boolean for '{key}'", details={"received_type": type(value).__name__}
|
||||
)
|
||||
|
||||
|
||||
def _require_dict(mapping: Mapping[str, object], key: str) -> dict[str, object]:
|
||||
value = mapping.get(key, {})
|
||||
if isinstance(value, dict):
|
||||
# Cast to help type checker understand dict.items() types
|
||||
typed_dict = cast(dict[str, object], value)
|
||||
return {str(k): v for k, v in typed_dict.items()}
|
||||
if isinstance(value, Mapping):
|
||||
# For Mapping, convert to dict with explicit typing
|
||||
typed_mapping = cast(Mapping[str, object], value)
|
||||
return {str(k): v for k, v in typed_mapping.items()}
|
||||
raise errors.GraphQLOperationError(
|
||||
f"Expected object for '{key}'", details={"received_type": type(value).__name__}
|
||||
)
|
||||
|
||||
|
||||
def _build_board_item_result(item: Mapping[str, object]) -> BoardItemResult:
|
||||
"""Convert a raw mapping into a strongly typed BoardItemResult."""
|
||||
return BoardItemResult(
|
||||
uuid=_require_str(item, "uuid"),
|
||||
id=_require_str(item, "id"),
|
||||
board_id=_require_int(item, "board_id"),
|
||||
board_name=_require_str(item, "board_name"),
|
||||
instance_id=_require_int(item, "instance_id"),
|
||||
data=_require_dict(item, "data"),
|
||||
requested_by=_optional_str(item, "requested_by"),
|
||||
created_at=_require_str(item, "created_at"),
|
||||
is_archived=_require_bool(item, "is_archived", default=False),
|
||||
)
|
||||
|
||||
@@ -23,25 +23,27 @@ Usage:
|
||||
"""
|
||||
|
||||
from typing import ClassVar
|
||||
from guide.app.strings.demo_texts.intake import IntakeTexts
|
||||
|
||||
from guide.app.strings.demo_texts.contract import ContractTexts
|
||||
from guide.app.strings.demo_texts.intake import IntakeTexts
|
||||
from guide.app.strings.demo_texts.suppliers import SupplierTexts
|
||||
from guide.app.strings.labels.auth import AuthLabels
|
||||
from guide.app.strings.labels.intake import IntakeLabels
|
||||
from guide.app.strings.labels.sourcing import SourcingLabels
|
||||
from guide.app.strings.labels.contract import (
|
||||
ContractBoardLabels,
|
||||
ContractFilterLabels,
|
||||
ContractFormLabels,
|
||||
)
|
||||
from guide.app.strings.labels.intake import IntakeLabels
|
||||
from guide.app.strings.labels.sourcing import SourcingLabels
|
||||
from guide.app.strings.selectors.auth import AuthSelectors
|
||||
from guide.app.strings.selectors.common import CommonSelectors
|
||||
from guide.app.strings.selectors.intake import IntakeSelectors
|
||||
from guide.app.strings.selectors.contract import (
|
||||
ContractBoardSelectors,
|
||||
ContractDashboardFilters,
|
||||
ContractFormSelectors,
|
||||
)
|
||||
from guide.app.strings.selectors.intake import IntakeSelectors
|
||||
from guide.app.strings.selectors.messaging import MessagingSelectors
|
||||
from guide.app.strings.selectors.navigation import NavigationSelectors
|
||||
from guide.app.strings.selectors.sourcing import SourcingSelectors
|
||||
|
||||
@@ -71,10 +73,18 @@ class IntakeStrings:
|
||||
reseller_textarea: ClassVar[str] = IntakeSelectors.RESELLER_TEXTAREA
|
||||
target_date_field: ClassVar[str] = IntakeSelectors.TARGET_DATE_FIELD
|
||||
entity_field: ClassVar[str] = IntakeSelectors.ENTITY_FIELD
|
||||
desired_supplier_name_field: ClassVar[str] = IntakeSelectors.DESIRED_SUPPLIER_NAME_FIELD
|
||||
desired_supplier_name_textarea: ClassVar[str] = IntakeSelectors.DESIRED_SUPPLIER_NAME_TEXTAREA
|
||||
desired_supplier_contact_field: ClassVar[str] = IntakeSelectors.DESIRED_SUPPLIER_CONTACT_FIELD
|
||||
desired_supplier_contact_textarea: ClassVar[str] = IntakeSelectors.DESIRED_SUPPLIER_CONTACT_TEXTAREA
|
||||
desired_supplier_name_field: ClassVar[str] = (
|
||||
IntakeSelectors.DESIRED_SUPPLIER_NAME_FIELD
|
||||
)
|
||||
desired_supplier_name_textarea: ClassVar[str] = (
|
||||
IntakeSelectors.DESIRED_SUPPLIER_NAME_TEXTAREA
|
||||
)
|
||||
desired_supplier_contact_field: ClassVar[str] = (
|
||||
IntakeSelectors.DESIRED_SUPPLIER_CONTACT_FIELD
|
||||
)
|
||||
desired_supplier_contact_textarea: ClassVar[str] = (
|
||||
IntakeSelectors.DESIRED_SUPPLIER_CONTACT_TEXTAREA
|
||||
)
|
||||
select_document_field: ClassVar[str] = IntakeSelectors.SELECT_DOCUMENT_FIELD
|
||||
drop_zone_container: ClassVar[str] = IntakeSelectors.DROP_ZONE_CONTAINER
|
||||
drop_zone_input: ClassVar[str] = IntakeSelectors.DROP_ZONE_INPUT
|
||||
@@ -87,7 +97,9 @@ class IntakeStrings:
|
||||
|
||||
# Labels - Legacy
|
||||
next_button_label: ClassVar[str] = IntakeLabels.NEXT_BUTTON
|
||||
legacy_description_placeholder: ClassVar[str] = IntakeLabels.LEGACY_DESCRIPTION_PLACEHOLDER
|
||||
legacy_description_placeholder: ClassVar[str] = (
|
||||
IntakeLabels.LEGACY_DESCRIPTION_PLACEHOLDER
|
||||
)
|
||||
|
||||
# Labels - Sourcing Request Form
|
||||
requester_label: ClassVar[str] = IntakeLabels.REQUESTER
|
||||
@@ -103,7 +115,9 @@ class IntakeStrings:
|
||||
reseller_label: ClassVar[str] = IntakeLabels.RESELLER
|
||||
entity_label: ClassVar[str] = IntakeLabels.ENTITY
|
||||
desired_supplier_name_label: ClassVar[str] = IntakeLabels.DESIRED_SUPPLIER_NAME
|
||||
desired_supplier_contact_label: ClassVar[str] = IntakeLabels.DESIRED_SUPPLIER_CONTACT
|
||||
desired_supplier_contact_label: ClassVar[str] = (
|
||||
IntakeLabels.DESIRED_SUPPLIER_CONTACT
|
||||
)
|
||||
select_document_label: ClassVar[str] = IntakeLabels.SELECT_DOCUMENT
|
||||
back_button_label: ClassVar[str] = IntakeLabels.BACK_BUTTON
|
||||
submit_button_label: ClassVar[str] = IntakeLabels.SUBMIT_BUTTON
|
||||
@@ -120,8 +134,12 @@ class IntakeStrings:
|
||||
opex_capex_request: ClassVar[str] = IntakeTexts.OPEX_CAPEX_REQUEST
|
||||
description_request: ClassVar[str] = IntakeTexts.DESCRIPTION_REQUEST
|
||||
target_date_request: ClassVar[str] = IntakeTexts.TARGET_DATE_REQUEST
|
||||
desired_supplier_name_request: ClassVar[str] = IntakeTexts.DESIRED_SUPPLIER_NAME_REQUEST
|
||||
desired_supplier_contact_request: ClassVar[str] = IntakeTexts.DESIRED_SUPPLIER_CONTACT_REQUEST
|
||||
desired_supplier_name_request: ClassVar[str] = (
|
||||
IntakeTexts.DESIRED_SUPPLIER_NAME_REQUEST
|
||||
)
|
||||
desired_supplier_contact_request: ClassVar[str] = (
|
||||
IntakeTexts.DESIRED_SUPPLIER_CONTACT_REQUEST
|
||||
)
|
||||
reseller_request: ClassVar[str] = IntakeTexts.RESELLER_REQUEST
|
||||
entity_request: ClassVar[str] = IntakeTexts.ENTITY_REQUEST
|
||||
|
||||
@@ -137,13 +155,19 @@ class ContractStrings:
|
||||
drop_zone_input: ClassVar[str] = ContractFormSelectors.DROP_ZONE_INPUT
|
||||
|
||||
contract_type_field: ClassVar[str] = ContractFormSelectors.CONTRACT_TYPE_FIELD
|
||||
contract_commodities_field: ClassVar[str] = ContractFormSelectors.CONTRACT_COMMODITIES_FIELD
|
||||
contract_commodities_field: ClassVar[str] = (
|
||||
ContractFormSelectors.CONTRACT_COMMODITIES_FIELD
|
||||
)
|
||||
supplier_contact_field: ClassVar[str] = ContractFormSelectors.SUPPLIER_CONTACT_FIELD
|
||||
classification_field: ClassVar[str] = ContractFormSelectors.CLASSIFICATION_FIELD
|
||||
entity_and_regions_field: ClassVar[str] = ContractFormSelectors.ENTITY_AND_REGIONS_FIELD
|
||||
entity_and_regions_field: ClassVar[str] = (
|
||||
ContractFormSelectors.ENTITY_AND_REGIONS_FIELD
|
||||
)
|
||||
renewal_type_field: ClassVar[str] = ContractFormSelectors.RENEWAL_TYPE_FIELD
|
||||
renewal_increase_field: ClassVar[str] = ContractFormSelectors.RENEWAL_INCREASE_FIELD
|
||||
renewal_alert_days_field: ClassVar[str] = ContractFormSelectors.RENEWAL_ALERT_DAYS_FIELD
|
||||
renewal_alert_days_field: ClassVar[str] = (
|
||||
ContractFormSelectors.RENEWAL_ALERT_DAYS_FIELD
|
||||
)
|
||||
effective_date_field: ClassVar[str] = ContractFormSelectors.EFFECTIVE_DATE_FIELD
|
||||
end_date_field: ClassVar[str] = ContractFormSelectors.END_DATE_FIELD
|
||||
required_notice_for_nonrenewal_field: ClassVar[str] = (
|
||||
@@ -165,20 +189,32 @@ class ContractStrings:
|
||||
payment_terms_field: ClassVar[str] = ContractFormSelectors.PAYMENT_TERMS_FIELD
|
||||
payment_schedule_field: ClassVar[str] = ContractFormSelectors.PAYMENT_SCHEDULE_FIELD
|
||||
business_contact_field: ClassVar[str] = ContractFormSelectors.BUSINESS_CONTACT_FIELD
|
||||
managing_department_field: ClassVar[str] = ContractFormSelectors.MANAGING_DEPARTMENT_FIELD
|
||||
funding_department_field: ClassVar[str] = ContractFormSelectors.FUNDING_DEPARTMENT_FIELD
|
||||
managing_department_field: ClassVar[str] = (
|
||||
ContractFormSelectors.MANAGING_DEPARTMENT_FIELD
|
||||
)
|
||||
funding_department_field: ClassVar[str] = (
|
||||
ContractFormSelectors.FUNDING_DEPARTMENT_FIELD
|
||||
)
|
||||
reseller_field: ClassVar[str] = ContractFormSelectors.RESELLER_FIELD
|
||||
project_name_field: ClassVar[str] = ContractFormSelectors.PROJECT_NAME_FIELD
|
||||
master_project_name_field: ClassVar[str] = ContractFormSelectors.MASTER_PROJECT_NAME_FIELD
|
||||
business_continuity_field: ClassVar[str] = ContractFormSelectors.BUSINESS_CONTINUITY_FIELD
|
||||
master_project_name_field: ClassVar[str] = (
|
||||
ContractFormSelectors.MASTER_PROJECT_NAME_FIELD
|
||||
)
|
||||
business_continuity_field: ClassVar[str] = (
|
||||
ContractFormSelectors.BUSINESS_CONTINUITY_FIELD
|
||||
)
|
||||
customer_data_field: ClassVar[str] = ContractFormSelectors.CUSTOMER_DATA_FIELD
|
||||
breach_notification_field: ClassVar[str] = ContractFormSelectors.BREACH_NOTIFICATION_FIELD
|
||||
breach_notification_field: ClassVar[str] = (
|
||||
ContractFormSelectors.BREACH_NOTIFICATION_FIELD
|
||||
)
|
||||
|
||||
# Board/grid selectors
|
||||
grid_view_button: ClassVar[str] = ContractBoardSelectors.GRID_VIEW_BUTTON
|
||||
chart_view_button: ClassVar[str] = ContractBoardSelectors.CHART_VIEW_BUTTON
|
||||
active_filter_button: ClassVar[str] = ContractBoardSelectors.ACTIVE_FILTER_BUTTON
|
||||
archived_filter_button: ClassVar[str] = ContractBoardSelectors.ARCHIVED_FILTER_BUTTON
|
||||
archived_filter_button: ClassVar[str] = (
|
||||
ContractBoardSelectors.ARCHIVED_FILTER_BUTTON
|
||||
)
|
||||
quick_filter: ClassVar[str] = ContractBoardSelectors.QUICK_FILTER
|
||||
add_row_button: ClassVar[str] = ContractBoardSelectors.ADD_ROW_BUTTON
|
||||
|
||||
@@ -194,11 +230,19 @@ class ContractStrings:
|
||||
filter_order_form: ClassVar[str] = ContractDashboardFilters.FILTER_ORDER_FORM
|
||||
filter_sow: ClassVar[str] = ContractDashboardFilters.FILTER_SOW
|
||||
filter_draft: ClassVar[str] = ContractDashboardFilters.FILTER_DRAFT
|
||||
filter_pending_approval: ClassVar[str] = ContractDashboardFilters.FILTER_PENDING_APPROVAL
|
||||
filter_ready_for_signature: ClassVar[str] = ContractDashboardFilters.FILTER_READY_FOR_SIGNATURE
|
||||
filter_pending_signature: ClassVar[str] = ContractDashboardFilters.FILTER_PENDING_SIGNATURE
|
||||
filter_pending_approval: ClassVar[str] = (
|
||||
ContractDashboardFilters.FILTER_PENDING_APPROVAL
|
||||
)
|
||||
filter_ready_for_signature: ClassVar[str] = (
|
||||
ContractDashboardFilters.FILTER_READY_FOR_SIGNATURE
|
||||
)
|
||||
filter_pending_signature: ClassVar[str] = (
|
||||
ContractDashboardFilters.FILTER_PENDING_SIGNATURE
|
||||
)
|
||||
filter_direct: ClassVar[str] = ContractDashboardFilters.FILTER_DIRECT
|
||||
filter_minority_supplier: ClassVar[str] = ContractDashboardFilters.FILTER_MINORITY_SUPPLIER
|
||||
filter_minority_supplier: ClassVar[str] = (
|
||||
ContractDashboardFilters.FILTER_MINORITY_SUPPLIER
|
||||
)
|
||||
filter_preferred: ClassVar[str] = ContractDashboardFilters.FILTER_PREFERRED
|
||||
filter_sell_side: ClassVar[str] = ContractDashboardFilters.FILTER_SELL_SIDE
|
||||
|
||||
@@ -230,7 +274,9 @@ class ContractStrings:
|
||||
required_notice_for_nonrenewal_label: ClassVar[str] = (
|
||||
ContractFormLabels.REQUIRED_NOTICE_FOR_NONRENEWAL
|
||||
)
|
||||
terminate_for_convenience_label: ClassVar[str] = ContractFormLabels.TERMINATE_FOR_CONVENIENCE
|
||||
terminate_for_convenience_label: ClassVar[str] = (
|
||||
ContractFormLabels.TERMINATE_FOR_CONVENIENCE
|
||||
)
|
||||
required_notice_for_termination_label: ClassVar[str] = (
|
||||
ContractFormLabels.REQUIRED_NOTICE_FOR_TERMINATION
|
||||
)
|
||||
@@ -277,7 +323,9 @@ class ContractStrings:
|
||||
|
||||
# Demo text
|
||||
contract_type_request: ClassVar[str] = ContractTexts.CONTRACT_TYPE
|
||||
contract_commodities_request: ClassVar[tuple[str, ...]] = ContractTexts.CONTRACT_COMMODITIES
|
||||
contract_commodities_request: ClassVar[tuple[str, ...]] = (
|
||||
ContractTexts.CONTRACT_COMMODITIES
|
||||
)
|
||||
supplier_contact_request: ClassVar[str] = ContractTexts.SUPPLIER_CONTACT
|
||||
classification_request: ClassVar[str] = ContractTexts.CLASSIFICATION
|
||||
entity_and_regions_request: ClassVar[str] = ContractTexts.ENTITY_AND_REGIONS
|
||||
@@ -289,7 +337,9 @@ class ContractStrings:
|
||||
required_notice_for_nonrenewal_request: ClassVar[str] = (
|
||||
ContractTexts.REQUIRED_NOTICE_FOR_NONRENEWAL
|
||||
)
|
||||
terminate_for_convenience_request: ClassVar[bool] = ContractTexts.TERMINATE_FOR_CONVENIENCE
|
||||
terminate_for_convenience_request: ClassVar[bool] = (
|
||||
ContractTexts.TERMINATE_FOR_CONVENIENCE
|
||||
)
|
||||
required_notice_for_termination_request: ClassVar[str] = (
|
||||
ContractTexts.REQUIRED_NOTICE_FOR_TERMINATION
|
||||
)
|
||||
@@ -375,6 +425,24 @@ class CommonStrings:
|
||||
page_header_accordion: ClassVar[str] = CommonSelectors.PAGE_HEADER_ACCORDION
|
||||
|
||||
|
||||
class MessagingStrings:
|
||||
"""Messaging flow strings: selectors for chat panel interactions."""
|
||||
|
||||
# New XPath-based selectors (primary)
|
||||
notification_indicator: ClassVar[str] = MessagingSelectors.NOTIFICATION_INDICATOR
|
||||
modal_wrapper: ClassVar[str] = MessagingSelectors.MODAL_WRAPPER
|
||||
modal_close_button: ClassVar[str] = MessagingSelectors.MODAL_CLOSE_BUTTON
|
||||
chat_messages_container: ClassVar[str] = MessagingSelectors.CHAT_MESSAGES_CONTAINER
|
||||
chat_flyout_button: ClassVar[str] = MessagingSelectors.CHAT_FLYOUT_BUTTON
|
||||
chat_conversations_tab: ClassVar[str] = MessagingSelectors.CHAT_CONVERSATIONS_TAB
|
||||
|
||||
# Legacy selectors (fallback)
|
||||
chat_button: ClassVar[str] = MessagingSelectors.CHAT_BUTTON
|
||||
chat_panel: ClassVar[str] = MessagingSelectors.CHAT_PANEL
|
||||
chat_input: ClassVar[str] = MessagingSelectors.CHAT_INPUT
|
||||
send_button: ClassVar[str] = MessagingSelectors.SEND_BUTTON
|
||||
|
||||
|
||||
class AppStrings:
|
||||
"""Root registry for all application strings.
|
||||
|
||||
@@ -391,6 +459,7 @@ class AppStrings:
|
||||
navigation: ClassVar[type[NavigationStrings]] = NavigationStrings
|
||||
auth: ClassVar[type[AuthStrings]] = AuthStrings
|
||||
common: ClassVar[type[CommonStrings]] = CommonStrings
|
||||
messaging: ClassVar[type[MessagingStrings]] = MessagingStrings
|
||||
|
||||
|
||||
# Module-level instance for convenience
|
||||
@@ -405,4 +474,5 @@ __all__ = [
|
||||
"NavigationStrings",
|
||||
"AuthStrings",
|
||||
"CommonStrings",
|
||||
"MessagingStrings",
|
||||
]
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from typing import ClassVar
|
||||
|
||||
from guide.app.strings.selectors.auth import AuthSelectors
|
||||
from guide.app.strings.selectors.auth import Auth0ErrorSelectors, AuthSelectors
|
||||
from guide.app.strings.selectors.contract import (
|
||||
ContractBoardSelectors,
|
||||
ContractDashboardFilters,
|
||||
ContractFormSelectors,
|
||||
)
|
||||
from guide.app.strings.selectors.intake import IntakeSelectors
|
||||
from guide.app.strings.selectors.login import LoginSelectors
|
||||
from guide.app.strings.selectors.navigation import (
|
||||
DrawerNavigationSelectors,
|
||||
MenuNavigationSelectors,
|
||||
@@ -18,6 +19,7 @@ from guide.app.strings.selectors.sourcing import SourcingSelectors
|
||||
class Selectors:
|
||||
AUTH: ClassVar[type[AuthSelectors]] = AuthSelectors
|
||||
INTAKE: ClassVar[type[IntakeSelectors]] = IntakeSelectors
|
||||
LOGIN: ClassVar[type[LoginSelectors]] = LoginSelectors
|
||||
SOURCING: ClassVar[type[SourcingSelectors]] = SourcingSelectors
|
||||
NAVIGATION: ClassVar[type[NavigationSelectors]] = NavigationSelectors
|
||||
DRAWER_NAVIGATION: ClassVar[type[DrawerNavigationSelectors]] = DrawerNavigationSelectors
|
||||
@@ -33,7 +35,9 @@ __all__ = [
|
||||
"Selectors",
|
||||
"selectors",
|
||||
"AuthSelectors",
|
||||
"Auth0ErrorSelectors",
|
||||
"IntakeSelectors",
|
||||
"LoginSelectors",
|
||||
"SourcingSelectors",
|
||||
"NavigationSelectors",
|
||||
"DrawerNavigationSelectors",
|
||||
|
||||
@@ -1,10 +1,56 @@
|
||||
"""Auth selectors and Auth0 error detection patterns."""
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
|
||||
class AuthSelectors:
|
||||
"""Selectors for auth UI elements."""
|
||||
|
||||
EMAIL_INPUT: ClassVar[str] = '[data-test="auth-email-input"]'
|
||||
SEND_CODE_BUTTON: ClassVar[str] = '[data-test="auth-send-code"]'
|
||||
CODE_INPUT: ClassVar[str] = '[data-test="auth-code-input"]'
|
||||
SUBMIT_BUTTON: ClassVar[str] = '[data-test="auth-submit"]'
|
||||
LOGOUT_BUTTON: ClassVar[str] = '[data-test="auth-logout"]'
|
||||
CURRENT_USER_DISPLAY: ClassVar[str] = '[data-test="user-display"]'
|
||||
|
||||
|
||||
class Auth0ErrorSelectors:
|
||||
"""Auth0 error detection selectors.
|
||||
|
||||
Used to detect authentication errors on Auth0 pages.
|
||||
These are text-based selectors that match error messages.
|
||||
"""
|
||||
|
||||
# Verification code errors
|
||||
VERIFICATION_CODE_ERRORS: ClassVar[tuple[str, ...]] = (
|
||||
"text=code is invalid",
|
||||
"text=code has expired",
|
||||
"text=too many attempts",
|
||||
"text=try again",
|
||||
".error-message",
|
||||
"[data-error='true']",
|
||||
)
|
||||
|
||||
# OTP URL / magic link errors
|
||||
OTP_URL_ERRORS: ClassVar[tuple[str, ...]] = (
|
||||
"text=verification code has expired",
|
||||
"text=link has expired",
|
||||
"text=link is invalid",
|
||||
"text=link has already been used",
|
||||
"text=Access denied",
|
||||
"text=try to login again",
|
||||
".error-message",
|
||||
"[data-error='true']",
|
||||
)
|
||||
|
||||
# Login button selectors for OTP confirmation page
|
||||
LOGIN_BUTTON_SELECTORS: ClassVar[tuple[str, ...]] = (
|
||||
"button:has-text('LOG IN')",
|
||||
"button:has-text('Log in')",
|
||||
"button:has-text('Login')",
|
||||
"[data-action-button-primary='true']",
|
||||
"form button[type='submit']",
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["AuthSelectors", "Auth0ErrorSelectors"]
|
||||
|
||||
64
src/guide/app/strings/selectors/login.py
Normal file
64
src/guide/app/strings/selectors/login.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Login page UI selectors.
|
||||
|
||||
Provides selectors for Raindrop login page interaction.
|
||||
"""
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
|
||||
class LoginSelectors:
|
||||
"""Selectors for login page UI elements.
|
||||
|
||||
Supports two modes for email field:
|
||||
- Dropdown (MuiSelect): When recent sessions exist
|
||||
- Text input: When no recent sessions
|
||||
"""
|
||||
|
||||
# Login page container
|
||||
LOGIN_CONTAINER: ClassVar[str] = "#login"
|
||||
"""Main login form container."""
|
||||
|
||||
# Instance selector (always a dropdown)
|
||||
INSTANCE_SELECTOR: ClassVar[str] = '[data-cy="login-instance-selector"]'
|
||||
"""Instance dropdown selector."""
|
||||
|
||||
# Email field - can be dropdown OR text input
|
||||
EMAIL_FIELD: ClassVar[str] = '[data-cy="login-email-or-connection-input"]'
|
||||
"""Email field wrapper (dropdown or text input)."""
|
||||
|
||||
EMAIL_TEXT_INPUT: ClassVar[str] = 'input[name="emailOrConnection"]'
|
||||
"""Email text input (when no recent sessions)."""
|
||||
|
||||
EMAIL_DROPDOWN: ClassVar[str] = (
|
||||
'[data-cy="login-email-or-connection-input"] .MuiSelect-select'
|
||||
)
|
||||
"""Email dropdown display (when recent sessions exist)."""
|
||||
|
||||
# Clear button to switch from dropdown to text input
|
||||
EMAIL_CLEAR_BUTTON: ClassVar[str] = (
|
||||
'[data-cy="login-toggle-email-edit-enable-button"]'
|
||||
)
|
||||
"""Button to clear selected email and enable text input."""
|
||||
|
||||
# Login submit button
|
||||
LOGIN_BUTTON: ClassVar[str] = '[data-cy="login-submit-button"]'
|
||||
"""Submit button to trigger OTP email."""
|
||||
|
||||
# Error message display
|
||||
ERROR_MESSAGE: ClassVar[str] = "#login h6.MuiTypography-root"
|
||||
"""Error message display area."""
|
||||
|
||||
# Auth0 Verification code page selectors
|
||||
VERIFICATION_CODE_INPUT: ClassVar[str] = "input#code"
|
||||
"""Verification code input field on Auth0 page."""
|
||||
|
||||
VERIFICATION_CODE_SUBMIT: ClassVar[str] = (
|
||||
'button[data-action-button-primary="true"]'
|
||||
)
|
||||
"""Submit button on Auth0 verification code page."""
|
||||
|
||||
VERIFICATION_CODE_LABEL: ClassVar[str] = "label#code-label"
|
||||
"""Label for verification code input (used for page detection)."""
|
||||
|
||||
|
||||
__all__ = ["LoginSelectors"]
|
||||
70
src/guide/app/strings/selectors/messaging.py
Normal file
70
src/guide/app/strings/selectors/messaging.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Chat and messaging UI selectors.
|
||||
|
||||
Provides selectors for chat panel interaction in board view.
|
||||
"""
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
|
||||
class MessagingSelectors:
|
||||
"""Selectors for chat/messaging UI elements.
|
||||
|
||||
XPath locators discovered from target application for:
|
||||
- Notification indicators
|
||||
- Modal handling (dismiss blocking modals after email URL login)
|
||||
- Chat panel expansion and navigation
|
||||
- Message input and send
|
||||
"""
|
||||
|
||||
# Notification indicator - confirms new notifications are available
|
||||
NOTIFICATION_INDICATOR: ClassVar[str] = (
|
||||
"xpath=/html/body/div/div/div/div/div/div[2]/div[1]/div[2]/div[2]/span"
|
||||
)
|
||||
"""Notification badge/indicator showing new messages available."""
|
||||
|
||||
# Modal handling - modal may appear after logging in via emailed URL
|
||||
MODAL_WRAPPER: ClassVar[str] = 'xpath=//*[@id="modal-wrapper"]'
|
||||
"""Modal container that may block page access after email URL login."""
|
||||
|
||||
MODAL_CLOSE_BUTTON: ClassVar[str] = (
|
||||
"xpath=/html/body/div[4]/div[3]/div/div/div/form/div[1]/div/div/div/div[4]/button"
|
||||
)
|
||||
"""Close button to dismiss blocking modal."""
|
||||
|
||||
# Chat panel elements
|
||||
CHAT_MESSAGES_CONTAINER: ClassVar[str] = 'xpath=//*[@id="chat-messages-container"]'
|
||||
"""Chat messages container - if visible, chat panel is already open."""
|
||||
|
||||
CHAT_FLYOUT_BUTTON: ClassVar[str] = "xpath=/html/body/div/div/button"
|
||||
"""Button to expand the chat flyout panel."""
|
||||
|
||||
CHAT_CONVERSATIONS_TAB: ClassVar[str] = 'xpath=//*[@id="right-drawer-tab-open"]'
|
||||
"""Tab button to switch to conversations view."""
|
||||
|
||||
# Legacy selectors (kept for compatibility, may need update)
|
||||
CHAT_BUTTON: ClassVar[str] = (
|
||||
'[data-testid="chat-button"], '
|
||||
'button[aria-label="Open chat"], '
|
||||
'[data-cy="chat-toggle"]'
|
||||
)
|
||||
"""Button to open/toggle the chat panel (legacy fallback)."""
|
||||
|
||||
CHAT_PANEL: ClassVar[str] = (
|
||||
'[data-testid="chat-panel"], .chat-panel, [data-cy="chat-panel"]'
|
||||
)
|
||||
"""Chat panel container element (legacy fallback)."""
|
||||
|
||||
CHAT_INPUT: ClassVar[str] = (
|
||||
'[data-testid="chat-input"], textarea.chat-input, [data-cy="message-input"]'
|
||||
)
|
||||
"""Text input field for composing messages."""
|
||||
|
||||
SEND_BUTTON: ClassVar[str] = (
|
||||
'[data-testid="send-message"], '
|
||||
'button[aria-label="Send"], '
|
||||
'[data-cy="send-button"]'
|
||||
)
|
||||
"""Button to send the composed message."""
|
||||
|
||||
|
||||
__all__ = ["MessagingSelectors"]
|
||||
@@ -1,3 +1,3 @@
|
||||
from guide.app.utils import env, ids, retry, timing
|
||||
from guide.app.utils import env, ids, retry, timing, urls
|
||||
|
||||
__all__ = ["env", "ids", "retry", "timing"]
|
||||
__all__ = ["env", "ids", "retry", "timing", "urls"]
|
||||
|
||||
109
src/guide/app/utils/jwt.py
Normal file
109
src/guide/app/utils/jwt.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""JWT utility functions.
|
||||
|
||||
Shared utilities for JWT parsing and validation without cryptographic verification.
|
||||
These are used for extracting claims and basic format validation only.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import cast
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_jwt_format(value: str) -> bool:
|
||||
"""Check if value looks like a JWT (three base64 segments separated by dots).
|
||||
|
||||
Performs basic structural validation without cryptographic verification.
|
||||
|
||||
Args:
|
||||
value: String to check.
|
||||
|
||||
Returns:
|
||||
True if the value appears to be a JWT, False otherwise.
|
||||
"""
|
||||
parts = value.split(".")
|
||||
if len(parts) != 3:
|
||||
return False
|
||||
# Each part should be non-empty and contain valid base64-ish characters
|
||||
return all(part and all(c.isalnum() or c in "-_=" for c in part) for part in parts)
|
||||
|
||||
|
||||
def parse_jwt_expiry(token: str) -> datetime | None:
|
||||
"""Parse exp claim from JWT payload.
|
||||
|
||||
Decodes the JWT payload (base64) without verification to extract
|
||||
the expiration timestamp.
|
||||
|
||||
Args:
|
||||
token: JWT token string.
|
||||
|
||||
Returns:
|
||||
Expiration datetime (UTC) if found, None otherwise.
|
||||
"""
|
||||
try:
|
||||
parts = token.split(".")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
# Decode payload (second part)
|
||||
payload_b64 = parts[1]
|
||||
# Add padding if needed
|
||||
padding = 4 - len(payload_b64) % 4
|
||||
if padding != 4:
|
||||
payload_b64 += "=" * padding
|
||||
|
||||
payload_bytes = base64.urlsafe_b64decode(payload_b64)
|
||||
# json.loads returns Any; cast to object for type narrowing
|
||||
decoded = cast(object, json.loads(payload_bytes.decode("utf-8")))
|
||||
|
||||
if not isinstance(decoded, dict):
|
||||
return None
|
||||
|
||||
# Cast narrowed dict to typed dict for proper member access
|
||||
payload = cast("dict[str, object]", decoded)
|
||||
exp_value = payload.get("exp")
|
||||
if isinstance(exp_value, int | float):
|
||||
return datetime.fromtimestamp(exp_value, tz=timezone.utc)
|
||||
|
||||
except Exception as exc:
|
||||
_logger.debug("Failed to parse JWT expiry: %s", exc)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def decode_jwt_payload(token: str) -> dict[str, object] | None:
|
||||
"""Decode JWT payload without verification.
|
||||
|
||||
Args:
|
||||
token: JWT token string.
|
||||
|
||||
Returns:
|
||||
Decoded payload as dict if successful, None otherwise.
|
||||
"""
|
||||
try:
|
||||
parts = token.split(".")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
payload_b64 = parts[1]
|
||||
# Add padding if needed
|
||||
padding = 4 - len(payload_b64) % 4
|
||||
if padding != 4:
|
||||
payload_b64 += "=" * padding
|
||||
|
||||
payload_bytes = base64.urlsafe_b64decode(payload_b64)
|
||||
decoded = cast(object, json.loads(payload_bytes.decode("utf-8")))
|
||||
|
||||
if isinstance(decoded, dict):
|
||||
return cast("dict[str, object]", decoded)
|
||||
|
||||
except Exception as exc:
|
||||
_logger.debug("Failed to decode JWT payload: %s", exc)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
__all__ = ["is_jwt_format", "parse_jwt_expiry", "decode_jwt_payload"]
|
||||
@@ -14,6 +14,7 @@ def retry(
|
||||
delay_seconds: float = 0.5,
|
||||
backoff_factor: float = 2.0,
|
||||
on_error: Callable[[Exception, int], None] | None = None,
|
||||
retry_exceptions: tuple[type[BaseException], ...] = (Exception,),
|
||||
) -> T:
|
||||
"""Retry a synchronous function with exponential backoff.
|
||||
|
||||
@@ -23,6 +24,7 @@ def retry(
|
||||
delay_seconds: Initial delay in seconds (default: 0.5)
|
||||
backoff_factor: Multiplier for delay after each retry (default: 2.0)
|
||||
on_error: Optional callback for retry errors
|
||||
retry_exceptions: Tuple of exception types to retry on
|
||||
|
||||
Returns:
|
||||
The result of the function call
|
||||
@@ -35,11 +37,11 @@ def retry(
|
||||
while True:
|
||||
try:
|
||||
return fn()
|
||||
except Exception as exc: # noqa: PERF203
|
||||
except retry_exceptions as exc:
|
||||
attempt += 1
|
||||
if attempt > retries:
|
||||
raise
|
||||
if on_error:
|
||||
if on_error and isinstance(exc, Exception):
|
||||
on_error(exc, attempt)
|
||||
time.sleep(current_delay)
|
||||
current_delay *= backoff_factor
|
||||
@@ -51,6 +53,7 @@ def retry_async(
|
||||
delay_seconds: float = 0.5,
|
||||
backoff_factor: float = 2.0,
|
||||
on_error: Callable[[Exception, int], None] | None = None,
|
||||
retry_exceptions: tuple[type[BaseException], ...] = (Exception,),
|
||||
) -> Callable[[Callable[..., Awaitable[T]]], Callable[..., Awaitable[T]]]:
|
||||
"""Decorator for retrying async functions with exponential backoff.
|
||||
|
||||
@@ -59,6 +62,7 @@ def retry_async(
|
||||
delay_seconds: Initial delay in seconds (default: 0.5)
|
||||
backoff_factor: Multiplier for delay after each retry (default: 2.0)
|
||||
on_error: Optional callback for retry errors
|
||||
retry_exceptions: Tuple of exception types to retry on
|
||||
|
||||
Returns:
|
||||
Decorated async function that retries on exception
|
||||
@@ -77,11 +81,11 @@ def retry_async(
|
||||
while True:
|
||||
try:
|
||||
return await fn(*args, **kwargs)
|
||||
except Exception as exc: # noqa: PERF203
|
||||
except retry_exceptions as exc:
|
||||
attempt += 1
|
||||
if attempt > retries:
|
||||
raise
|
||||
if on_error:
|
||||
if on_error and isinstance(exc, Exception):
|
||||
on_error(exc, attempt)
|
||||
await asyncio.sleep(current_delay)
|
||||
current_delay *= backoff_factor
|
||||
|
||||
42
src/guide/app/utils/urls.py
Normal file
42
src/guide/app/utils/urls.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""URL manipulation utilities."""
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
class InvalidUrlError(ValueError):
|
||||
"""Raised when a URL is invalid or malformed."""
|
||||
|
||||
|
||||
def extract_base_url(url: str) -> str:
|
||||
"""Extract base URL (scheme + netloc) from a full URL.
|
||||
|
||||
Args:
|
||||
url: Full URL possibly with path, query, fragment.
|
||||
|
||||
Returns:
|
||||
Base URL like 'https://stg.raindrop.com'.
|
||||
|
||||
Raises:
|
||||
InvalidUrlError: If URL is invalid, missing scheme, or missing netloc.
|
||||
|
||||
Examples:
|
||||
>>> extract_base_url("https://stg.raindrop.com/login?x=1")
|
||||
'https://stg.raindrop.com'
|
||||
>>> extract_base_url("http://localhost:8000/api/v1")
|
||||
'http://localhost:8000'
|
||||
"""
|
||||
if not url:
|
||||
raise InvalidUrlError("Invalid URL: expected non-empty string, got empty string")
|
||||
|
||||
parsed = urlparse(url)
|
||||
|
||||
if not parsed.scheme:
|
||||
raise InvalidUrlError(f"Invalid URL: missing scheme in '{url}'")
|
||||
|
||||
if not parsed.netloc:
|
||||
raise InvalidUrlError(f"Invalid URL: missing netloc in '{url}'")
|
||||
|
||||
return f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
|
||||
__all__ = ["extract_base_url", "InvalidUrlError"]
|
||||
@@ -20,28 +20,24 @@ def mock_browser_hosts() -> dict[str, BrowserHostConfig]:
|
||||
from guide.app.core.config import BrowserHostConfig, HostKind
|
||||
|
||||
return {
|
||||
"demo-cdp": BrowserHostConfig(
|
||||
id="demo-cdp",
|
||||
kind=HostKind.CDP,
|
||||
host="192.168.50.185",
|
||||
port=9223,
|
||||
),
|
||||
"demo-extension": BrowserHostConfig(
|
||||
id="demo-extension",
|
||||
kind=HostKind.EXTENSION,
|
||||
port=17373,
|
||||
),
|
||||
"support-cdp": BrowserHostConfig(
|
||||
id="support-cdp",
|
||||
kind=HostKind.CDP,
|
||||
host="192.168.50.108",
|
||||
port=9223,
|
||||
),
|
||||
"support-extension": BrowserHostConfig(
|
||||
id="support-extension",
|
||||
kind=HostKind.EXTENSION,
|
||||
port=17374,
|
||||
),
|
||||
"browserless-cdp": BrowserHostConfig(
|
||||
id="browserless-cdp",
|
||||
kind=HostKind.CDP,
|
||||
host="browserless.lab",
|
||||
port=80,
|
||||
cdp_url="ws://browserless.lab:80/",
|
||||
browser="chromium",
|
||||
),
|
||||
"headless-local": BrowserHostConfig(
|
||||
id="headless-local",
|
||||
kind=HostKind.HEADLESS,
|
||||
@@ -95,7 +91,7 @@ def app_settings(
|
||||
return AppSettings(
|
||||
raindrop_base_url="https://app.raindrop.com",
|
||||
raindrop_graphql_url="https://app.raindrop.com/graphql",
|
||||
default_browser_host_id="demo-headless",
|
||||
default_browser_host_id="browserless-cdp",
|
||||
browser_hosts=mock_browser_hosts,
|
||||
personas=mock_personas,
|
||||
)
|
||||
|
||||
4
tests/fixtures/data_generator.py
vendored
4
tests/fixtures/data_generator.py
vendored
@@ -72,6 +72,8 @@ FormValue = str | int | list[str] | bool
|
||||
class DataConfig:
|
||||
"""Load and provide access to test data configuration."""
|
||||
|
||||
__test__ = False # Tell pytest this is not a test class
|
||||
|
||||
config: TestDataConfig
|
||||
|
||||
def __init__(self, config_path: Path | None = None):
|
||||
@@ -116,6 +118,8 @@ class DataConfig:
|
||||
class DataGenerator:
|
||||
"""Generate realistic test data for form fields."""
|
||||
|
||||
__test__ = False # Tell pytest this is not a test class
|
||||
|
||||
config: DataConfig
|
||||
use_faker: bool
|
||||
faker: FakerLike | None
|
||||
|
||||
@@ -98,7 +98,7 @@ async def test_sourcing_request_uses_pagehelpers_wrapper(
|
||||
mock_form_helpers: None,
|
||||
action_context: ActionContext,
|
||||
) -> None:
|
||||
"""Verify action uses PageHelpers wrapper."""
|
||||
"""Verify action returns structured result with selection details."""
|
||||
action = FillSourcingRequestAction()
|
||||
|
||||
result = await action.run(mock_page_with_helpers, action_context)
|
||||
@@ -106,32 +106,26 @@ async def test_sourcing_request_uses_pagehelpers_wrapper(
|
||||
assert isinstance(result, ActionResult)
|
||||
assert result.details["message"] == "Sourcing request form filled"
|
||||
assert "selection_results" in result.details
|
||||
assert "fields_filled" in result.details
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sourcing_request_calls_wait_for_network_idle(
|
||||
async def test_sourcing_request_completes_all_dropdown_selections(
|
||||
mock_page_with_helpers: MagicMock,
|
||||
mock_dropdown_helpers: None,
|
||||
mock_form_helpers: None,
|
||||
action_context: ActionContext,
|
||||
) -> None:
|
||||
"""Verify wait_for_network_idle() called after dropdown selections."""
|
||||
"""Verify all dropdown selections complete successfully."""
|
||||
action = FillSourcingRequestAction()
|
||||
|
||||
await action.run(mock_page_with_helpers, action_context)
|
||||
result = await action.run(mock_page_with_helpers, action_context)
|
||||
|
||||
# Should call wait_for_load_state("networkidle") after each dropdown operation
|
||||
# Expected calls: commodities, planned, regions, opex_capex, entity = 5 calls
|
||||
wait_calls = [
|
||||
call_item for call_item in mock_page_with_helpers.wait_for_load_state.call_args_list
|
||||
if call_item[0][0] == "networkidle"
|
||||
]
|
||||
# Verify all dropdown selections were captured in results
|
||||
from typing import cast
|
||||
selection_results = cast("dict[str, dict[str, list[str]]]", result.details["selection_results"])
|
||||
|
||||
assert len(wait_calls) >= 5, (
|
||||
f"Expected at least 5 wait_for_load_state('networkidle') calls, "
|
||||
f"got {len(wait_calls)}"
|
||||
)
|
||||
expected_dropdowns = {"commodities", "planned", "regions", "opex_capex", "entity"}
|
||||
assert set(selection_results.keys()) == expected_dropdowns
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -203,18 +197,16 @@ async def test_sourcing_request_returns_structured_results(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sourcing_request_captures_diagnostics_on_error(
|
||||
async def test_sourcing_request_propagates_errors(
|
||||
mock_page_with_helpers: MagicMock,
|
||||
mock_dropdown_helpers: None,
|
||||
action_context: ActionContext,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Verify diagnostics captured when action fails."""
|
||||
from guide.app.errors import ActionExecutionError
|
||||
|
||||
"""Verify errors from form helpers propagate correctly."""
|
||||
# Make fill_date raise an error
|
||||
async def mock_fill_date_error(page: "PageLike", selector: str, value: str) -> None:
|
||||
raise ValueError("Test error")
|
||||
raise ValueError("Test error in fill_date")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"guide.app.actions.intake.sourcing_request.fill_date",
|
||||
@@ -223,14 +215,11 @@ async def test_sourcing_request_captures_diagnostics_on_error(
|
||||
|
||||
action = FillSourcingRequestAction()
|
||||
|
||||
with pytest.raises(ActionExecutionError) as exc_info:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await action.run(mock_page_with_helpers, action_context)
|
||||
|
||||
# Verify error message includes context
|
||||
assert "Failed to fill sourcing request form" in str(exc_info.value)
|
||||
assert exc_info.value.details is not None
|
||||
assert "error" in exc_info.value.details
|
||||
assert "debug_info" in exc_info.value.details
|
||||
# Verify original error message preserved
|
||||
assert "Test error in fill_date" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -80,113 +80,140 @@ class TestExtensionPageSendCommand:
|
||||
|
||||
|
||||
class TestExtensionPageClick:
|
||||
"""Test the click method using CLICK action."""
|
||||
"""Test the click method using eval_js."""
|
||||
|
||||
async def test_click_sends_correct_command(self) -> None:
|
||||
"""Click sends structured CLICK command."""
|
||||
"""Click executes JavaScript with click logic via eval_js."""
|
||||
mock_client = MagicMock(spec=ExtensionClient)
|
||||
mock_client.send_message = AsyncMock()
|
||||
page = ExtensionPage(mock_client)
|
||||
|
||||
with patch.object(page, "_send_command", new_callable=AsyncMock) as mock_send:
|
||||
mock_send.return_value = {"success": True}
|
||||
await page.click("button.submit")
|
||||
# Direct assignment to instance for proper method interception
|
||||
mock_eval = AsyncMock(return_value={"success": True})
|
||||
page.eval_js = mock_eval # type: ignore[method-assign]
|
||||
await page.click("button.submit")
|
||||
|
||||
mock_send.assert_called_once_with("CLICK", {"selector": "button.submit"})
|
||||
mock_eval.assert_called_once()
|
||||
js_code = mock_eval.call_args[0][0]
|
||||
assert "querySelector" in js_code
|
||||
assert "button.submit" in js_code
|
||||
assert "click" in js_code
|
||||
|
||||
async def test_click_handles_error(self) -> None:
|
||||
"""Click handles error from extension."""
|
||||
"""Click handles error response from JavaScript."""
|
||||
mock_client = MagicMock(spec=ExtensionClient)
|
||||
mock_client.send_message = AsyncMock()
|
||||
page = ExtensionPage(mock_client)
|
||||
|
||||
with patch.object(page, "_send_command", new_callable=AsyncMock) as mock_send:
|
||||
mock_send.side_effect = BrowserConnectionError("Element not found")
|
||||
|
||||
try:
|
||||
await page.click("button.missing")
|
||||
assert False, "Should have raised BrowserConnectionError"
|
||||
except BrowserConnectionError as e:
|
||||
assert "Element not found" in str(e)
|
||||
# Direct assignment to instance for proper method interception
|
||||
mock_eval = AsyncMock(return_value={"success": False, "error": "Element not found"})
|
||||
page.eval_js = mock_eval # type: ignore[method-assign]
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await page.click("button.missing")
|
||||
assert "Element not found" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestExtensionPageFill:
|
||||
"""Test the fill method using FILL action."""
|
||||
"""Test the fill method using eval_js."""
|
||||
|
||||
async def test_fill_sends_correct_command(self) -> None:
|
||||
"""Fill sends structured FILL command."""
|
||||
"""Fill executes JavaScript with React-compatible value setter via eval_js."""
|
||||
mock_client = MagicMock(spec=ExtensionClient)
|
||||
mock_client.send_message = AsyncMock()
|
||||
page = ExtensionPage(mock_client)
|
||||
|
||||
with patch.object(page, "_send_command", new_callable=AsyncMock) as mock_send:
|
||||
mock_send.return_value = {"success": True}
|
||||
await page.fill("input[name='email']", "test@example.com")
|
||||
# Direct assignment to instance for proper method interception
|
||||
mock_eval = AsyncMock(return_value=True)
|
||||
page.eval_js = mock_eval # type: ignore[method-assign]
|
||||
await page.fill("input[name='email']", "test@example.com")
|
||||
|
||||
mock_send.assert_called_once_with(
|
||||
"FILL", {"selector": "input[name='email']", "value": "test@example.com"}
|
||||
)
|
||||
mock_eval.assert_called_once()
|
||||
js_code = mock_eval.call_args[0][0]
|
||||
assert "querySelector" in js_code
|
||||
# Selector is escaped for JS strings: input[name='email'] -> input[name=\'email\']
|
||||
assert "input[name=" in js_code and "email" in js_code
|
||||
assert "test@example.com" in js_code
|
||||
assert "nativeSetter" in js_code # React-compatible setter
|
||||
|
||||
async def test_type_sends_fill_command(self) -> None:
|
||||
"""Type method sends FILL command (same as fill)."""
|
||||
"""Type method uses fill implementation (same as fill)."""
|
||||
mock_client = MagicMock(spec=ExtensionClient)
|
||||
mock_client.send_message = AsyncMock()
|
||||
page = ExtensionPage(mock_client)
|
||||
|
||||
with patch.object(page, "_send_command", new_callable=AsyncMock) as mock_send:
|
||||
mock_send.return_value = {"success": True}
|
||||
await page.type("input.search", "query text")
|
||||
# Direct assignment to instance for proper method interception
|
||||
mock_eval = AsyncMock(return_value=True)
|
||||
page.eval_js = mock_eval # type: ignore[method-assign]
|
||||
await page.type("input.search", "query text")
|
||||
|
||||
mock_send.assert_called_once_with(
|
||||
"FILL", {"selector": "input.search", "value": "query text"}
|
||||
)
|
||||
mock_eval.assert_called_once()
|
||||
js_code = mock_eval.call_args[0][0]
|
||||
assert "querySelector" in js_code
|
||||
assert "input.search" in js_code
|
||||
assert "query text" in js_code
|
||||
|
||||
|
||||
class TestExtensionPageWaitForSelector:
|
||||
"""Test the wait_for_selector method using WAIT_FOR_SELECTOR action."""
|
||||
"""Test the wait_for_selector method using eval_js polling."""
|
||||
|
||||
async def test_wait_for_selector_default_params(self) -> None:
|
||||
"""Wait for selector with default parameters."""
|
||||
"""Wait for selector with default parameters finds element immediately."""
|
||||
mock_client = MagicMock(spec=ExtensionClient)
|
||||
mock_client.send_message = AsyncMock()
|
||||
page = ExtensionPage(mock_client)
|
||||
|
||||
with patch.object(page, "_send_command", new_callable=AsyncMock) as mock_send:
|
||||
mock_send.return_value = {"found": True}
|
||||
await page.wait_for_selector(".loading")
|
||||
# Direct assignment to instance for proper method interception
|
||||
mock_eval = AsyncMock(return_value={"exists": True, "visible": True})
|
||||
page.eval_js = mock_eval # type: ignore[method-assign]
|
||||
result = await page.wait_for_selector(".loading")
|
||||
|
||||
mock_send.assert_called_once()
|
||||
call_args = mock_send.call_args[0]
|
||||
assert call_args[0] == "WAIT_FOR_SELECTOR"
|
||||
assert call_args[1]["selector"] == ".loading"
|
||||
assert call_args[1]["state"] == "attached"
|
||||
assert call_args[1]["timeout"] == 5000
|
||||
assert result is None # Returns None for Playwright compatibility
|
||||
mock_eval.assert_called_once()
|
||||
js_code = mock_eval.call_args[0][0]
|
||||
assert ".loading" in js_code
|
||||
assert "querySelector" in js_code
|
||||
|
||||
async def test_wait_for_selector_custom_timeout(self) -> None:
|
||||
"""Wait for selector with custom timeout."""
|
||||
"""Wait for selector with custom timeout retries after delay."""
|
||||
mock_client = MagicMock(spec=ExtensionClient)
|
||||
mock_client.send_message = AsyncMock()
|
||||
page = ExtensionPage(mock_client)
|
||||
|
||||
with patch.object(page, "_send_command", new_callable=AsyncMock) as mock_send:
|
||||
mock_send.return_value = {"found": True}
|
||||
_ = await page.wait_for_selector(".loading", timeout=10000)
|
||||
call_count = 0
|
||||
|
||||
call_args = mock_send.call_args[0][1]
|
||||
assert call_args["timeout"] == 10000
|
||||
async def mock_eval_side_effect(js_code: str, await_promise: bool = True) -> dict[str, bool] | bool:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
_ = await_promise # Unused but required for signature match
|
||||
if call_count == 1:
|
||||
return {"exists": False, "visible": False} # Not found first time
|
||||
return True # Found on retry
|
||||
|
||||
# Direct assignment to instance for proper method interception
|
||||
mock_eval = AsyncMock(side_effect=mock_eval_side_effect)
|
||||
page.eval_js = mock_eval # type: ignore[method-assign]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await page.wait_for_selector(".loading", timeout=100)
|
||||
|
||||
assert result is None # Returns None for Playwright compatibility
|
||||
assert mock_eval.call_count == 2 # Initial check + retry
|
||||
|
||||
async def test_wait_for_selector_visible_state(self) -> None:
|
||||
"""Wait for selector with visible state."""
|
||||
"""Wait for selector with visible state checks visibility."""
|
||||
mock_client = MagicMock(spec=ExtensionClient)
|
||||
mock_client.send_message = AsyncMock()
|
||||
page = ExtensionPage(mock_client)
|
||||
|
||||
with patch.object(page, "_send_command", new_callable=AsyncMock) as mock_send:
|
||||
mock_send.return_value = {"found": True}
|
||||
_ = await page.wait_for_selector(".modal", state="visible")
|
||||
# Direct assignment to instance for proper method interception
|
||||
mock_eval = AsyncMock(return_value={"exists": True, "visible": True})
|
||||
page.eval_js = mock_eval # type: ignore[method-assign]
|
||||
result = await page.wait_for_selector(".modal", state="visible")
|
||||
|
||||
call_args = mock_send.call_args[0][1]
|
||||
assert call_args["state"] == "visible"
|
||||
assert result is None # Returns None for Playwright compatibility
|
||||
mock_eval.assert_called_once()
|
||||
js_code = mock_eval.call_args[0][0]
|
||||
assert ".modal" in js_code
|
||||
assert "isVisible" in js_code # Visibility check in JS
|
||||
|
||||
|
||||
class TestExtensionPageClickElementWithText:
|
||||
|
||||
@@ -140,6 +140,20 @@ class TestBrowserHostConfigValidation:
|
||||
assert config.kind == HostKind.CDP
|
||||
assert config.host == "localhost"
|
||||
assert config.port == 9222
|
||||
assert config.cdp_url is None
|
||||
|
||||
def test_cdp_host_config_with_explicit_url(self) -> None:
|
||||
"""Test creating CDP host with explicit websocket endpoint."""
|
||||
from guide.app.core.config import BrowserHostConfig, HostKind
|
||||
|
||||
config = BrowserHostConfig(
|
||||
id="browserless",
|
||||
kind=HostKind.CDP,
|
||||
cdp_url="ws://browserless.lab:80/",
|
||||
)
|
||||
assert config.cdp_url == "ws://browserless.lab:80/"
|
||||
assert config.host is None
|
||||
assert config.port is None
|
||||
|
||||
def test_host_kind_string_coercion(self) -> None:
|
||||
"""Test that HostKind accepts string values."""
|
||||
@@ -173,6 +187,6 @@ class TestAppSettingsDefaults:
|
||||
settings = AppSettings()
|
||||
assert settings.raindrop_base_url == "https://stg.raindrop.com"
|
||||
assert settings.raindrop_graphql_url == "https://raindrop-staging.hasura.app/v1/graphql"
|
||||
assert settings.default_browser_host_id == "demo-cdp"
|
||||
assert settings.default_browser_host_id == "browserless-cdp"
|
||||
assert isinstance(settings.browser_hosts, dict)
|
||||
assert isinstance(settings.personas, dict)
|
||||
|
||||
Reference in New Issue
Block a user