x
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import ClassVar, Protocol, runtime_checkable
|
||||
from collections.abc import Iterable, Mapping
|
||||
from typing import Any, ClassVar, Protocol, runtime_checkable
|
||||
|
||||
from playwright.async_api import Page
|
||||
|
||||
@@ -16,22 +16,118 @@ class DemoAction(Protocol):
|
||||
async def run(self, page: Page, context: ActionContext) -> ActionResult: ...
|
||||
|
||||
|
||||
class ActionRegistry:
|
||||
"""Simple mapping of action id -> action instance."""
|
||||
# Global registry of action classes
|
||||
_REGISTERED_ACTIONS: dict[str, type[DemoAction]] = {}
|
||||
|
||||
def __init__(self, actions: Iterable[DemoAction]):
|
||||
self._actions: dict[str, DemoAction] = {action.id: action for action in actions}
|
||||
|
||||
def register_action(cls: type[DemoAction]) -> type[DemoAction]:
|
||||
"""Decorator to register an action class to the global action registry.
|
||||
|
||||
Usage:
|
||||
@register_action
|
||||
class MyAction(DemoAction):
|
||||
id = "my-action"
|
||||
...
|
||||
"""
|
||||
_REGISTERED_ACTIONS[cls.id] = cls
|
||||
return cls
|
||||
|
||||
|
||||
def get_registered_actions() -> Mapping[str, type[DemoAction]]:
|
||||
"""Return a mapping of all registered action classes by id."""
|
||||
return _REGISTERED_ACTIONS.copy()
|
||||
|
||||
|
||||
class ActionRegistry:
|
||||
"""Manages action instances and metadata.
|
||||
|
||||
Supports both explicit action registration and dynamically-registered actions
|
||||
via the @register_action decorator.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actions: Iterable[DemoAction] | None = None,
|
||||
action_factories: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Initialize registry with action instances and/or factory functions.
|
||||
|
||||
Args:
|
||||
actions: Iterable of action instances
|
||||
action_factories: Dict mapping action ids to factory functions that create instances
|
||||
"""
|
||||
self._actions: dict[str, DemoAction] = {}
|
||||
self._factories: dict[str, Any] = action_factories or {}
|
||||
|
||||
if actions:
|
||||
for action in actions:
|
||||
self._actions[action.id] = action
|
||||
|
||||
def register(self, action: DemoAction) -> None:
|
||||
"""Register an action instance."""
|
||||
self._actions[action.id] = action
|
||||
|
||||
def get(self, action_id: str) -> DemoAction:
|
||||
if action_id not in self._actions:
|
||||
raise errors.ActionExecutionError(f"Unknown action '{action_id}'")
|
||||
return self._actions[action_id]
|
||||
"""Retrieve an action by id.
|
||||
|
||||
Searches in order:
|
||||
1. Explicitly registered action instances
|
||||
2. Factory functions
|
||||
3. Globally registered action classes (from @register_action decorator)
|
||||
|
||||
Raises ActionExecutionError if the action is not found.
|
||||
"""
|
||||
if action_id in self._actions:
|
||||
return self._actions[action_id]
|
||||
if action_id in self._factories:
|
||||
action = self._factories[action_id]()
|
||||
self._actions[action_id] = action
|
||||
return action
|
||||
|
||||
# Check globally registered actions
|
||||
registered_classes = get_registered_actions()
|
||||
if action_id in registered_classes:
|
||||
action_cls = registered_classes[action_id]
|
||||
action = action_cls()
|
||||
self._actions[action_id] = action
|
||||
return action
|
||||
|
||||
raise errors.ActionExecutionError(f"Unknown action '{action_id}'")
|
||||
|
||||
def list_metadata(self) -> list[ActionMetadata]:
|
||||
return [
|
||||
ActionMetadata(id=action.id, description=action.description, category=action.category)
|
||||
for action in self._actions.values()
|
||||
]
|
||||
"""List metadata for all registered actions.
|
||||
|
||||
Includes:
|
||||
- Explicitly registered action instances
|
||||
- Actions from factory functions
|
||||
- Globally registered action classes (from @register_action decorator)
|
||||
"""
|
||||
seen_ids: set[str] = set()
|
||||
metadata: list[ActionMetadata] = []
|
||||
|
||||
# Add explicit instances
|
||||
for action in self._actions.values():
|
||||
metadata.append(
|
||||
ActionMetadata(id=action.id, description=action.description, category=action.category)
|
||||
)
|
||||
seen_ids.add(action.id)
|
||||
|
||||
# Add factory functions
|
||||
for factory in self._factories.values():
|
||||
action = factory()
|
||||
if action.id not in seen_ids:
|
||||
metadata.append(
|
||||
ActionMetadata(id=action.id, description=action.description, category=action.category)
|
||||
)
|
||||
seen_ids.add(action.id)
|
||||
|
||||
# Add globally registered actions
|
||||
for action_cls in get_registered_actions().values():
|
||||
if action_cls.id not in seen_ids:
|
||||
action = action_cls()
|
||||
metadata.append(
|
||||
ActionMetadata(id=action.id, description=action.description, category=action.category)
|
||||
)
|
||||
seen_ids.add(action.id)
|
||||
|
||||
return metadata
|
||||
|
||||
@@ -2,11 +2,12 @@ from playwright.async_api import Page
|
||||
|
||||
from typing import ClassVar, override
|
||||
|
||||
from guide.app.actions.base import DemoAction
|
||||
from guide.app.actions.base import DemoAction, register_action
|
||||
from guide.app.models.domain import ActionContext, ActionResult
|
||||
from guide.app.strings.service import strings
|
||||
from guide.app.strings.registry import app_strings
|
||||
|
||||
|
||||
@register_action
|
||||
class FillIntakeBasicAction(DemoAction):
|
||||
id: ClassVar[str] = "fill-intake-basic"
|
||||
description: ClassVar[str] = "Fill the intake description and advance to the next step."
|
||||
@@ -14,9 +15,7 @@ class FillIntakeBasicAction(DemoAction):
|
||||
|
||||
@override
|
||||
async def run(self, page: Page, context: ActionContext) -> ActionResult:
|
||||
description_val = strings.text("INTAKE", "CONVEYOR_BELT_REQUEST")
|
||||
if not isinstance(description_val, str):
|
||||
raise ValueError("INTAKE.CONVEYOR_BELT_REQUEST must be a string")
|
||||
await page.fill(strings.selector("INTAKE", "DESCRIPTION_FIELD"), description_val)
|
||||
await page.click(strings.selector("INTAKE", "NEXT_BUTTON"))
|
||||
description_val = app_strings.intake.texts.conveyor_belt_request
|
||||
await page.fill(app_strings.intake.selectors.description_field, description_val)
|
||||
await page.click(app_strings.intake.selectors.next_button)
|
||||
return ActionResult(details={"message": "Intake filled"})
|
||||
|
||||
@@ -1,16 +1,33 @@
|
||||
from guide.app.actions.auth.login import LoginAsPersonaAction
|
||||
from guide.app.actions.base import ActionRegistry, DemoAction
|
||||
from guide.app.actions.intake.basic import FillIntakeBasicAction
|
||||
from guide.app.actions.sourcing.add_suppliers import AddThreeSuppliersAction
|
||||
from guide.app.models.personas import PersonaStore
|
||||
|
||||
|
||||
def _load_registered_actions() -> None:
|
||||
"""Load all action modules to trigger @register_action decorators.
|
||||
|
||||
This function must be called to ensure all action classes are registered.
|
||||
"""
|
||||
import guide.app.actions.intake.basic
|
||||
import guide.app.actions.sourcing.add_suppliers
|
||||
|
||||
# Keep module names alive for registration side effects
|
||||
assert guide.app.actions.intake.basic is not None
|
||||
assert guide.app.actions.sourcing.add_suppliers is not None
|
||||
|
||||
|
||||
def default_registry(persona_store: PersonaStore, login_url: str) -> ActionRegistry:
|
||||
"""Create the default action registry with all registered actions.
|
||||
|
||||
Actions that require dependency injection (like LoginAsPersonaAction) are
|
||||
explicitly instantiated here. Actions decorated with @register_action are
|
||||
automatically discovered and instantiated by the registry.
|
||||
"""
|
||||
_load_registered_actions()
|
||||
actions: list[DemoAction] = [
|
||||
LoginAsPersonaAction(persona_store, login_url),
|
||||
FillIntakeBasicAction(),
|
||||
AddThreeSuppliersAction(),
|
||||
]
|
||||
return ActionRegistry(actions)
|
||||
|
||||
|
||||
__all__ = ["default_registry", "ActionRegistry", "DemoAction"]
|
||||
|
||||
@@ -2,11 +2,12 @@ from playwright.async_api import Page
|
||||
|
||||
from typing import ClassVar, override
|
||||
|
||||
from guide.app.actions.base import DemoAction
|
||||
from guide.app.actions.base import DemoAction, register_action
|
||||
from guide.app.models.domain import ActionContext, ActionResult
|
||||
from guide.app.strings.service import strings
|
||||
from guide.app.strings.registry import app_strings
|
||||
|
||||
|
||||
@register_action
|
||||
class AddThreeSuppliersAction(DemoAction):
|
||||
id: ClassVar[str] = "add-three-suppliers"
|
||||
description: ClassVar[str] = "Adds three default suppliers to the sourcing event."
|
||||
@@ -14,11 +15,8 @@ class AddThreeSuppliersAction(DemoAction):
|
||||
|
||||
@override
|
||||
async def run(self, page: Page, context: ActionContext) -> ActionResult:
|
||||
suppliers_val = strings.text("SUPPLIERS", "DEFAULT_TRIO")
|
||||
if not isinstance(suppliers_val, list):
|
||||
raise ValueError("SUPPLIERS.DEFAULT_TRIO must be a list of strings")
|
||||
suppliers: list[str] = list(suppliers_val)
|
||||
suppliers = app_strings.sourcing.texts.default_trio
|
||||
for supplier in suppliers:
|
||||
await page.fill(strings.selector("SOURCING", "SUPPLIER_SEARCH_INPUT"), supplier)
|
||||
await page.click(strings.selector("SOURCING", "ADD_SUPPLIER_BUTTON"))
|
||||
await page.fill(app_strings.sourcing.selectors.supplier_search_input, supplier)
|
||||
await page.click(app_strings.sourcing.selectors.add_supplier_button)
|
||||
return ActionResult(details={"added_suppliers": list(suppliers)})
|
||||
|
||||
@@ -67,7 +67,7 @@ async def execute_action(
|
||||
|
||||
persona = personas.get(payload.persona_id) if payload.persona_id else None
|
||||
target_host_id = payload.browser_host_id or (persona.browser_host_id if persona else None)
|
||||
target_host_id = target_host_id or browser_client.settings.default_browser_host_id
|
||||
target_host_id = target_host_id or settings.default_browser_host_id
|
||||
|
||||
context = ActionContext(
|
||||
action_id=action_id,
|
||||
|
||||
@@ -2,19 +2,18 @@ from playwright.async_api import Page
|
||||
|
||||
from guide.app.auth.mfa import MfaCodeProvider
|
||||
from guide.app.models.personas.models import DemoPersona
|
||||
from guide.app.strings.service import strings
|
||||
from guide.app.strings.registry import app_strings
|
||||
|
||||
|
||||
async def detect_current_persona(page: Page) -> str | None:
|
||||
"""Return the email/identifier of the currently signed-in user, if visible."""
|
||||
current_selector = strings.selector("AUTH", "CURRENT_USER_DISPLAY")
|
||||
element = page.locator(current_selector)
|
||||
element = page.locator(app_strings.auth.selectors.current_user_display)
|
||||
if await element.count() == 0:
|
||||
return None
|
||||
text = await element.first.text_content()
|
||||
if text is None:
|
||||
return None
|
||||
prefix = strings.label("AUTH", "CURRENT_USER_DISPLAY_PREFIX")
|
||||
prefix = app_strings.auth.labels.current_user_display_prefix
|
||||
if prefix and text.startswith(prefix):
|
||||
return text.removeprefix(prefix).strip()
|
||||
return text.strip()
|
||||
@@ -24,16 +23,15 @@ async def login_with_mfa(page: Page, email: str, mfa_provider: MfaCodeProvider,
|
||||
if login_url:
|
||||
_response = await page.goto(login_url)
|
||||
del _response
|
||||
await page.fill(strings.selector("AUTH", "EMAIL_INPUT"), email)
|
||||
await page.click(strings.selector("AUTH", "SEND_CODE_BUTTON"))
|
||||
await page.fill(app_strings.auth.selectors.email_input, email)
|
||||
await page.click(app_strings.auth.selectors.send_code_button)
|
||||
code = mfa_provider.get_code(email)
|
||||
await page.fill(strings.selector("AUTH", "CODE_INPUT"), code)
|
||||
await page.click(strings.selector("AUTH", "SUBMIT_BUTTON"))
|
||||
await page.fill(app_strings.auth.selectors.code_input, code)
|
||||
await page.click(app_strings.auth.selectors.submit_button)
|
||||
|
||||
|
||||
async def logout(page: Page) -> None:
|
||||
logout_selector = strings.selector("AUTH", "LOGOUT_BUTTON")
|
||||
await page.click(logout_selector)
|
||||
await page.click(app_strings.auth.selectors.logout_button)
|
||||
|
||||
|
||||
async def ensure_persona(page: Page, persona: DemoPersona, mfa_provider: MfaCodeProvider, login_url: str | None = None) -> None:
|
||||
|
||||
@@ -1,88 +1,42 @@
|
||||
import contextlib
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from playwright.async_api import Browser, BrowserContext, Page, Playwright, async_playwright
|
||||
from playwright.async_api import Page
|
||||
|
||||
from guide.app.core.config import AppSettings, BrowserHostConfig, HostKind
|
||||
from guide.app import errors
|
||||
from guide.app.browser.pool import BrowserPool
|
||||
|
||||
|
||||
class BrowserClient:
|
||||
"""Connector that yields a Playwright Page for either CDP or headless hosts."""
|
||||
"""Provides page access via a persistent browser pool.
|
||||
|
||||
settings: AppSettings
|
||||
This client uses the BrowserPool to efficiently manage connections.
|
||||
Instead of opening/closing browsers per request, the pool maintains
|
||||
long-lived connections and allocates pages on demand.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: AppSettings) -> None:
|
||||
self.settings = settings
|
||||
def __init__(self, pool: BrowserPool) -> None:
|
||||
"""Initialize with a browser pool.
|
||||
|
||||
def _resolve_host(self, host_id: str | None) -> BrowserHostConfig:
|
||||
resolved_id = host_id or self.settings.default_browser_host_id
|
||||
host = self.settings.browser_hosts.get(resolved_id)
|
||||
if not host:
|
||||
known = ", ".join(self.settings.browser_hosts.keys()) or "<none>"
|
||||
raise errors.ConfigError(f"Unknown browser host '{resolved_id}'. Known: {known}")
|
||||
return host
|
||||
Args:
|
||||
pool: The BrowserPool instance to use
|
||||
"""
|
||||
self.pool = pool
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def open_page(self, host_id: str | None = None) -> AsyncIterator[Page]:
|
||||
host = self._resolve_host(host_id)
|
||||
playwright = await async_playwright().start()
|
||||
browser: Browser | None = None
|
||||
context: BrowserContext | None = None
|
||||
try:
|
||||
if host.kind == HostKind.CDP:
|
||||
if not host.host or host.port is None:
|
||||
raise errors.ConfigError("CDP host requires host and port fields.")
|
||||
cdp_url = f"http://{host.host}:{host.port}"
|
||||
try:
|
||||
browser = await playwright.chromium.connect_over_cdp(cdp_url)
|
||||
except Exception as exc: # pragma: no cover - network dependent
|
||||
raise errors.BrowserConnectionError(
|
||||
f"Cannot connect to CDP endpoint {cdp_url}", details={"host_id": host.id}
|
||||
) from exc
|
||||
page = self._pick_raindrop_page(browser)
|
||||
if not page:
|
||||
raise errors.BrowserConnectionError(
|
||||
"No Raindrop page found in connected browser.", details={"host_id": host.id}
|
||||
)
|
||||
else:
|
||||
browser_type = _resolve_browser_type(playwright, host.browser)
|
||||
browser = await browser_type.launch(headless=True)
|
||||
context = await browser.new_context()
|
||||
page = await context.new_page()
|
||||
yield page
|
||||
finally:
|
||||
with contextlib.suppress(Exception):
|
||||
if context:
|
||||
await context.close()
|
||||
with contextlib.suppress(Exception):
|
||||
if browser:
|
||||
await browser.close()
|
||||
with contextlib.suppress(Exception):
|
||||
await playwright.stop()
|
||||
"""Get a page from the pool for the specified host.
|
||||
|
||||
def _pick_raindrop_page(self, browser: Browser) -> Page | None:
|
||||
target_substr = self.settings.raindrop_base_url
|
||||
pages: list[Page] = []
|
||||
for context in browser.contexts:
|
||||
pages.extend(context.pages)
|
||||
pages = pages or list(browser.contexts[0].pages) if browser.contexts else []
|
||||
return next(
|
||||
(
|
||||
page
|
||||
for page in reversed(pages)
|
||||
if target_substr in (page.url or "")
|
||||
),
|
||||
pages[-1] if pages else None,
|
||||
)
|
||||
The page is obtained from the pool's persistent browser connection.
|
||||
No browser startup/connection overhead occurs on each request.
|
||||
|
||||
Args:
|
||||
host_id: The host identifier, or None for the default host
|
||||
|
||||
Yields:
|
||||
A Playwright Page instance
|
||||
"""
|
||||
page = await self.pool.get_page(host_id)
|
||||
yield page
|
||||
|
||||
|
||||
def _resolve_browser_type(playwright: Playwright, browser: str | None):
|
||||
desired = (browser or "chromium").lower()
|
||||
if desired == "chromium":
|
||||
return playwright.chromium
|
||||
if desired == "firefox":
|
||||
return playwright.firefox
|
||||
if desired == "webkit":
|
||||
return playwright.webkit
|
||||
raise errors.ConfigError(f"Unsupported headless browser type '{browser}'")
|
||||
__all__ = ["BrowserClient"]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
@@ -8,6 +9,8 @@ from typing import ClassVar, TypeAlias, TypeGuard, cast
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
CONFIG_DIR = Path(__file__).resolve().parents[4] / "config"
|
||||
HOSTS_FILE = CONFIG_DIR / "hosts.yaml"
|
||||
PERSONAS_FILE = CONFIG_DIR / "personas.yaml"
|
||||
@@ -17,19 +20,25 @@ RecordList: TypeAlias = list[JsonRecord]
|
||||
|
||||
|
||||
def _coerce_mapping(mapping: Mapping[object, object]) -> dict[str, object]:
|
||||
"""Convert mapping keys to strings."""
|
||||
return {str(key): value for key, value in mapping.items()}
|
||||
|
||||
|
||||
def _is_object_mapping(value: object) -> TypeGuard[Mapping[object, object]]:
|
||||
"""Check if value is a Mapping."""
|
||||
return isinstance(value, Mapping)
|
||||
|
||||
|
||||
class HostKind(str, Enum):
|
||||
"""Browser host kind: CDP or headless."""
|
||||
|
||||
CDP = "cdp"
|
||||
HEADLESS = "headless"
|
||||
|
||||
|
||||
class BrowserHostConfig(BaseModel):
|
||||
"""Configuration for a browser host (CDP or headless)."""
|
||||
|
||||
id: str
|
||||
kind: HostKind
|
||||
host: str | None = None
|
||||
@@ -38,6 +47,8 @@ class BrowserHostConfig(BaseModel):
|
||||
|
||||
|
||||
class PersonaConfig(BaseModel):
|
||||
"""Configuration for a demo persona."""
|
||||
|
||||
id: str
|
||||
role: str
|
||||
email: str
|
||||
@@ -46,7 +57,20 @@ class PersonaConfig(BaseModel):
|
||||
|
||||
|
||||
class AppSettings(BaseSettings):
|
||||
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(env_prefix="RAINDROP_DEMO_")
|
||||
"""Application settings loaded from YAML files + environment variables.
|
||||
|
||||
Configuration sources (in order):
|
||||
1. Environment variables with prefix RAINDROP_DEMO_
|
||||
2. YAML files: config/hosts.yaml, config/personas.yaml
|
||||
3. JSON overrides via RAINDROP_DEMO_BROWSER_HOSTS_JSON, RAINDROP_DEMO_PERSONAS_JSON
|
||||
|
||||
The JSON overrides take precedence over file-based config.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
|
||||
env_prefix="RAINDROP_DEMO_",
|
||||
case_sensitive=False,
|
||||
)
|
||||
|
||||
raindrop_base_url: str = "https://app.raindrop.com"
|
||||
raindrop_graphql_url: str = "https://app.raindrop.com/graphql"
|
||||
@@ -55,7 +79,8 @@ class AppSettings(BaseSettings):
|
||||
personas: dict[str, PersonaConfig] = Field(default_factory=dict)
|
||||
|
||||
|
||||
def _load_hosts_file(path: Path) -> dict[str, BrowserHostConfig]:
|
||||
def _load_yaml_file(path: Path) -> dict[str, object]:
|
||||
"""Load YAML file, handling missing PyYAML gracefully."""
|
||||
if not path.exists():
|
||||
return {}
|
||||
|
||||
@@ -63,123 +88,118 @@ def _load_hosts_file(path: Path) -> dict[str, BrowserHostConfig]:
|
||||
import yaml
|
||||
except ModuleNotFoundError as exc:
|
||||
raise RuntimeError(
|
||||
"hosts.yaml found but PyYAML is not installed. Add 'pyyaml' to dependencies or remove hosts.yaml."
|
||||
f"{path.name} found but PyYAML is not installed. Add 'pyyaml' to dependencies."
|
||||
) from exc
|
||||
|
||||
data_raw: object = yaml.safe_load(path.read_text()) or {}
|
||||
hosts: dict[str, BrowserHostConfig] = {}
|
||||
for item in _normalize_records(data_raw, key_name="hosts"):
|
||||
host = BrowserHostConfig.model_validate(item)
|
||||
hosts[host.id] = host
|
||||
return hosts
|
||||
return cast(dict[str, object], yaml.safe_load(path.read_text()) or {})
|
||||
|
||||
|
||||
def _parse_json_hosts(value: str) -> dict[str, BrowserHostConfig]:
|
||||
decoded_raw: object = cast(object, json.loads(value))
|
||||
decoded = _ensure_record_list(decoded_raw, "RAINDROP_DEMO_BROWSER_HOSTS_JSON")
|
||||
hosts: dict[str, BrowserHostConfig] = {}
|
||||
for item in decoded:
|
||||
host = BrowserHostConfig.model_validate(item)
|
||||
hosts[host.id] = host
|
||||
return hosts
|
||||
|
||||
|
||||
def _parse_json_personas(value: str) -> dict[str, PersonaConfig]:
|
||||
decoded_raw: object = cast(object, json.loads(value))
|
||||
decoded = _ensure_record_list(decoded_raw, "RAINDROP_DEMO_PERSONAS_JSON")
|
||||
personas: dict[str, PersonaConfig] = {}
|
||||
for item in decoded:
|
||||
persona = PersonaConfig.model_validate(item)
|
||||
personas[persona.id] = persona
|
||||
return personas
|
||||
|
||||
|
||||
def _load_personas_file(path: Path) -> dict[str, PersonaConfig]:
|
||||
if not path.exists():
|
||||
return {}
|
||||
|
||||
try:
|
||||
import yaml
|
||||
except ModuleNotFoundError as exc:
|
||||
raise RuntimeError(
|
||||
"personas.yaml found but PyYAML is not installed. Add 'pyyaml' to dependencies or remove personas.yaml."
|
||||
) from exc
|
||||
|
||||
data_raw: object = yaml.safe_load(path.read_text()) or {}
|
||||
personas: dict[str, PersonaConfig] = {}
|
||||
for item in _normalize_records(data_raw, key_name="personas"):
|
||||
persona = PersonaConfig.model_validate(item)
|
||||
personas[persona.id] = persona
|
||||
return personas
|
||||
|
||||
|
||||
def _normalize_records(data_raw: object, key_name: str) -> RecordList:
|
||||
if isinstance(data_raw, Mapping):
|
||||
mapping_raw: dict[str, object] = _coerce_mapping(cast(Mapping[object, object], data_raw))
|
||||
content: object = mapping_raw.get(key_name, mapping_raw)
|
||||
def _normalize_host_records(data: object) -> RecordList:
|
||||
"""Normalize host records from YAML or list format."""
|
||||
if isinstance(data, Mapping):
|
||||
mapping = _coerce_mapping(cast(Mapping[object, object], data))
|
||||
content = mapping.get("hosts", mapping)
|
||||
else:
|
||||
content = data_raw
|
||||
content = data
|
||||
|
||||
records: RecordList = []
|
||||
if isinstance(content, Mapping):
|
||||
mapping_content: dict[str, object] = _coerce_mapping(cast(Mapping[object, object], content))
|
||||
records: RecordList = []
|
||||
mapping_content = _coerce_mapping(cast(Mapping[object, object], content))
|
||||
for key, value in mapping_content.items():
|
||||
mapping = _ensure_mapping(value)
|
||||
if "id" not in mapping:
|
||||
mapping["id"] = key
|
||||
records.append(mapping)
|
||||
return records
|
||||
|
||||
if isinstance(content, list):
|
||||
list_content: list[object] = cast(list[object], content)
|
||||
list_records: RecordList = []
|
||||
for item in list_content:
|
||||
record = _coerce_mapping(cast(Mapping[object, object], value)) if _is_object_mapping(value) else {}
|
||||
if "id" not in record:
|
||||
record["id"] = key
|
||||
records.append(record)
|
||||
elif isinstance(content, list):
|
||||
for item in cast(list[object], content):
|
||||
if _is_object_mapping(item):
|
||||
mapping_item = _coerce_mapping(item)
|
||||
list_records.append(dict(mapping_item))
|
||||
return list_records
|
||||
records.append(_coerce_mapping(cast(Mapping[object, object], item)))
|
||||
|
||||
return []
|
||||
return records
|
||||
|
||||
|
||||
def _ensure_mapping(value: object) -> JsonRecord:
|
||||
if isinstance(value, Mapping):
|
||||
return _coerce_mapping(cast(Mapping[object, object], value))
|
||||
return {}
|
||||
def _normalize_persona_records(data: object) -> RecordList:
|
||||
"""Normalize persona records from YAML or list format."""
|
||||
if isinstance(data, Mapping):
|
||||
mapping = _coerce_mapping(cast(Mapping[object, object], data))
|
||||
content = mapping.get("personas", mapping)
|
||||
else:
|
||||
content = data
|
||||
|
||||
|
||||
def _ensure_record_list(value: object, source_label: str) -> RecordList:
|
||||
if isinstance(value, list):
|
||||
list_content: list[object] = cast(list[object], value)
|
||||
list_records: RecordList = []
|
||||
for item in list_content:
|
||||
records: RecordList = []
|
||||
if isinstance(content, Mapping):
|
||||
mapping_content = _coerce_mapping(cast(Mapping[object, object], content))
|
||||
for key, value in mapping_content.items():
|
||||
record = _coerce_mapping(cast(Mapping[object, object], value)) if _is_object_mapping(value) else {}
|
||||
if "id" not in record:
|
||||
record["id"] = key
|
||||
records.append(record)
|
||||
elif isinstance(content, list):
|
||||
for item in cast(list[object], content):
|
||||
if _is_object_mapping(item):
|
||||
mapping_item = _coerce_mapping(item)
|
||||
list_records.append(dict(mapping_item))
|
||||
return list_records
|
||||
raise ValueError(f"{source_label} must be a JSON array of objects")
|
||||
records.append(_coerce_mapping(cast(Mapping[object, object], item)))
|
||||
|
||||
return records
|
||||
|
||||
|
||||
def load_settings() -> AppSettings:
|
||||
"""Load application settings from files and environment.
|
||||
|
||||
Configuration is loaded in this order (later overrides earlier):
|
||||
1. Default values in AppSettings model
|
||||
2. YAML files (config/hosts.yaml, config/personas.yaml)
|
||||
3. Environment variables (RAINDROP_DEMO_* prefix)
|
||||
4. JSON overrides via environment (RAINDROP_DEMO_BROWSER_HOSTS_JSON, etc.)
|
||||
"""
|
||||
settings = AppSettings()
|
||||
merged_hosts: dict[str, BrowserHostConfig] = {}
|
||||
merged_personas: dict[str, PersonaConfig] = {}
|
||||
|
||||
merged_hosts |= _load_hosts_file(HOSTS_FILE)
|
||||
merged_personas |= _load_personas_file(PERSONAS_FILE)
|
||||
# Load from YAML files
|
||||
hosts_data = _load_yaml_file(HOSTS_FILE)
|
||||
personas_data = _load_yaml_file(PERSONAS_FILE)
|
||||
|
||||
if env_json := os.environ.get("RAINDROP_DEMO_BROWSER_HOSTS_JSON"):
|
||||
merged_hosts.update(_parse_json_hosts(env_json))
|
||||
hosts_dict: dict[str, BrowserHostConfig] = {}
|
||||
for record in _normalize_host_records(hosts_data):
|
||||
host = BrowserHostConfig.model_validate(record)
|
||||
hosts_dict[host.id] = host
|
||||
|
||||
if env_personas := os.environ.get("RAINDROP_DEMO_PERSONAS_JSON"):
|
||||
merged_personas.update(_parse_json_personas(env_personas))
|
||||
personas_dict: dict[str, PersonaConfig] = {}
|
||||
for record in _normalize_persona_records(personas_data):
|
||||
persona = PersonaConfig.model_validate(record)
|
||||
personas_dict[persona.id] = persona
|
||||
|
||||
if merged_hosts or merged_personas:
|
||||
# Load JSON overrides from environment
|
||||
if browser_hosts_json := os.environ.get("RAINDROP_DEMO_BROWSER_HOSTS_JSON"):
|
||||
try:
|
||||
decoded_hosts: object = json.loads(browser_hosts_json)
|
||||
if not isinstance(decoded_hosts, list):
|
||||
raise ValueError("RAINDROP_DEMO_BROWSER_HOSTS_JSON must be a JSON array")
|
||||
for record in cast(list[object], decoded_hosts):
|
||||
if _is_object_mapping(record):
|
||||
host = BrowserHostConfig.model_validate(record)
|
||||
hosts_dict[host.id] = host
|
||||
except (json.JSONDecodeError, ValueError) as exc:
|
||||
_logger.warning(f"Failed to parse RAINDROP_DEMO_BROWSER_HOSTS_JSON: {exc}")
|
||||
|
||||
if personas_json := os.environ.get("RAINDROP_DEMO_PERSONAS_JSON"):
|
||||
try:
|
||||
decoded_personas: object = json.loads(personas_json)
|
||||
if not isinstance(decoded_personas, list):
|
||||
raise ValueError("RAINDROP_DEMO_PERSONAS_JSON must be a JSON array")
|
||||
for record in cast(list[object], decoded_personas):
|
||||
if _is_object_mapping(record):
|
||||
persona = PersonaConfig.model_validate(record)
|
||||
personas_dict[persona.id] = persona
|
||||
except (json.JSONDecodeError, ValueError) as exc:
|
||||
_logger.warning(f"Failed to parse RAINDROP_DEMO_PERSONAS_JSON: {exc}")
|
||||
|
||||
# Update settings with loaded configuration
|
||||
if hosts_dict or personas_dict:
|
||||
settings = settings.model_copy(
|
||||
update={
|
||||
"browser_hosts": merged_hosts or settings.browser_hosts,
|
||||
"personas": merged_personas or settings.personas,
|
||||
"browser_hosts": hosts_dict or settings.browser_hosts,
|
||||
"personas": personas_dict or settings.personas,
|
||||
}
|
||||
)
|
||||
|
||||
_logger.info(f"Loaded {len(settings.browser_hosts)} browser hosts, {len(settings.personas)} personas")
|
||||
return settings
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Request
|
||||
from fastapi.exception_handlers import http_exception_handler
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
from guide.app.actions.registry import default_registry
|
||||
from guide.app.browser.client import BrowserClient
|
||||
from guide.app.browser.pool import BrowserPool
|
||||
from guide.app.core.config import AppSettings, load_settings
|
||||
from guide.app.core.logging import configure_logging
|
||||
from guide.app.api import router as api_router
|
||||
from guide.app import errors
|
||||
from guide.app.models.personas import PersonaStore
|
||||
from fastapi import Request
|
||||
from fastapi.exception_handlers import http_exception_handler
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
@@ -18,16 +19,30 @@ def create_app() -> FastAPI:
|
||||
settings = load_settings()
|
||||
persona_store = PersonaStore(settings)
|
||||
registry = default_registry(persona_store, settings.raindrop_base_url)
|
||||
browser_client = BrowserClient(settings)
|
||||
|
||||
# Create browser pool for efficient connection management
|
||||
browser_pool = BrowserPool(settings)
|
||||
browser_client = BrowserClient(browser_pool)
|
||||
|
||||
app = FastAPI(title="Raindrop Demo Automation", version="0.1.0")
|
||||
app.state.settings = settings
|
||||
app.state.action_registry = registry
|
||||
app.state.browser_client = browser_client
|
||||
app.state.browser_pool = browser_pool
|
||||
app.state.persona_store = persona_store
|
||||
|
||||
# Dependency overrides so FastAPI deps can pull settings without globals
|
||||
app.dependency_overrides = {AppSettings: lambda: settings}
|
||||
|
||||
# Startup/shutdown lifecycle for browser pool
|
||||
@app.on_event("startup")
|
||||
async def startup_browser_pool() -> None:
|
||||
await browser_pool.initialize()
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_browser_pool() -> None:
|
||||
await browser_pool.close()
|
||||
|
||||
app.include_router(api_router)
|
||||
app.add_exception_handler(errors.GuideError, guide_exception_handler)
|
||||
app.add_exception_handler(Exception, general_exception_handler)
|
||||
|
||||
@@ -1,12 +1,23 @@
|
||||
from guide.app.models.personas.models import DemoPersona
|
||||
from guide.app.models.types import JSONValue
|
||||
from guide.app.raindrop.graphql import GraphQLClient
|
||||
from guide.app.raindrop.types import GetIntakeRequestResponse, CreateIntakeRequestResponse, IntakeRequestData
|
||||
from guide.app.strings import graphql as gql_strings
|
||||
|
||||
|
||||
async def get_intake_request(
|
||||
client: GraphQLClient, persona: DemoPersona, request_id: str
|
||||
) -> dict[str, JSONValue]:
|
||||
) -> IntakeRequestData:
|
||||
"""Fetch an intake request by ID.
|
||||
|
||||
Args:
|
||||
client: GraphQL client
|
||||
persona: Persona making the request
|
||||
request_id: The intake request ID
|
||||
|
||||
Returns:
|
||||
IntakeRequestData: Type-safe response data
|
||||
"""
|
||||
variables: dict[str, JSONValue] = {"id": request_id}
|
||||
data = await client.execute(
|
||||
query=gql_strings.GET_INTAKE_REQUEST,
|
||||
@@ -14,13 +25,25 @@ async def get_intake_request(
|
||||
persona=persona,
|
||||
operation_name="GetIntakeRequest",
|
||||
)
|
||||
result = data.get("intakeRequest")
|
||||
return result if isinstance(result, dict) else {}
|
||||
response = GetIntakeRequestResponse.model_validate(data)
|
||||
if not response.intake_request:
|
||||
raise ValueError(f"No intake request found with id: {request_id}")
|
||||
return response.intake_request
|
||||
|
||||
|
||||
async def create_intake_request(
|
||||
client: GraphQLClient, persona: DemoPersona, payload: dict[str, JSONValue]
|
||||
) -> dict[str, JSONValue]:
|
||||
) -> IntakeRequestData:
|
||||
"""Create a new intake request.
|
||||
|
||||
Args:
|
||||
client: GraphQL client
|
||||
persona: Persona making the request
|
||||
payload: Intake request input data
|
||||
|
||||
Returns:
|
||||
IntakeRequestData: The created intake request
|
||||
"""
|
||||
variables: dict[str, JSONValue] = {"input": payload}
|
||||
data = await client.execute(
|
||||
query=gql_strings.CREATE_INTAKE_REQUEST,
|
||||
@@ -28,5 +51,7 @@ async def create_intake_request(
|
||||
persona=persona,
|
||||
operation_name="CreateIntakeRequest",
|
||||
)
|
||||
result = data.get("createIntakeRequest")
|
||||
return result if isinstance(result, dict) else {}
|
||||
response = CreateIntakeRequestResponse.model_validate(data)
|
||||
if not response.create_intake_request:
|
||||
raise ValueError("Failed to create intake request")
|
||||
return response.create_intake_request
|
||||
|
||||
@@ -1,10 +1,21 @@
|
||||
from guide.app.models.personas.models import DemoPersona
|
||||
from guide.app.models.types import JSONValue
|
||||
from guide.app.raindrop.graphql import GraphQLClient
|
||||
from guide.app.raindrop.types import ListSuppliersResponse, AddSupplierResponse, SupplierData
|
||||
from guide.app.strings import graphql as gql_strings
|
||||
|
||||
|
||||
async def list_suppliers(client: GraphQLClient, persona: DemoPersona, limit: int = 10) -> list[dict[str, JSONValue]]:
|
||||
async def list_suppliers(client: GraphQLClient, persona: DemoPersona, limit: int = 10) -> list[SupplierData]:
|
||||
"""Fetch a list of suppliers.
|
||||
|
||||
Args:
|
||||
client: GraphQL client
|
||||
persona: Persona making the request
|
||||
limit: Maximum number of suppliers to return
|
||||
|
||||
Returns:
|
||||
list[SupplierData]: Type-safe list of suppliers
|
||||
"""
|
||||
variables: dict[str, JSONValue] = {"limit": limit}
|
||||
data = await client.execute(
|
||||
query=gql_strings.LIST_SUPPLIERS,
|
||||
@@ -12,19 +23,23 @@ async def list_suppliers(client: GraphQLClient, persona: DemoPersona, limit: int
|
||||
persona=persona,
|
||||
operation_name="ListSuppliers",
|
||||
)
|
||||
suppliers = data.get("suppliers")
|
||||
if not isinstance(suppliers, list):
|
||||
return []
|
||||
filtered: list[dict[str, JSONValue]] = []
|
||||
for item in suppliers:
|
||||
if isinstance(item, dict):
|
||||
filtered.append(item)
|
||||
return filtered
|
||||
response = ListSuppliersResponse.model_validate(data)
|
||||
return response.suppliers
|
||||
|
||||
|
||||
async def add_supplier(
|
||||
client: GraphQLClient, persona: DemoPersona, supplier: dict[str, JSONValue]
|
||||
) -> dict[str, JSONValue]:
|
||||
) -> SupplierData:
|
||||
"""Add a supplier to the sourcing event.
|
||||
|
||||
Args:
|
||||
client: GraphQL client
|
||||
persona: Persona making the request
|
||||
supplier: Supplier input data
|
||||
|
||||
Returns:
|
||||
SupplierData: The created supplier
|
||||
"""
|
||||
variables: dict[str, JSONValue] = {"input": supplier}
|
||||
data = await client.execute(
|
||||
query=gql_strings.ADD_SUPPLIER,
|
||||
@@ -32,5 +47,7 @@ async def add_supplier(
|
||||
persona=persona,
|
||||
operation_name="AddSupplier",
|
||||
)
|
||||
result = data.get("addSupplier")
|
||||
return result if isinstance(result, dict) else {}
|
||||
response = AddSupplierResponse.model_validate(data)
|
||||
if not response.add_supplier:
|
||||
raise ValueError("Failed to add supplier")
|
||||
return response.add_supplier
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
from typing import TypeAlias, cast
|
||||
|
||||
from guide.app.strings.graphql import (
|
||||
ADD_SUPPLIER,
|
||||
CREATE_INTAKE_REQUEST,
|
||||
GET_INTAKE_REQUEST,
|
||||
LIST_SUPPLIERS,
|
||||
)
|
||||
from guide.app.strings.demo_texts import DemoTexts, EventTexts, IntakeTexts, SupplierTexts
|
||||
from guide.app.strings.labels import AuthLabels, IntakeLabels, Labels, SourcingLabels
|
||||
from guide.app.strings.selectors import AuthSelectors, IntakeSelectors, NavigationSelectors, Selectors, SourcingSelectors
|
||||
|
||||
SelectorNamespace: TypeAlias = type[IntakeSelectors] | type[SourcingSelectors] | type[NavigationSelectors] | type[AuthSelectors]
|
||||
LabelNamespace: TypeAlias = type[IntakeLabels] | type[SourcingLabels] | type[AuthLabels]
|
||||
TextNamespace: TypeAlias = type[IntakeTexts] | type[SupplierTexts] | type[EventTexts]
|
||||
TextValue: TypeAlias = str | list[str]
|
||||
|
||||
_SELECTORS: dict[str, SelectorNamespace] = {
|
||||
"INTAKE": Selectors.INTAKE,
|
||||
"SOURCING": Selectors.SOURCING,
|
||||
"NAVIGATION": Selectors.NAVIGATION,
|
||||
"AUTH": Selectors.AUTH,
|
||||
}
|
||||
|
||||
_LABELS: dict[str, LabelNamespace] = {
|
||||
"INTAKE": Labels.INTAKE,
|
||||
"SOURCING": Labels.SOURCING,
|
||||
"AUTH": Labels.AUTH,
|
||||
}
|
||||
|
||||
_TEXTS: dict[str, TextNamespace] = {
|
||||
"INTAKE": DemoTexts.INTAKE,
|
||||
"SUPPLIERS": DemoTexts.SUPPLIERS,
|
||||
"EVENTS": DemoTexts.EVENTS,
|
||||
}
|
||||
|
||||
_GQL = {
|
||||
"GET_INTAKE_REQUEST": GET_INTAKE_REQUEST,
|
||||
"CREATE_INTAKE_REQUEST": CREATE_INTAKE_REQUEST,
|
||||
"LIST_SUPPLIERS": LIST_SUPPLIERS,
|
||||
"ADD_SUPPLIER": ADD_SUPPLIER,
|
||||
}
|
||||
|
||||
|
||||
class Strings:
|
||||
"""Unified accessor for selectors, labels, texts, and GraphQL queries."""
|
||||
|
||||
@staticmethod
|
||||
def selector(domain: str, name: str) -> str:
|
||||
ns = _SELECTORS.get(domain.upper())
|
||||
if ns is None:
|
||||
raise KeyError(f"Unknown selector domain '{domain}'")
|
||||
value = cast(str | None, getattr(ns, name, None))
|
||||
return _as_str(value, f"selector {domain}.{name}")
|
||||
|
||||
@staticmethod
|
||||
def label(domain: str, name: str) -> str:
|
||||
ns = _LABELS.get(domain.upper())
|
||||
if ns is None:
|
||||
raise KeyError(f"Unknown label domain '{domain}'")
|
||||
value = cast(str | None, getattr(ns, name, None))
|
||||
return _as_str(value, f"label {domain}.{name}")
|
||||
|
||||
@staticmethod
|
||||
def text(domain: str, name: str) -> TextValue:
|
||||
ns = _TEXTS.get(domain.upper())
|
||||
if ns is None:
|
||||
raise KeyError(f"Unknown text domain '{domain}'")
|
||||
value: object | None = getattr(ns, name, None)
|
||||
if value is None:
|
||||
raise KeyError(f"Unknown text {domain}.{name}")
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, list):
|
||||
value_list: list[object] = cast(list[object], value)
|
||||
str_values: list[str] = []
|
||||
for item in value_list:
|
||||
if not isinstance(item, str):
|
||||
raise TypeError(f"text {domain}.{name} must be a string or list of strings")
|
||||
str_values.append(item)
|
||||
return str_values
|
||||
raise TypeError(f"text {domain}.{name} must be a string or list of strings")
|
||||
|
||||
@staticmethod
|
||||
def gql(name: str) -> str:
|
||||
value = _GQL.get(name)
|
||||
return _as_str(value, f"graphql string {name}")
|
||||
|
||||
|
||||
def _as_str(value: object, label: str) -> str:
|
||||
if not isinstance(value, str): # pragma: no cover
|
||||
raise TypeError(f"{label} must be a string")
|
||||
return value
|
||||
|
||||
|
||||
strings = Strings()
|
||||
|
||||
__all__ = ["Strings", "strings"]
|
||||
Reference in New Issue
Block a user