This commit is contained in:
2025-11-22 01:47:12 +00:00
parent b79a545a4d
commit 94f9051aee
12 changed files with 370 additions and 329 deletions

View File

@@ -1,5 +1,5 @@
from collections.abc import Iterable
from typing import ClassVar, Protocol, runtime_checkable
from collections.abc import Iterable, Mapping
from typing import Any, ClassVar, Protocol, runtime_checkable
from playwright.async_api import Page
@@ -16,22 +16,118 @@ class DemoAction(Protocol):
async def run(self, page: Page, context: ActionContext) -> ActionResult: ...
class ActionRegistry:
"""Simple mapping of action id -> action instance."""
# Global registry of action classes
_REGISTERED_ACTIONS: dict[str, type[DemoAction]] = {}
def __init__(self, actions: Iterable[DemoAction]):
self._actions: dict[str, DemoAction] = {action.id: action for action in actions}
def register_action(cls: type[DemoAction]) -> type[DemoAction]:
"""Decorator to register an action class to the global action registry.
Usage:
@register_action
class MyAction(DemoAction):
id = "my-action"
...
"""
_REGISTERED_ACTIONS[cls.id] = cls
return cls
def get_registered_actions() -> Mapping[str, type[DemoAction]]:
"""Return a mapping of all registered action classes by id."""
return _REGISTERED_ACTIONS.copy()
class ActionRegistry:
"""Manages action instances and metadata.
Supports both explicit action registration and dynamically-registered actions
via the @register_action decorator.
"""
def __init__(
self,
actions: Iterable[DemoAction] | None = None,
action_factories: dict[str, Any] | None = None,
) -> None:
"""Initialize registry with action instances and/or factory functions.
Args:
actions: Iterable of action instances
action_factories: Dict mapping action ids to factory functions that create instances
"""
self._actions: dict[str, DemoAction] = {}
self._factories: dict[str, Any] = action_factories or {}
if actions:
for action in actions:
self._actions[action.id] = action
def register(self, action: DemoAction) -> None:
"""Register an action instance."""
self._actions[action.id] = action
def get(self, action_id: str) -> DemoAction:
if action_id not in self._actions:
raise errors.ActionExecutionError(f"Unknown action '{action_id}'")
return self._actions[action_id]
"""Retrieve an action by id.
Searches in order:
1. Explicitly registered action instances
2. Factory functions
3. Globally registered action classes (from @register_action decorator)
Raises ActionExecutionError if the action is not found.
"""
if action_id in self._actions:
return self._actions[action_id]
if action_id in self._factories:
action = self._factories[action_id]()
self._actions[action_id] = action
return action
# Check globally registered actions
registered_classes = get_registered_actions()
if action_id in registered_classes:
action_cls = registered_classes[action_id]
action = action_cls()
self._actions[action_id] = action
return action
raise errors.ActionExecutionError(f"Unknown action '{action_id}'")
def list_metadata(self) -> list[ActionMetadata]:
return [
ActionMetadata(id=action.id, description=action.description, category=action.category)
for action in self._actions.values()
]
"""List metadata for all registered actions.
Includes:
- Explicitly registered action instances
- Actions from factory functions
- Globally registered action classes (from @register_action decorator)
"""
seen_ids: set[str] = set()
metadata: list[ActionMetadata] = []
# Add explicit instances
for action in self._actions.values():
metadata.append(
ActionMetadata(id=action.id, description=action.description, category=action.category)
)
seen_ids.add(action.id)
# Add factory functions
for factory in self._factories.values():
action = factory()
if action.id not in seen_ids:
metadata.append(
ActionMetadata(id=action.id, description=action.description, category=action.category)
)
seen_ids.add(action.id)
# Add globally registered actions
for action_cls in get_registered_actions().values():
if action_cls.id not in seen_ids:
action = action_cls()
metadata.append(
ActionMetadata(id=action.id, description=action.description, category=action.category)
)
seen_ids.add(action.id)
return metadata

View File

@@ -2,11 +2,12 @@ from playwright.async_api import Page
from typing import ClassVar, override
from guide.app.actions.base import DemoAction
from guide.app.actions.base import DemoAction, register_action
from guide.app.models.domain import ActionContext, ActionResult
from guide.app.strings.service import strings
from guide.app.strings.registry import app_strings
@register_action
class FillIntakeBasicAction(DemoAction):
id: ClassVar[str] = "fill-intake-basic"
description: ClassVar[str] = "Fill the intake description and advance to the next step."
@@ -14,9 +15,7 @@ class FillIntakeBasicAction(DemoAction):
@override
async def run(self, page: Page, context: ActionContext) -> ActionResult:
description_val = strings.text("INTAKE", "CONVEYOR_BELT_REQUEST")
if not isinstance(description_val, str):
raise ValueError("INTAKE.CONVEYOR_BELT_REQUEST must be a string")
await page.fill(strings.selector("INTAKE", "DESCRIPTION_FIELD"), description_val)
await page.click(strings.selector("INTAKE", "NEXT_BUTTON"))
description_val = app_strings.intake.texts.conveyor_belt_request
await page.fill(app_strings.intake.selectors.description_field, description_val)
await page.click(app_strings.intake.selectors.next_button)
return ActionResult(details={"message": "Intake filled"})

View File

@@ -1,16 +1,33 @@
from guide.app.actions.auth.login import LoginAsPersonaAction
from guide.app.actions.base import ActionRegistry, DemoAction
from guide.app.actions.intake.basic import FillIntakeBasicAction
from guide.app.actions.sourcing.add_suppliers import AddThreeSuppliersAction
from guide.app.models.personas import PersonaStore
def _load_registered_actions() -> None:
"""Load all action modules to trigger @register_action decorators.
This function must be called to ensure all action classes are registered.
"""
import guide.app.actions.intake.basic
import guide.app.actions.sourcing.add_suppliers
# Keep module names alive for registration side effects
assert guide.app.actions.intake.basic is not None
assert guide.app.actions.sourcing.add_suppliers is not None
def default_registry(persona_store: PersonaStore, login_url: str) -> ActionRegistry:
"""Create the default action registry with all registered actions.
Actions that require dependency injection (like LoginAsPersonaAction) are
explicitly instantiated here. Actions decorated with @register_action are
automatically discovered and instantiated by the registry.
"""
_load_registered_actions()
actions: list[DemoAction] = [
LoginAsPersonaAction(persona_store, login_url),
FillIntakeBasicAction(),
AddThreeSuppliersAction(),
]
return ActionRegistry(actions)
__all__ = ["default_registry", "ActionRegistry", "DemoAction"]

View File

@@ -2,11 +2,12 @@ from playwright.async_api import Page
from typing import ClassVar, override
from guide.app.actions.base import DemoAction
from guide.app.actions.base import DemoAction, register_action
from guide.app.models.domain import ActionContext, ActionResult
from guide.app.strings.service import strings
from guide.app.strings.registry import app_strings
@register_action
class AddThreeSuppliersAction(DemoAction):
id: ClassVar[str] = "add-three-suppliers"
description: ClassVar[str] = "Adds three default suppliers to the sourcing event."
@@ -14,11 +15,8 @@ class AddThreeSuppliersAction(DemoAction):
@override
async def run(self, page: Page, context: ActionContext) -> ActionResult:
suppliers_val = strings.text("SUPPLIERS", "DEFAULT_TRIO")
if not isinstance(suppliers_val, list):
raise ValueError("SUPPLIERS.DEFAULT_TRIO must be a list of strings")
suppliers: list[str] = list(suppliers_val)
suppliers = app_strings.sourcing.texts.default_trio
for supplier in suppliers:
await page.fill(strings.selector("SOURCING", "SUPPLIER_SEARCH_INPUT"), supplier)
await page.click(strings.selector("SOURCING", "ADD_SUPPLIER_BUTTON"))
await page.fill(app_strings.sourcing.selectors.supplier_search_input, supplier)
await page.click(app_strings.sourcing.selectors.add_supplier_button)
return ActionResult(details={"added_suppliers": list(suppliers)})

View File

@@ -67,7 +67,7 @@ async def execute_action(
persona = personas.get(payload.persona_id) if payload.persona_id else None
target_host_id = payload.browser_host_id or (persona.browser_host_id if persona else None)
target_host_id = target_host_id or browser_client.settings.default_browser_host_id
target_host_id = target_host_id or settings.default_browser_host_id
context = ActionContext(
action_id=action_id,

View File

@@ -2,19 +2,18 @@ from playwright.async_api import Page
from guide.app.auth.mfa import MfaCodeProvider
from guide.app.models.personas.models import DemoPersona
from guide.app.strings.service import strings
from guide.app.strings.registry import app_strings
async def detect_current_persona(page: Page) -> str | None:
"""Return the email/identifier of the currently signed-in user, if visible."""
current_selector = strings.selector("AUTH", "CURRENT_USER_DISPLAY")
element = page.locator(current_selector)
element = page.locator(app_strings.auth.selectors.current_user_display)
if await element.count() == 0:
return None
text = await element.first.text_content()
if text is None:
return None
prefix = strings.label("AUTH", "CURRENT_USER_DISPLAY_PREFIX")
prefix = app_strings.auth.labels.current_user_display_prefix
if prefix and text.startswith(prefix):
return text.removeprefix(prefix).strip()
return text.strip()
@@ -24,16 +23,15 @@ async def login_with_mfa(page: Page, email: str, mfa_provider: MfaCodeProvider,
if login_url:
_response = await page.goto(login_url)
del _response
await page.fill(strings.selector("AUTH", "EMAIL_INPUT"), email)
await page.click(strings.selector("AUTH", "SEND_CODE_BUTTON"))
await page.fill(app_strings.auth.selectors.email_input, email)
await page.click(app_strings.auth.selectors.send_code_button)
code = mfa_provider.get_code(email)
await page.fill(strings.selector("AUTH", "CODE_INPUT"), code)
await page.click(strings.selector("AUTH", "SUBMIT_BUTTON"))
await page.fill(app_strings.auth.selectors.code_input, code)
await page.click(app_strings.auth.selectors.submit_button)
async def logout(page: Page) -> None:
logout_selector = strings.selector("AUTH", "LOGOUT_BUTTON")
await page.click(logout_selector)
await page.click(app_strings.auth.selectors.logout_button)
async def ensure_persona(page: Page, persona: DemoPersona, mfa_provider: MfaCodeProvider, login_url: str | None = None) -> None:

View File

@@ -1,88 +1,42 @@
import contextlib
from collections.abc import AsyncIterator
from playwright.async_api import Browser, BrowserContext, Page, Playwright, async_playwright
from playwright.async_api import Page
from guide.app.core.config import AppSettings, BrowserHostConfig, HostKind
from guide.app import errors
from guide.app.browser.pool import BrowserPool
class BrowserClient:
"""Connector that yields a Playwright Page for either CDP or headless hosts."""
"""Provides page access via a persistent browser pool.
settings: AppSettings
This client uses the BrowserPool to efficiently manage connections.
Instead of opening/closing browsers per request, the pool maintains
long-lived connections and allocates pages on demand.
"""
def __init__(self, settings: AppSettings) -> None:
self.settings = settings
def __init__(self, pool: BrowserPool) -> None:
"""Initialize with a browser pool.
def _resolve_host(self, host_id: str | None) -> BrowserHostConfig:
resolved_id = host_id or self.settings.default_browser_host_id
host = self.settings.browser_hosts.get(resolved_id)
if not host:
known = ", ".join(self.settings.browser_hosts.keys()) or "<none>"
raise errors.ConfigError(f"Unknown browser host '{resolved_id}'. Known: {known}")
return host
Args:
pool: The BrowserPool instance to use
"""
self.pool = pool
@contextlib.asynccontextmanager
async def open_page(self, host_id: str | None = None) -> AsyncIterator[Page]:
host = self._resolve_host(host_id)
playwright = await async_playwright().start()
browser: Browser | None = None
context: BrowserContext | None = None
try:
if host.kind == HostKind.CDP:
if not host.host or host.port is None:
raise errors.ConfigError("CDP host requires host and port fields.")
cdp_url = f"http://{host.host}:{host.port}"
try:
browser = await playwright.chromium.connect_over_cdp(cdp_url)
except Exception as exc: # pragma: no cover - network dependent
raise errors.BrowserConnectionError(
f"Cannot connect to CDP endpoint {cdp_url}", details={"host_id": host.id}
) from exc
page = self._pick_raindrop_page(browser)
if not page:
raise errors.BrowserConnectionError(
"No Raindrop page found in connected browser.", details={"host_id": host.id}
)
else:
browser_type = _resolve_browser_type(playwright, host.browser)
browser = await browser_type.launch(headless=True)
context = await browser.new_context()
page = await context.new_page()
yield page
finally:
with contextlib.suppress(Exception):
if context:
await context.close()
with contextlib.suppress(Exception):
if browser:
await browser.close()
with contextlib.suppress(Exception):
await playwright.stop()
"""Get a page from the pool for the specified host.
def _pick_raindrop_page(self, browser: Browser) -> Page | None:
target_substr = self.settings.raindrop_base_url
pages: list[Page] = []
for context in browser.contexts:
pages.extend(context.pages)
pages = pages or list(browser.contexts[0].pages) if browser.contexts else []
return next(
(
page
for page in reversed(pages)
if target_substr in (page.url or "")
),
pages[-1] if pages else None,
)
The page is obtained from the pool's persistent browser connection.
No browser startup/connection overhead occurs on each request.
Args:
host_id: The host identifier, or None for the default host
Yields:
A Playwright Page instance
"""
page = await self.pool.get_page(host_id)
yield page
def _resolve_browser_type(playwright: Playwright, browser: str | None):
desired = (browser or "chromium").lower()
if desired == "chromium":
return playwright.chromium
if desired == "firefox":
return playwright.firefox
if desired == "webkit":
return playwright.webkit
raise errors.ConfigError(f"Unsupported headless browser type '{browser}'")
__all__ = ["BrowserClient"]

View File

@@ -1,4 +1,5 @@
import json
import logging
import os
from enum import Enum
from pathlib import Path
@@ -8,6 +9,8 @@ from typing import ClassVar, TypeAlias, TypeGuard, cast
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings, SettingsConfigDict
_logger = logging.getLogger(__name__)
CONFIG_DIR = Path(__file__).resolve().parents[4] / "config"
HOSTS_FILE = CONFIG_DIR / "hosts.yaml"
PERSONAS_FILE = CONFIG_DIR / "personas.yaml"
@@ -17,19 +20,25 @@ RecordList: TypeAlias = list[JsonRecord]
def _coerce_mapping(mapping: Mapping[object, object]) -> dict[str, object]:
"""Convert mapping keys to strings."""
return {str(key): value for key, value in mapping.items()}
def _is_object_mapping(value: object) -> TypeGuard[Mapping[object, object]]:
"""Check if value is a Mapping."""
return isinstance(value, Mapping)
class HostKind(str, Enum):
"""Browser host kind: CDP or headless."""
CDP = "cdp"
HEADLESS = "headless"
class BrowserHostConfig(BaseModel):
"""Configuration for a browser host (CDP or headless)."""
id: str
kind: HostKind
host: str | None = None
@@ -38,6 +47,8 @@ class BrowserHostConfig(BaseModel):
class PersonaConfig(BaseModel):
"""Configuration for a demo persona."""
id: str
role: str
email: str
@@ -46,7 +57,20 @@ class PersonaConfig(BaseModel):
class AppSettings(BaseSettings):
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(env_prefix="RAINDROP_DEMO_")
"""Application settings loaded from YAML files + environment variables.
Configuration sources (in order):
1. Environment variables with prefix RAINDROP_DEMO_
2. YAML files: config/hosts.yaml, config/personas.yaml
3. JSON overrides via RAINDROP_DEMO_BROWSER_HOSTS_JSON, RAINDROP_DEMO_PERSONAS_JSON
The JSON overrides take precedence over file-based config.
"""
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
env_prefix="RAINDROP_DEMO_",
case_sensitive=False,
)
raindrop_base_url: str = "https://app.raindrop.com"
raindrop_graphql_url: str = "https://app.raindrop.com/graphql"
@@ -55,7 +79,8 @@ class AppSettings(BaseSettings):
personas: dict[str, PersonaConfig] = Field(default_factory=dict)
def _load_hosts_file(path: Path) -> dict[str, BrowserHostConfig]:
def _load_yaml_file(path: Path) -> dict[str, object]:
"""Load YAML file, handling missing PyYAML gracefully."""
if not path.exists():
return {}
@@ -63,123 +88,118 @@ def _load_hosts_file(path: Path) -> dict[str, BrowserHostConfig]:
import yaml
except ModuleNotFoundError as exc:
raise RuntimeError(
"hosts.yaml found but PyYAML is not installed. Add 'pyyaml' to dependencies or remove hosts.yaml."
f"{path.name} found but PyYAML is not installed. Add 'pyyaml' to dependencies."
) from exc
data_raw: object = yaml.safe_load(path.read_text()) or {}
hosts: dict[str, BrowserHostConfig] = {}
for item in _normalize_records(data_raw, key_name="hosts"):
host = BrowserHostConfig.model_validate(item)
hosts[host.id] = host
return hosts
return cast(dict[str, object], yaml.safe_load(path.read_text()) or {})
def _parse_json_hosts(value: str) -> dict[str, BrowserHostConfig]:
decoded_raw: object = cast(object, json.loads(value))
decoded = _ensure_record_list(decoded_raw, "RAINDROP_DEMO_BROWSER_HOSTS_JSON")
hosts: dict[str, BrowserHostConfig] = {}
for item in decoded:
host = BrowserHostConfig.model_validate(item)
hosts[host.id] = host
return hosts
def _parse_json_personas(value: str) -> dict[str, PersonaConfig]:
decoded_raw: object = cast(object, json.loads(value))
decoded = _ensure_record_list(decoded_raw, "RAINDROP_DEMO_PERSONAS_JSON")
personas: dict[str, PersonaConfig] = {}
for item in decoded:
persona = PersonaConfig.model_validate(item)
personas[persona.id] = persona
return personas
def _load_personas_file(path: Path) -> dict[str, PersonaConfig]:
if not path.exists():
return {}
try:
import yaml
except ModuleNotFoundError as exc:
raise RuntimeError(
"personas.yaml found but PyYAML is not installed. Add 'pyyaml' to dependencies or remove personas.yaml."
) from exc
data_raw: object = yaml.safe_load(path.read_text()) or {}
personas: dict[str, PersonaConfig] = {}
for item in _normalize_records(data_raw, key_name="personas"):
persona = PersonaConfig.model_validate(item)
personas[persona.id] = persona
return personas
def _normalize_records(data_raw: object, key_name: str) -> RecordList:
if isinstance(data_raw, Mapping):
mapping_raw: dict[str, object] = _coerce_mapping(cast(Mapping[object, object], data_raw))
content: object = mapping_raw.get(key_name, mapping_raw)
def _normalize_host_records(data: object) -> RecordList:
"""Normalize host records from YAML or list format."""
if isinstance(data, Mapping):
mapping = _coerce_mapping(cast(Mapping[object, object], data))
content = mapping.get("hosts", mapping)
else:
content = data_raw
content = data
records: RecordList = []
if isinstance(content, Mapping):
mapping_content: dict[str, object] = _coerce_mapping(cast(Mapping[object, object], content))
records: RecordList = []
mapping_content = _coerce_mapping(cast(Mapping[object, object], content))
for key, value in mapping_content.items():
mapping = _ensure_mapping(value)
if "id" not in mapping:
mapping["id"] = key
records.append(mapping)
return records
if isinstance(content, list):
list_content: list[object] = cast(list[object], content)
list_records: RecordList = []
for item in list_content:
record = _coerce_mapping(cast(Mapping[object, object], value)) if _is_object_mapping(value) else {}
if "id" not in record:
record["id"] = key
records.append(record)
elif isinstance(content, list):
for item in cast(list[object], content):
if _is_object_mapping(item):
mapping_item = _coerce_mapping(item)
list_records.append(dict(mapping_item))
return list_records
records.append(_coerce_mapping(cast(Mapping[object, object], item)))
return []
return records
def _ensure_mapping(value: object) -> JsonRecord:
if isinstance(value, Mapping):
return _coerce_mapping(cast(Mapping[object, object], value))
return {}
def _normalize_persona_records(data: object) -> RecordList:
"""Normalize persona records from YAML or list format."""
if isinstance(data, Mapping):
mapping = _coerce_mapping(cast(Mapping[object, object], data))
content = mapping.get("personas", mapping)
else:
content = data
def _ensure_record_list(value: object, source_label: str) -> RecordList:
if isinstance(value, list):
list_content: list[object] = cast(list[object], value)
list_records: RecordList = []
for item in list_content:
records: RecordList = []
if isinstance(content, Mapping):
mapping_content = _coerce_mapping(cast(Mapping[object, object], content))
for key, value in mapping_content.items():
record = _coerce_mapping(cast(Mapping[object, object], value)) if _is_object_mapping(value) else {}
if "id" not in record:
record["id"] = key
records.append(record)
elif isinstance(content, list):
for item in cast(list[object], content):
if _is_object_mapping(item):
mapping_item = _coerce_mapping(item)
list_records.append(dict(mapping_item))
return list_records
raise ValueError(f"{source_label} must be a JSON array of objects")
records.append(_coerce_mapping(cast(Mapping[object, object], item)))
return records
def load_settings() -> AppSettings:
"""Load application settings from files and environment.
Configuration is loaded in this order (later overrides earlier):
1. Default values in AppSettings model
2. YAML files (config/hosts.yaml, config/personas.yaml)
3. Environment variables (RAINDROP_DEMO_* prefix)
4. JSON overrides via environment (RAINDROP_DEMO_BROWSER_HOSTS_JSON, etc.)
"""
settings = AppSettings()
merged_hosts: dict[str, BrowserHostConfig] = {}
merged_personas: dict[str, PersonaConfig] = {}
merged_hosts |= _load_hosts_file(HOSTS_FILE)
merged_personas |= _load_personas_file(PERSONAS_FILE)
# Load from YAML files
hosts_data = _load_yaml_file(HOSTS_FILE)
personas_data = _load_yaml_file(PERSONAS_FILE)
if env_json := os.environ.get("RAINDROP_DEMO_BROWSER_HOSTS_JSON"):
merged_hosts.update(_parse_json_hosts(env_json))
hosts_dict: dict[str, BrowserHostConfig] = {}
for record in _normalize_host_records(hosts_data):
host = BrowserHostConfig.model_validate(record)
hosts_dict[host.id] = host
if env_personas := os.environ.get("RAINDROP_DEMO_PERSONAS_JSON"):
merged_personas.update(_parse_json_personas(env_personas))
personas_dict: dict[str, PersonaConfig] = {}
for record in _normalize_persona_records(personas_data):
persona = PersonaConfig.model_validate(record)
personas_dict[persona.id] = persona
if merged_hosts or merged_personas:
# Load JSON overrides from environment
if browser_hosts_json := os.environ.get("RAINDROP_DEMO_BROWSER_HOSTS_JSON"):
try:
decoded_hosts: object = json.loads(browser_hosts_json)
if not isinstance(decoded_hosts, list):
raise ValueError("RAINDROP_DEMO_BROWSER_HOSTS_JSON must be a JSON array")
for record in cast(list[object], decoded_hosts):
if _is_object_mapping(record):
host = BrowserHostConfig.model_validate(record)
hosts_dict[host.id] = host
except (json.JSONDecodeError, ValueError) as exc:
_logger.warning(f"Failed to parse RAINDROP_DEMO_BROWSER_HOSTS_JSON: {exc}")
if personas_json := os.environ.get("RAINDROP_DEMO_PERSONAS_JSON"):
try:
decoded_personas: object = json.loads(personas_json)
if not isinstance(decoded_personas, list):
raise ValueError("RAINDROP_DEMO_PERSONAS_JSON must be a JSON array")
for record in cast(list[object], decoded_personas):
if _is_object_mapping(record):
persona = PersonaConfig.model_validate(record)
personas_dict[persona.id] = persona
except (json.JSONDecodeError, ValueError) as exc:
_logger.warning(f"Failed to parse RAINDROP_DEMO_PERSONAS_JSON: {exc}")
# Update settings with loaded configuration
if hosts_dict or personas_dict:
settings = settings.model_copy(
update={
"browser_hosts": merged_hosts or settings.browser_hosts,
"personas": merged_personas or settings.personas,
"browser_hosts": hosts_dict or settings.browser_hosts,
"personas": personas_dict or settings.personas,
}
)
_logger.info(f"Loaded {len(settings.browser_hosts)} browser hosts, {len(settings.personas)} personas")
return settings

View File

@@ -1,15 +1,16 @@
from fastapi import FastAPI
from fastapi import Request
from fastapi.exception_handlers import http_exception_handler
from fastapi.exceptions import HTTPException
from guide.app.actions.registry import default_registry
from guide.app.browser.client import BrowserClient
from guide.app.browser.pool import BrowserPool
from guide.app.core.config import AppSettings, load_settings
from guide.app.core.logging import configure_logging
from guide.app.api import router as api_router
from guide.app import errors
from guide.app.models.personas import PersonaStore
from fastapi import Request
from fastapi.exception_handlers import http_exception_handler
from fastapi.exceptions import HTTPException
def create_app() -> FastAPI:
@@ -18,16 +19,30 @@ def create_app() -> FastAPI:
settings = load_settings()
persona_store = PersonaStore(settings)
registry = default_registry(persona_store, settings.raindrop_base_url)
browser_client = BrowserClient(settings)
# Create browser pool for efficient connection management
browser_pool = BrowserPool(settings)
browser_client = BrowserClient(browser_pool)
app = FastAPI(title="Raindrop Demo Automation", version="0.1.0")
app.state.settings = settings
app.state.action_registry = registry
app.state.browser_client = browser_client
app.state.browser_pool = browser_pool
app.state.persona_store = persona_store
# Dependency overrides so FastAPI deps can pull settings without globals
app.dependency_overrides = {AppSettings: lambda: settings}
# Startup/shutdown lifecycle for browser pool
@app.on_event("startup")
async def startup_browser_pool() -> None:
await browser_pool.initialize()
@app.on_event("shutdown")
async def shutdown_browser_pool() -> None:
await browser_pool.close()
app.include_router(api_router)
app.add_exception_handler(errors.GuideError, guide_exception_handler)
app.add_exception_handler(Exception, general_exception_handler)

View File

@@ -1,12 +1,23 @@
from guide.app.models.personas.models import DemoPersona
from guide.app.models.types import JSONValue
from guide.app.raindrop.graphql import GraphQLClient
from guide.app.raindrop.types import GetIntakeRequestResponse, CreateIntakeRequestResponse, IntakeRequestData
from guide.app.strings import graphql as gql_strings
async def get_intake_request(
client: GraphQLClient, persona: DemoPersona, request_id: str
) -> dict[str, JSONValue]:
) -> IntakeRequestData:
"""Fetch an intake request by ID.
Args:
client: GraphQL client
persona: Persona making the request
request_id: The intake request ID
Returns:
IntakeRequestData: Type-safe response data
"""
variables: dict[str, JSONValue] = {"id": request_id}
data = await client.execute(
query=gql_strings.GET_INTAKE_REQUEST,
@@ -14,13 +25,25 @@ async def get_intake_request(
persona=persona,
operation_name="GetIntakeRequest",
)
result = data.get("intakeRequest")
return result if isinstance(result, dict) else {}
response = GetIntakeRequestResponse.model_validate(data)
if not response.intake_request:
raise ValueError(f"No intake request found with id: {request_id}")
return response.intake_request
async def create_intake_request(
client: GraphQLClient, persona: DemoPersona, payload: dict[str, JSONValue]
) -> dict[str, JSONValue]:
) -> IntakeRequestData:
"""Create a new intake request.
Args:
client: GraphQL client
persona: Persona making the request
payload: Intake request input data
Returns:
IntakeRequestData: The created intake request
"""
variables: dict[str, JSONValue] = {"input": payload}
data = await client.execute(
query=gql_strings.CREATE_INTAKE_REQUEST,
@@ -28,5 +51,7 @@ async def create_intake_request(
persona=persona,
operation_name="CreateIntakeRequest",
)
result = data.get("createIntakeRequest")
return result if isinstance(result, dict) else {}
response = CreateIntakeRequestResponse.model_validate(data)
if not response.create_intake_request:
raise ValueError("Failed to create intake request")
return response.create_intake_request

View File

@@ -1,10 +1,21 @@
from guide.app.models.personas.models import DemoPersona
from guide.app.models.types import JSONValue
from guide.app.raindrop.graphql import GraphQLClient
from guide.app.raindrop.types import ListSuppliersResponse, AddSupplierResponse, SupplierData
from guide.app.strings import graphql as gql_strings
async def list_suppliers(client: GraphQLClient, persona: DemoPersona, limit: int = 10) -> list[dict[str, JSONValue]]:
async def list_suppliers(client: GraphQLClient, persona: DemoPersona, limit: int = 10) -> list[SupplierData]:
"""Fetch a list of suppliers.
Args:
client: GraphQL client
persona: Persona making the request
limit: Maximum number of suppliers to return
Returns:
list[SupplierData]: Type-safe list of suppliers
"""
variables: dict[str, JSONValue] = {"limit": limit}
data = await client.execute(
query=gql_strings.LIST_SUPPLIERS,
@@ -12,19 +23,23 @@ async def list_suppliers(client: GraphQLClient, persona: DemoPersona, limit: int
persona=persona,
operation_name="ListSuppliers",
)
suppliers = data.get("suppliers")
if not isinstance(suppliers, list):
return []
filtered: list[dict[str, JSONValue]] = []
for item in suppliers:
if isinstance(item, dict):
filtered.append(item)
return filtered
response = ListSuppliersResponse.model_validate(data)
return response.suppliers
async def add_supplier(
client: GraphQLClient, persona: DemoPersona, supplier: dict[str, JSONValue]
) -> dict[str, JSONValue]:
) -> SupplierData:
"""Add a supplier to the sourcing event.
Args:
client: GraphQL client
persona: Persona making the request
supplier: Supplier input data
Returns:
SupplierData: The created supplier
"""
variables: dict[str, JSONValue] = {"input": supplier}
data = await client.execute(
query=gql_strings.ADD_SUPPLIER,
@@ -32,5 +47,7 @@ async def add_supplier(
persona=persona,
operation_name="AddSupplier",
)
result = data.get("addSupplier")
return result if isinstance(result, dict) else {}
response = AddSupplierResponse.model_validate(data)
if not response.add_supplier:
raise ValueError("Failed to add supplier")
return response.add_supplier

View File

@@ -1,98 +0,0 @@
from typing import TypeAlias, cast
from guide.app.strings.graphql import (
ADD_SUPPLIER,
CREATE_INTAKE_REQUEST,
GET_INTAKE_REQUEST,
LIST_SUPPLIERS,
)
from guide.app.strings.demo_texts import DemoTexts, EventTexts, IntakeTexts, SupplierTexts
from guide.app.strings.labels import AuthLabels, IntakeLabels, Labels, SourcingLabels
from guide.app.strings.selectors import AuthSelectors, IntakeSelectors, NavigationSelectors, Selectors, SourcingSelectors
SelectorNamespace: TypeAlias = type[IntakeSelectors] | type[SourcingSelectors] | type[NavigationSelectors] | type[AuthSelectors]
LabelNamespace: TypeAlias = type[IntakeLabels] | type[SourcingLabels] | type[AuthLabels]
TextNamespace: TypeAlias = type[IntakeTexts] | type[SupplierTexts] | type[EventTexts]
TextValue: TypeAlias = str | list[str]
_SELECTORS: dict[str, SelectorNamespace] = {
"INTAKE": Selectors.INTAKE,
"SOURCING": Selectors.SOURCING,
"NAVIGATION": Selectors.NAVIGATION,
"AUTH": Selectors.AUTH,
}
_LABELS: dict[str, LabelNamespace] = {
"INTAKE": Labels.INTAKE,
"SOURCING": Labels.SOURCING,
"AUTH": Labels.AUTH,
}
_TEXTS: dict[str, TextNamespace] = {
"INTAKE": DemoTexts.INTAKE,
"SUPPLIERS": DemoTexts.SUPPLIERS,
"EVENTS": DemoTexts.EVENTS,
}
_GQL = {
"GET_INTAKE_REQUEST": GET_INTAKE_REQUEST,
"CREATE_INTAKE_REQUEST": CREATE_INTAKE_REQUEST,
"LIST_SUPPLIERS": LIST_SUPPLIERS,
"ADD_SUPPLIER": ADD_SUPPLIER,
}
class Strings:
"""Unified accessor for selectors, labels, texts, and GraphQL queries."""
@staticmethod
def selector(domain: str, name: str) -> str:
ns = _SELECTORS.get(domain.upper())
if ns is None:
raise KeyError(f"Unknown selector domain '{domain}'")
value = cast(str | None, getattr(ns, name, None))
return _as_str(value, f"selector {domain}.{name}")
@staticmethod
def label(domain: str, name: str) -> str:
ns = _LABELS.get(domain.upper())
if ns is None:
raise KeyError(f"Unknown label domain '{domain}'")
value = cast(str | None, getattr(ns, name, None))
return _as_str(value, f"label {domain}.{name}")
@staticmethod
def text(domain: str, name: str) -> TextValue:
ns = _TEXTS.get(domain.upper())
if ns is None:
raise KeyError(f"Unknown text domain '{domain}'")
value: object | None = getattr(ns, name, None)
if value is None:
raise KeyError(f"Unknown text {domain}.{name}")
if isinstance(value, str):
return value
if isinstance(value, list):
value_list: list[object] = cast(list[object], value)
str_values: list[str] = []
for item in value_list:
if not isinstance(item, str):
raise TypeError(f"text {domain}.{name} must be a string or list of strings")
str_values.append(item)
return str_values
raise TypeError(f"text {domain}.{name} must be a string or list of strings")
@staticmethod
def gql(name: str) -> str:
value = _GQL.get(name)
return _as_str(value, f"graphql string {name}")
def _as_str(value: object, label: str) -> str:
if not isinstance(value, str): # pragma: no cover
raise TypeError(f"{label} must be a string")
return value
strings = Strings()
__all__ = ["Strings", "strings"]