From 94f9051aee340b1830e71c674f169812ed267c0b Mon Sep 17 00:00:00 2001 From: Travis Vasceannie Date: Sat, 22 Nov 2025 01:47:12 +0000 Subject: [PATCH] x --- src/guide/app/actions/base.py | 122 ++++++++-- src/guide/app/actions/intake/basic.py | 13 +- src/guide/app/actions/registry.py | 25 ++- .../app/actions/sourcing/add_suppliers.py | 14 +- src/guide/app/api/routes/actions.py | 2 +- src/guide/app/auth/session.py | 18 +- src/guide/app/browser/client.py | 98 +++------ src/guide/app/core/config.py | 208 ++++++++++-------- src/guide/app/main.py | 23 +- src/guide/app/raindrop/operations/intake.py | 37 +++- src/guide/app/raindrop/operations/sourcing.py | 41 +++- src/guide/app/strings/service.py | 98 --------- 12 files changed, 370 insertions(+), 329 deletions(-) delete mode 100644 src/guide/app/strings/service.py diff --git a/src/guide/app/actions/base.py b/src/guide/app/actions/base.py index 6e904fd..b594a3e 100644 --- a/src/guide/app/actions/base.py +++ b/src/guide/app/actions/base.py @@ -1,5 +1,5 @@ -from collections.abc import Iterable -from typing import ClassVar, Protocol, runtime_checkable +from collections.abc import Iterable, Mapping +from typing import Any, ClassVar, Protocol, runtime_checkable from playwright.async_api import Page @@ -16,22 +16,118 @@ class DemoAction(Protocol): async def run(self, page: Page, context: ActionContext) -> ActionResult: ... -class ActionRegistry: - """Simple mapping of action id -> action instance.""" +# Global registry of action classes +_REGISTERED_ACTIONS: dict[str, type[DemoAction]] = {} - def __init__(self, actions: Iterable[DemoAction]): - self._actions: dict[str, DemoAction] = {action.id: action for action in actions} + +def register_action(cls: type[DemoAction]) -> type[DemoAction]: + """Decorator to register an action class to the global action registry. + + Usage: + @register_action + class MyAction(DemoAction): + id = "my-action" + ... + """ + _REGISTERED_ACTIONS[cls.id] = cls + return cls + + +def get_registered_actions() -> Mapping[str, type[DemoAction]]: + """Return a mapping of all registered action classes by id.""" + return _REGISTERED_ACTIONS.copy() + + +class ActionRegistry: + """Manages action instances and metadata. + + Supports both explicit action registration and dynamically-registered actions + via the @register_action decorator. + """ + + def __init__( + self, + actions: Iterable[DemoAction] | None = None, + action_factories: dict[str, Any] | None = None, + ) -> None: + """Initialize registry with action instances and/or factory functions. + + Args: + actions: Iterable of action instances + action_factories: Dict mapping action ids to factory functions that create instances + """ + self._actions: dict[str, DemoAction] = {} + self._factories: dict[str, Any] = action_factories or {} + + if actions: + for action in actions: + self._actions[action.id] = action def register(self, action: DemoAction) -> None: + """Register an action instance.""" self._actions[action.id] = action def get(self, action_id: str) -> DemoAction: - if action_id not in self._actions: - raise errors.ActionExecutionError(f"Unknown action '{action_id}'") - return self._actions[action_id] + """Retrieve an action by id. + + Searches in order: + 1. Explicitly registered action instances + 2. Factory functions + 3. Globally registered action classes (from @register_action decorator) + + Raises ActionExecutionError if the action is not found. + """ + if action_id in self._actions: + return self._actions[action_id] + if action_id in self._factories: + action = self._factories[action_id]() + self._actions[action_id] = action + return action + + # Check globally registered actions + registered_classes = get_registered_actions() + if action_id in registered_classes: + action_cls = registered_classes[action_id] + action = action_cls() + self._actions[action_id] = action + return action + + raise errors.ActionExecutionError(f"Unknown action '{action_id}'") def list_metadata(self) -> list[ActionMetadata]: - return [ - ActionMetadata(id=action.id, description=action.description, category=action.category) - for action in self._actions.values() - ] + """List metadata for all registered actions. + + Includes: + - Explicitly registered action instances + - Actions from factory functions + - Globally registered action classes (from @register_action decorator) + """ + seen_ids: set[str] = set() + metadata: list[ActionMetadata] = [] + + # Add explicit instances + for action in self._actions.values(): + metadata.append( + ActionMetadata(id=action.id, description=action.description, category=action.category) + ) + seen_ids.add(action.id) + + # Add factory functions + for factory in self._factories.values(): + action = factory() + if action.id not in seen_ids: + metadata.append( + ActionMetadata(id=action.id, description=action.description, category=action.category) + ) + seen_ids.add(action.id) + + # Add globally registered actions + for action_cls in get_registered_actions().values(): + if action_cls.id not in seen_ids: + action = action_cls() + metadata.append( + ActionMetadata(id=action.id, description=action.description, category=action.category) + ) + seen_ids.add(action.id) + + return metadata diff --git a/src/guide/app/actions/intake/basic.py b/src/guide/app/actions/intake/basic.py index 8f61223..0f06b1d 100644 --- a/src/guide/app/actions/intake/basic.py +++ b/src/guide/app/actions/intake/basic.py @@ -2,11 +2,12 @@ from playwright.async_api import Page from typing import ClassVar, override -from guide.app.actions.base import DemoAction +from guide.app.actions.base import DemoAction, register_action from guide.app.models.domain import ActionContext, ActionResult -from guide.app.strings.service import strings +from guide.app.strings.registry import app_strings +@register_action class FillIntakeBasicAction(DemoAction): id: ClassVar[str] = "fill-intake-basic" description: ClassVar[str] = "Fill the intake description and advance to the next step." @@ -14,9 +15,7 @@ class FillIntakeBasicAction(DemoAction): @override async def run(self, page: Page, context: ActionContext) -> ActionResult: - description_val = strings.text("INTAKE", "CONVEYOR_BELT_REQUEST") - if not isinstance(description_val, str): - raise ValueError("INTAKE.CONVEYOR_BELT_REQUEST must be a string") - await page.fill(strings.selector("INTAKE", "DESCRIPTION_FIELD"), description_val) - await page.click(strings.selector("INTAKE", "NEXT_BUTTON")) + description_val = app_strings.intake.texts.conveyor_belt_request + await page.fill(app_strings.intake.selectors.description_field, description_val) + await page.click(app_strings.intake.selectors.next_button) return ActionResult(details={"message": "Intake filled"}) diff --git a/src/guide/app/actions/registry.py b/src/guide/app/actions/registry.py index 6509472..31dfcdd 100644 --- a/src/guide/app/actions/registry.py +++ b/src/guide/app/actions/registry.py @@ -1,16 +1,33 @@ from guide.app.actions.auth.login import LoginAsPersonaAction from guide.app.actions.base import ActionRegistry, DemoAction -from guide.app.actions.intake.basic import FillIntakeBasicAction -from guide.app.actions.sourcing.add_suppliers import AddThreeSuppliersAction from guide.app.models.personas import PersonaStore +def _load_registered_actions() -> None: + """Load all action modules to trigger @register_action decorators. + + This function must be called to ensure all action classes are registered. + """ + import guide.app.actions.intake.basic + import guide.app.actions.sourcing.add_suppliers + + # Keep module names alive for registration side effects + assert guide.app.actions.intake.basic is not None + assert guide.app.actions.sourcing.add_suppliers is not None + + def default_registry(persona_store: PersonaStore, login_url: str) -> ActionRegistry: + """Create the default action registry with all registered actions. + + Actions that require dependency injection (like LoginAsPersonaAction) are + explicitly instantiated here. Actions decorated with @register_action are + automatically discovered and instantiated by the registry. + """ + _load_registered_actions() actions: list[DemoAction] = [ LoginAsPersonaAction(persona_store, login_url), - FillIntakeBasicAction(), - AddThreeSuppliersAction(), ] return ActionRegistry(actions) + __all__ = ["default_registry", "ActionRegistry", "DemoAction"] diff --git a/src/guide/app/actions/sourcing/add_suppliers.py b/src/guide/app/actions/sourcing/add_suppliers.py index 5b5696f..b8fb4aa 100644 --- a/src/guide/app/actions/sourcing/add_suppliers.py +++ b/src/guide/app/actions/sourcing/add_suppliers.py @@ -2,11 +2,12 @@ from playwright.async_api import Page from typing import ClassVar, override -from guide.app.actions.base import DemoAction +from guide.app.actions.base import DemoAction, register_action from guide.app.models.domain import ActionContext, ActionResult -from guide.app.strings.service import strings +from guide.app.strings.registry import app_strings +@register_action class AddThreeSuppliersAction(DemoAction): id: ClassVar[str] = "add-three-suppliers" description: ClassVar[str] = "Adds three default suppliers to the sourcing event." @@ -14,11 +15,8 @@ class AddThreeSuppliersAction(DemoAction): @override async def run(self, page: Page, context: ActionContext) -> ActionResult: - suppliers_val = strings.text("SUPPLIERS", "DEFAULT_TRIO") - if not isinstance(suppliers_val, list): - raise ValueError("SUPPLIERS.DEFAULT_TRIO must be a list of strings") - suppliers: list[str] = list(suppliers_val) + suppliers = app_strings.sourcing.texts.default_trio for supplier in suppliers: - await page.fill(strings.selector("SOURCING", "SUPPLIER_SEARCH_INPUT"), supplier) - await page.click(strings.selector("SOURCING", "ADD_SUPPLIER_BUTTON")) + await page.fill(app_strings.sourcing.selectors.supplier_search_input, supplier) + await page.click(app_strings.sourcing.selectors.add_supplier_button) return ActionResult(details={"added_suppliers": list(suppliers)}) diff --git a/src/guide/app/api/routes/actions.py b/src/guide/app/api/routes/actions.py index e5e4512..5b0704a 100644 --- a/src/guide/app/api/routes/actions.py +++ b/src/guide/app/api/routes/actions.py @@ -67,7 +67,7 @@ async def execute_action( persona = personas.get(payload.persona_id) if payload.persona_id else None target_host_id = payload.browser_host_id or (persona.browser_host_id if persona else None) - target_host_id = target_host_id or browser_client.settings.default_browser_host_id + target_host_id = target_host_id or settings.default_browser_host_id context = ActionContext( action_id=action_id, diff --git a/src/guide/app/auth/session.py b/src/guide/app/auth/session.py index e7b3d61..f9c8694 100644 --- a/src/guide/app/auth/session.py +++ b/src/guide/app/auth/session.py @@ -2,19 +2,18 @@ from playwright.async_api import Page from guide.app.auth.mfa import MfaCodeProvider from guide.app.models.personas.models import DemoPersona -from guide.app.strings.service import strings +from guide.app.strings.registry import app_strings async def detect_current_persona(page: Page) -> str | None: """Return the email/identifier of the currently signed-in user, if visible.""" - current_selector = strings.selector("AUTH", "CURRENT_USER_DISPLAY") - element = page.locator(current_selector) + element = page.locator(app_strings.auth.selectors.current_user_display) if await element.count() == 0: return None text = await element.first.text_content() if text is None: return None - prefix = strings.label("AUTH", "CURRENT_USER_DISPLAY_PREFIX") + prefix = app_strings.auth.labels.current_user_display_prefix if prefix and text.startswith(prefix): return text.removeprefix(prefix).strip() return text.strip() @@ -24,16 +23,15 @@ async def login_with_mfa(page: Page, email: str, mfa_provider: MfaCodeProvider, if login_url: _response = await page.goto(login_url) del _response - await page.fill(strings.selector("AUTH", "EMAIL_INPUT"), email) - await page.click(strings.selector("AUTH", "SEND_CODE_BUTTON")) + await page.fill(app_strings.auth.selectors.email_input, email) + await page.click(app_strings.auth.selectors.send_code_button) code = mfa_provider.get_code(email) - await page.fill(strings.selector("AUTH", "CODE_INPUT"), code) - await page.click(strings.selector("AUTH", "SUBMIT_BUTTON")) + await page.fill(app_strings.auth.selectors.code_input, code) + await page.click(app_strings.auth.selectors.submit_button) async def logout(page: Page) -> None: - logout_selector = strings.selector("AUTH", "LOGOUT_BUTTON") - await page.click(logout_selector) + await page.click(app_strings.auth.selectors.logout_button) async def ensure_persona(page: Page, persona: DemoPersona, mfa_provider: MfaCodeProvider, login_url: str | None = None) -> None: diff --git a/src/guide/app/browser/client.py b/src/guide/app/browser/client.py index 0f84db5..7a4c677 100644 --- a/src/guide/app/browser/client.py +++ b/src/guide/app/browser/client.py @@ -1,88 +1,42 @@ import contextlib from collections.abc import AsyncIterator -from playwright.async_api import Browser, BrowserContext, Page, Playwright, async_playwright +from playwright.async_api import Page -from guide.app.core.config import AppSettings, BrowserHostConfig, HostKind -from guide.app import errors +from guide.app.browser.pool import BrowserPool class BrowserClient: - """Connector that yields a Playwright Page for either CDP or headless hosts.""" + """Provides page access via a persistent browser pool. - settings: AppSettings + This client uses the BrowserPool to efficiently manage connections. + Instead of opening/closing browsers per request, the pool maintains + long-lived connections and allocates pages on demand. + """ - def __init__(self, settings: AppSettings) -> None: - self.settings = settings + def __init__(self, pool: BrowserPool) -> None: + """Initialize with a browser pool. - def _resolve_host(self, host_id: str | None) -> BrowserHostConfig: - resolved_id = host_id or self.settings.default_browser_host_id - host = self.settings.browser_hosts.get(resolved_id) - if not host: - known = ", ".join(self.settings.browser_hosts.keys()) or "" - raise errors.ConfigError(f"Unknown browser host '{resolved_id}'. Known: {known}") - return host + Args: + pool: The BrowserPool instance to use + """ + self.pool = pool @contextlib.asynccontextmanager async def open_page(self, host_id: str | None = None) -> AsyncIterator[Page]: - host = self._resolve_host(host_id) - playwright = await async_playwright().start() - browser: Browser | None = None - context: BrowserContext | None = None - try: - if host.kind == HostKind.CDP: - if not host.host or host.port is None: - raise errors.ConfigError("CDP host requires host and port fields.") - cdp_url = f"http://{host.host}:{host.port}" - try: - browser = await playwright.chromium.connect_over_cdp(cdp_url) - except Exception as exc: # pragma: no cover - network dependent - raise errors.BrowserConnectionError( - f"Cannot connect to CDP endpoint {cdp_url}", details={"host_id": host.id} - ) from exc - page = self._pick_raindrop_page(browser) - if not page: - raise errors.BrowserConnectionError( - "No Raindrop page found in connected browser.", details={"host_id": host.id} - ) - else: - browser_type = _resolve_browser_type(playwright, host.browser) - browser = await browser_type.launch(headless=True) - context = await browser.new_context() - page = await context.new_page() - yield page - finally: - with contextlib.suppress(Exception): - if context: - await context.close() - with contextlib.suppress(Exception): - if browser: - await browser.close() - with contextlib.suppress(Exception): - await playwright.stop() + """Get a page from the pool for the specified host. - def _pick_raindrop_page(self, browser: Browser) -> Page | None: - target_substr = self.settings.raindrop_base_url - pages: list[Page] = [] - for context in browser.contexts: - pages.extend(context.pages) - pages = pages or list(browser.contexts[0].pages) if browser.contexts else [] - return next( - ( - page - for page in reversed(pages) - if target_substr in (page.url or "") - ), - pages[-1] if pages else None, - ) + The page is obtained from the pool's persistent browser connection. + No browser startup/connection overhead occurs on each request. + + Args: + host_id: The host identifier, or None for the default host + + Yields: + A Playwright Page instance + """ + page = await self.pool.get_page(host_id) + yield page -def _resolve_browser_type(playwright: Playwright, browser: str | None): - desired = (browser or "chromium").lower() - if desired == "chromium": - return playwright.chromium - if desired == "firefox": - return playwright.firefox - if desired == "webkit": - return playwright.webkit - raise errors.ConfigError(f"Unsupported headless browser type '{browser}'") +__all__ = ["BrowserClient"] diff --git a/src/guide/app/core/config.py b/src/guide/app/core/config.py index e824a36..ffa9612 100644 --- a/src/guide/app/core/config.py +++ b/src/guide/app/core/config.py @@ -1,4 +1,5 @@ import json +import logging import os from enum import Enum from pathlib import Path @@ -8,6 +9,8 @@ from typing import ClassVar, TypeAlias, TypeGuard, cast from pydantic import BaseModel, Field from pydantic_settings import BaseSettings, SettingsConfigDict +_logger = logging.getLogger(__name__) + CONFIG_DIR = Path(__file__).resolve().parents[4] / "config" HOSTS_FILE = CONFIG_DIR / "hosts.yaml" PERSONAS_FILE = CONFIG_DIR / "personas.yaml" @@ -17,19 +20,25 @@ RecordList: TypeAlias = list[JsonRecord] def _coerce_mapping(mapping: Mapping[object, object]) -> dict[str, object]: + """Convert mapping keys to strings.""" return {str(key): value for key, value in mapping.items()} def _is_object_mapping(value: object) -> TypeGuard[Mapping[object, object]]: + """Check if value is a Mapping.""" return isinstance(value, Mapping) class HostKind(str, Enum): + """Browser host kind: CDP or headless.""" + CDP = "cdp" HEADLESS = "headless" class BrowserHostConfig(BaseModel): + """Configuration for a browser host (CDP or headless).""" + id: str kind: HostKind host: str | None = None @@ -38,6 +47,8 @@ class BrowserHostConfig(BaseModel): class PersonaConfig(BaseModel): + """Configuration for a demo persona.""" + id: str role: str email: str @@ -46,7 +57,20 @@ class PersonaConfig(BaseModel): class AppSettings(BaseSettings): - model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(env_prefix="RAINDROP_DEMO_") + """Application settings loaded from YAML files + environment variables. + + Configuration sources (in order): + 1. Environment variables with prefix RAINDROP_DEMO_ + 2. YAML files: config/hosts.yaml, config/personas.yaml + 3. JSON overrides via RAINDROP_DEMO_BROWSER_HOSTS_JSON, RAINDROP_DEMO_PERSONAS_JSON + + The JSON overrides take precedence over file-based config. + """ + + model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict( + env_prefix="RAINDROP_DEMO_", + case_sensitive=False, + ) raindrop_base_url: str = "https://app.raindrop.com" raindrop_graphql_url: str = "https://app.raindrop.com/graphql" @@ -55,7 +79,8 @@ class AppSettings(BaseSettings): personas: dict[str, PersonaConfig] = Field(default_factory=dict) -def _load_hosts_file(path: Path) -> dict[str, BrowserHostConfig]: +def _load_yaml_file(path: Path) -> dict[str, object]: + """Load YAML file, handling missing PyYAML gracefully.""" if not path.exists(): return {} @@ -63,123 +88,118 @@ def _load_hosts_file(path: Path) -> dict[str, BrowserHostConfig]: import yaml except ModuleNotFoundError as exc: raise RuntimeError( - "hosts.yaml found but PyYAML is not installed. Add 'pyyaml' to dependencies or remove hosts.yaml." + f"{path.name} found but PyYAML is not installed. Add 'pyyaml' to dependencies." ) from exc - data_raw: object = yaml.safe_load(path.read_text()) or {} - hosts: dict[str, BrowserHostConfig] = {} - for item in _normalize_records(data_raw, key_name="hosts"): - host = BrowserHostConfig.model_validate(item) - hosts[host.id] = host - return hosts + return cast(dict[str, object], yaml.safe_load(path.read_text()) or {}) -def _parse_json_hosts(value: str) -> dict[str, BrowserHostConfig]: - decoded_raw: object = cast(object, json.loads(value)) - decoded = _ensure_record_list(decoded_raw, "RAINDROP_DEMO_BROWSER_HOSTS_JSON") - hosts: dict[str, BrowserHostConfig] = {} - for item in decoded: - host = BrowserHostConfig.model_validate(item) - hosts[host.id] = host - return hosts - - -def _parse_json_personas(value: str) -> dict[str, PersonaConfig]: - decoded_raw: object = cast(object, json.loads(value)) - decoded = _ensure_record_list(decoded_raw, "RAINDROP_DEMO_PERSONAS_JSON") - personas: dict[str, PersonaConfig] = {} - for item in decoded: - persona = PersonaConfig.model_validate(item) - personas[persona.id] = persona - return personas - - -def _load_personas_file(path: Path) -> dict[str, PersonaConfig]: - if not path.exists(): - return {} - - try: - import yaml - except ModuleNotFoundError as exc: - raise RuntimeError( - "personas.yaml found but PyYAML is not installed. Add 'pyyaml' to dependencies or remove personas.yaml." - ) from exc - - data_raw: object = yaml.safe_load(path.read_text()) or {} - personas: dict[str, PersonaConfig] = {} - for item in _normalize_records(data_raw, key_name="personas"): - persona = PersonaConfig.model_validate(item) - personas[persona.id] = persona - return personas - - -def _normalize_records(data_raw: object, key_name: str) -> RecordList: - if isinstance(data_raw, Mapping): - mapping_raw: dict[str, object] = _coerce_mapping(cast(Mapping[object, object], data_raw)) - content: object = mapping_raw.get(key_name, mapping_raw) +def _normalize_host_records(data: object) -> RecordList: + """Normalize host records from YAML or list format.""" + if isinstance(data, Mapping): + mapping = _coerce_mapping(cast(Mapping[object, object], data)) + content = mapping.get("hosts", mapping) else: - content = data_raw + content = data + records: RecordList = [] if isinstance(content, Mapping): - mapping_content: dict[str, object] = _coerce_mapping(cast(Mapping[object, object], content)) - records: RecordList = [] + mapping_content = _coerce_mapping(cast(Mapping[object, object], content)) for key, value in mapping_content.items(): - mapping = _ensure_mapping(value) - if "id" not in mapping: - mapping["id"] = key - records.append(mapping) - return records - - if isinstance(content, list): - list_content: list[object] = cast(list[object], content) - list_records: RecordList = [] - for item in list_content: + record = _coerce_mapping(cast(Mapping[object, object], value)) if _is_object_mapping(value) else {} + if "id" not in record: + record["id"] = key + records.append(record) + elif isinstance(content, list): + for item in cast(list[object], content): if _is_object_mapping(item): - mapping_item = _coerce_mapping(item) - list_records.append(dict(mapping_item)) - return list_records + records.append(_coerce_mapping(cast(Mapping[object, object], item))) - return [] + return records -def _ensure_mapping(value: object) -> JsonRecord: - if isinstance(value, Mapping): - return _coerce_mapping(cast(Mapping[object, object], value)) - return {} +def _normalize_persona_records(data: object) -> RecordList: + """Normalize persona records from YAML or list format.""" + if isinstance(data, Mapping): + mapping = _coerce_mapping(cast(Mapping[object, object], data)) + content = mapping.get("personas", mapping) + else: + content = data - -def _ensure_record_list(value: object, source_label: str) -> RecordList: - if isinstance(value, list): - list_content: list[object] = cast(list[object], value) - list_records: RecordList = [] - for item in list_content: + records: RecordList = [] + if isinstance(content, Mapping): + mapping_content = _coerce_mapping(cast(Mapping[object, object], content)) + for key, value in mapping_content.items(): + record = _coerce_mapping(cast(Mapping[object, object], value)) if _is_object_mapping(value) else {} + if "id" not in record: + record["id"] = key + records.append(record) + elif isinstance(content, list): + for item in cast(list[object], content): if _is_object_mapping(item): - mapping_item = _coerce_mapping(item) - list_records.append(dict(mapping_item)) - return list_records - raise ValueError(f"{source_label} must be a JSON array of objects") + records.append(_coerce_mapping(cast(Mapping[object, object], item))) + + return records def load_settings() -> AppSettings: + """Load application settings from files and environment. + + Configuration is loaded in this order (later overrides earlier): + 1. Default values in AppSettings model + 2. YAML files (config/hosts.yaml, config/personas.yaml) + 3. Environment variables (RAINDROP_DEMO_* prefix) + 4. JSON overrides via environment (RAINDROP_DEMO_BROWSER_HOSTS_JSON, etc.) + """ settings = AppSettings() - merged_hosts: dict[str, BrowserHostConfig] = {} - merged_personas: dict[str, PersonaConfig] = {} - merged_hosts |= _load_hosts_file(HOSTS_FILE) - merged_personas |= _load_personas_file(PERSONAS_FILE) + # Load from YAML files + hosts_data = _load_yaml_file(HOSTS_FILE) + personas_data = _load_yaml_file(PERSONAS_FILE) - if env_json := os.environ.get("RAINDROP_DEMO_BROWSER_HOSTS_JSON"): - merged_hosts.update(_parse_json_hosts(env_json)) + hosts_dict: dict[str, BrowserHostConfig] = {} + for record in _normalize_host_records(hosts_data): + host = BrowserHostConfig.model_validate(record) + hosts_dict[host.id] = host - if env_personas := os.environ.get("RAINDROP_DEMO_PERSONAS_JSON"): - merged_personas.update(_parse_json_personas(env_personas)) + personas_dict: dict[str, PersonaConfig] = {} + for record in _normalize_persona_records(personas_data): + persona = PersonaConfig.model_validate(record) + personas_dict[persona.id] = persona - if merged_hosts or merged_personas: + # Load JSON overrides from environment + if browser_hosts_json := os.environ.get("RAINDROP_DEMO_BROWSER_HOSTS_JSON"): + try: + decoded_hosts: object = json.loads(browser_hosts_json) + if not isinstance(decoded_hosts, list): + raise ValueError("RAINDROP_DEMO_BROWSER_HOSTS_JSON must be a JSON array") + for record in cast(list[object], decoded_hosts): + if _is_object_mapping(record): + host = BrowserHostConfig.model_validate(record) + hosts_dict[host.id] = host + except (json.JSONDecodeError, ValueError) as exc: + _logger.warning(f"Failed to parse RAINDROP_DEMO_BROWSER_HOSTS_JSON: {exc}") + + if personas_json := os.environ.get("RAINDROP_DEMO_PERSONAS_JSON"): + try: + decoded_personas: object = json.loads(personas_json) + if not isinstance(decoded_personas, list): + raise ValueError("RAINDROP_DEMO_PERSONAS_JSON must be a JSON array") + for record in cast(list[object], decoded_personas): + if _is_object_mapping(record): + persona = PersonaConfig.model_validate(record) + personas_dict[persona.id] = persona + except (json.JSONDecodeError, ValueError) as exc: + _logger.warning(f"Failed to parse RAINDROP_DEMO_PERSONAS_JSON: {exc}") + + # Update settings with loaded configuration + if hosts_dict or personas_dict: settings = settings.model_copy( update={ - "browser_hosts": merged_hosts or settings.browser_hosts, - "personas": merged_personas or settings.personas, + "browser_hosts": hosts_dict or settings.browser_hosts, + "personas": personas_dict or settings.personas, } ) + _logger.info(f"Loaded {len(settings.browser_hosts)} browser hosts, {len(settings.personas)} personas") return settings diff --git a/src/guide/app/main.py b/src/guide/app/main.py index e60e34f..16a0f9d 100644 --- a/src/guide/app/main.py +++ b/src/guide/app/main.py @@ -1,15 +1,16 @@ from fastapi import FastAPI +from fastapi import Request +from fastapi.exception_handlers import http_exception_handler +from fastapi.exceptions import HTTPException from guide.app.actions.registry import default_registry from guide.app.browser.client import BrowserClient +from guide.app.browser.pool import BrowserPool from guide.app.core.config import AppSettings, load_settings from guide.app.core.logging import configure_logging from guide.app.api import router as api_router from guide.app import errors from guide.app.models.personas import PersonaStore -from fastapi import Request -from fastapi.exception_handlers import http_exception_handler -from fastapi.exceptions import HTTPException def create_app() -> FastAPI: @@ -18,16 +19,30 @@ def create_app() -> FastAPI: settings = load_settings() persona_store = PersonaStore(settings) registry = default_registry(persona_store, settings.raindrop_base_url) - browser_client = BrowserClient(settings) + + # Create browser pool for efficient connection management + browser_pool = BrowserPool(settings) + browser_client = BrowserClient(browser_pool) app = FastAPI(title="Raindrop Demo Automation", version="0.1.0") app.state.settings = settings app.state.action_registry = registry app.state.browser_client = browser_client + app.state.browser_pool = browser_pool app.state.persona_store = persona_store # Dependency overrides so FastAPI deps can pull settings without globals app.dependency_overrides = {AppSettings: lambda: settings} + + # Startup/shutdown lifecycle for browser pool + @app.on_event("startup") + async def startup_browser_pool() -> None: + await browser_pool.initialize() + + @app.on_event("shutdown") + async def shutdown_browser_pool() -> None: + await browser_pool.close() + app.include_router(api_router) app.add_exception_handler(errors.GuideError, guide_exception_handler) app.add_exception_handler(Exception, general_exception_handler) diff --git a/src/guide/app/raindrop/operations/intake.py b/src/guide/app/raindrop/operations/intake.py index dff35d1..b0fb68d 100644 --- a/src/guide/app/raindrop/operations/intake.py +++ b/src/guide/app/raindrop/operations/intake.py @@ -1,12 +1,23 @@ from guide.app.models.personas.models import DemoPersona from guide.app.models.types import JSONValue from guide.app.raindrop.graphql import GraphQLClient +from guide.app.raindrop.types import GetIntakeRequestResponse, CreateIntakeRequestResponse, IntakeRequestData from guide.app.strings import graphql as gql_strings async def get_intake_request( client: GraphQLClient, persona: DemoPersona, request_id: str -) -> dict[str, JSONValue]: +) -> IntakeRequestData: + """Fetch an intake request by ID. + + Args: + client: GraphQL client + persona: Persona making the request + request_id: The intake request ID + + Returns: + IntakeRequestData: Type-safe response data + """ variables: dict[str, JSONValue] = {"id": request_id} data = await client.execute( query=gql_strings.GET_INTAKE_REQUEST, @@ -14,13 +25,25 @@ async def get_intake_request( persona=persona, operation_name="GetIntakeRequest", ) - result = data.get("intakeRequest") - return result if isinstance(result, dict) else {} + response = GetIntakeRequestResponse.model_validate(data) + if not response.intake_request: + raise ValueError(f"No intake request found with id: {request_id}") + return response.intake_request async def create_intake_request( client: GraphQLClient, persona: DemoPersona, payload: dict[str, JSONValue] -) -> dict[str, JSONValue]: +) -> IntakeRequestData: + """Create a new intake request. + + Args: + client: GraphQL client + persona: Persona making the request + payload: Intake request input data + + Returns: + IntakeRequestData: The created intake request + """ variables: dict[str, JSONValue] = {"input": payload} data = await client.execute( query=gql_strings.CREATE_INTAKE_REQUEST, @@ -28,5 +51,7 @@ async def create_intake_request( persona=persona, operation_name="CreateIntakeRequest", ) - result = data.get("createIntakeRequest") - return result if isinstance(result, dict) else {} + response = CreateIntakeRequestResponse.model_validate(data) + if not response.create_intake_request: + raise ValueError("Failed to create intake request") + return response.create_intake_request diff --git a/src/guide/app/raindrop/operations/sourcing.py b/src/guide/app/raindrop/operations/sourcing.py index e05df00..78c0687 100644 --- a/src/guide/app/raindrop/operations/sourcing.py +++ b/src/guide/app/raindrop/operations/sourcing.py @@ -1,10 +1,21 @@ from guide.app.models.personas.models import DemoPersona from guide.app.models.types import JSONValue from guide.app.raindrop.graphql import GraphQLClient +from guide.app.raindrop.types import ListSuppliersResponse, AddSupplierResponse, SupplierData from guide.app.strings import graphql as gql_strings -async def list_suppliers(client: GraphQLClient, persona: DemoPersona, limit: int = 10) -> list[dict[str, JSONValue]]: +async def list_suppliers(client: GraphQLClient, persona: DemoPersona, limit: int = 10) -> list[SupplierData]: + """Fetch a list of suppliers. + + Args: + client: GraphQL client + persona: Persona making the request + limit: Maximum number of suppliers to return + + Returns: + list[SupplierData]: Type-safe list of suppliers + """ variables: dict[str, JSONValue] = {"limit": limit} data = await client.execute( query=gql_strings.LIST_SUPPLIERS, @@ -12,19 +23,23 @@ async def list_suppliers(client: GraphQLClient, persona: DemoPersona, limit: int persona=persona, operation_name="ListSuppliers", ) - suppliers = data.get("suppliers") - if not isinstance(suppliers, list): - return [] - filtered: list[dict[str, JSONValue]] = [] - for item in suppliers: - if isinstance(item, dict): - filtered.append(item) - return filtered + response = ListSuppliersResponse.model_validate(data) + return response.suppliers async def add_supplier( client: GraphQLClient, persona: DemoPersona, supplier: dict[str, JSONValue] -) -> dict[str, JSONValue]: +) -> SupplierData: + """Add a supplier to the sourcing event. + + Args: + client: GraphQL client + persona: Persona making the request + supplier: Supplier input data + + Returns: + SupplierData: The created supplier + """ variables: dict[str, JSONValue] = {"input": supplier} data = await client.execute( query=gql_strings.ADD_SUPPLIER, @@ -32,5 +47,7 @@ async def add_supplier( persona=persona, operation_name="AddSupplier", ) - result = data.get("addSupplier") - return result if isinstance(result, dict) else {} + response = AddSupplierResponse.model_validate(data) + if not response.add_supplier: + raise ValueError("Failed to add supplier") + return response.add_supplier diff --git a/src/guide/app/strings/service.py b/src/guide/app/strings/service.py deleted file mode 100644 index 7d0f58b..0000000 --- a/src/guide/app/strings/service.py +++ /dev/null @@ -1,98 +0,0 @@ -from typing import TypeAlias, cast - -from guide.app.strings.graphql import ( - ADD_SUPPLIER, - CREATE_INTAKE_REQUEST, - GET_INTAKE_REQUEST, - LIST_SUPPLIERS, -) -from guide.app.strings.demo_texts import DemoTexts, EventTexts, IntakeTexts, SupplierTexts -from guide.app.strings.labels import AuthLabels, IntakeLabels, Labels, SourcingLabels -from guide.app.strings.selectors import AuthSelectors, IntakeSelectors, NavigationSelectors, Selectors, SourcingSelectors - -SelectorNamespace: TypeAlias = type[IntakeSelectors] | type[SourcingSelectors] | type[NavigationSelectors] | type[AuthSelectors] -LabelNamespace: TypeAlias = type[IntakeLabels] | type[SourcingLabels] | type[AuthLabels] -TextNamespace: TypeAlias = type[IntakeTexts] | type[SupplierTexts] | type[EventTexts] -TextValue: TypeAlias = str | list[str] - -_SELECTORS: dict[str, SelectorNamespace] = { - "INTAKE": Selectors.INTAKE, - "SOURCING": Selectors.SOURCING, - "NAVIGATION": Selectors.NAVIGATION, - "AUTH": Selectors.AUTH, -} - -_LABELS: dict[str, LabelNamespace] = { - "INTAKE": Labels.INTAKE, - "SOURCING": Labels.SOURCING, - "AUTH": Labels.AUTH, -} - -_TEXTS: dict[str, TextNamespace] = { - "INTAKE": DemoTexts.INTAKE, - "SUPPLIERS": DemoTexts.SUPPLIERS, - "EVENTS": DemoTexts.EVENTS, -} - -_GQL = { - "GET_INTAKE_REQUEST": GET_INTAKE_REQUEST, - "CREATE_INTAKE_REQUEST": CREATE_INTAKE_REQUEST, - "LIST_SUPPLIERS": LIST_SUPPLIERS, - "ADD_SUPPLIER": ADD_SUPPLIER, -} - - -class Strings: - """Unified accessor for selectors, labels, texts, and GraphQL queries.""" - - @staticmethod - def selector(domain: str, name: str) -> str: - ns = _SELECTORS.get(domain.upper()) - if ns is None: - raise KeyError(f"Unknown selector domain '{domain}'") - value = cast(str | None, getattr(ns, name, None)) - return _as_str(value, f"selector {domain}.{name}") - - @staticmethod - def label(domain: str, name: str) -> str: - ns = _LABELS.get(domain.upper()) - if ns is None: - raise KeyError(f"Unknown label domain '{domain}'") - value = cast(str | None, getattr(ns, name, None)) - return _as_str(value, f"label {domain}.{name}") - - @staticmethod - def text(domain: str, name: str) -> TextValue: - ns = _TEXTS.get(domain.upper()) - if ns is None: - raise KeyError(f"Unknown text domain '{domain}'") - value: object | None = getattr(ns, name, None) - if value is None: - raise KeyError(f"Unknown text {domain}.{name}") - if isinstance(value, str): - return value - if isinstance(value, list): - value_list: list[object] = cast(list[object], value) - str_values: list[str] = [] - for item in value_list: - if not isinstance(item, str): - raise TypeError(f"text {domain}.{name} must be a string or list of strings") - str_values.append(item) - return str_values - raise TypeError(f"text {domain}.{name} must be a string or list of strings") - - @staticmethod - def gql(name: str) -> str: - value = _GQL.get(name) - return _as_str(value, f"graphql string {name}") - - -def _as_str(value: object, label: str) -> str: - if not isinstance(value, str): # pragma: no cover - raise TypeError(f"{label} must be a string") - return value - - -strings = Strings() - -__all__ = ["Strings", "strings"]