This commit is contained in:
2025-11-22 11:48:00 +00:00
parent 64d19d07f4
commit 285fa338f9
2 changed files with 42 additions and 19 deletions

View File

@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Mapping from collections.abc import Callable, Iterable, Mapping
from inspect import Parameter, signature from inspect import Parameter, signature
from typing import ClassVar from typing import ClassVar, override, cast
from playwright.async_api import Page from playwright.async_api import Page
@@ -84,6 +84,7 @@ class CompositeAction(DemoAction):
self.registry: ActionRegistry = registry self.registry: ActionRegistry = registry
self.context: ActionContext | None = None self.context: ActionContext | None = None
@override
async def run(self, page: Page, context: ActionContext) -> ActionResult: async def run(self, page: Page, context: ActionContext) -> ActionResult:
"""Execute all child actions in sequence. """Execute all child actions in sequence.
@@ -216,7 +217,7 @@ class ActionRegistry:
Parameter.VAR_KEYWORD, Parameter.VAR_KEYWORD,
) )
# Check if parameter has a default value # Check if parameter has a default value
has_default = param.default is not Parameter.empty has_default = cast(object, param.default) != Parameter.empty
if not is_var_param and not has_default: if not is_var_param and not has_default:
msg = ( msg = (
f"Action '{action_cls.id}' requires dependency '{param_name}' " f"Action '{action_cls.id}' requires dependency '{param_name}' "

View File

@@ -4,7 +4,7 @@ import os
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from collections.abc import Mapping from collections.abc import Mapping
from typing import ClassVar, TypeAlias from typing import ClassVar, TypeAlias, cast
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -82,7 +82,7 @@ def _load_yaml_file(path: Path) -> dict[str, object]:
) from exc ) from exc
# Load content and validate structure # Load content and validate structure
loaded = yaml.safe_load(path.read_text()) loaded = cast(object | None, yaml.safe_load(path.read_text()))
if loaded is None: if loaded is None:
return {} return {}
@@ -90,7 +90,9 @@ def _load_yaml_file(path: Path) -> dict[str, object]:
if not isinstance(loaded, dict): if not isinstance(loaded, dict):
return {} return {}
result: dict[str, object] = {str(key): value for key, value in loaded.items()} # Explicitly cast to dict after isinstance check for type narrowing
loaded_dict = cast(dict[object, object], loaded)
result: dict[str, object] = {str(key): value for key, value in loaded_dict.items()}
return result return result
@@ -100,8 +102,10 @@ def _normalize_host_records(data: object) -> RecordList:
Processes data into a list of typed records by handling dict and list formats. Processes data into a list of typed records by handling dict and list formats.
""" """
# Extract content from wrapper mapping if present # Extract content from wrapper mapping if present
content: object
if isinstance(data, Mapping): if isinstance(data, Mapping):
mapping = _coerce_mapping(data) data_mapping = cast(Mapping[object, object], data)
mapping = _coerce_mapping(data_mapping)
content = mapping.get("hosts", mapping) content = mapping.get("hosts", mapping)
else: else:
content = data content = data
@@ -110,19 +114,26 @@ def _normalize_host_records(data: object) -> RecordList:
# Process mapping format (dict of hosts keyed by ID) # Process mapping format (dict of hosts keyed by ID)
if isinstance(content, Mapping): if isinstance(content, Mapping):
mapping_content = _coerce_mapping(content) content_mapping = cast(Mapping[object, object], content)
mapping_content = _coerce_mapping(content_mapping)
for key, value in mapping_content.items(): for key, value in mapping_content.items():
# Check if value is a mapping before using it as one # Check if value is a mapping before using it as one
record = _coerce_mapping(value) if isinstance(value, Mapping) else {} if isinstance(value, Mapping):
value_mapping = cast(Mapping[object, object], value)
record: JsonRecord = _coerce_mapping(value_mapping)
else:
record = {}
# Ensure record has an id field # Ensure record has an id field
if "id" not in record: if "id" not in record:
record["id"] = key record["id"] = key
records.append(record) records.append(record)
elif isinstance(content, list): elif isinstance(content, list):
for item in content: content_list = cast(list[object], content)
for item in content_list:
if isinstance(item, Mapping): if isinstance(item, Mapping):
records.append(_coerce_mapping(item)) item_mapping = cast(Mapping[object, object], item)
records.append(_coerce_mapping(item_mapping))
return records return records
@@ -133,8 +144,10 @@ def _normalize_persona_records(data: object) -> RecordList:
Processes data into a list of typed records by handling dict and list formats. Processes data into a list of typed records by handling dict and list formats.
""" """
# Extract content from wrapper mapping if present # Extract content from wrapper mapping if present
content: object
if isinstance(data, Mapping): if isinstance(data, Mapping):
mapping = _coerce_mapping(data) data_mapping = cast(Mapping[object, object], data)
mapping = _coerce_mapping(data_mapping)
content = mapping.get("personas", mapping) content = mapping.get("personas", mapping)
else: else:
content = data content = data
@@ -143,19 +156,26 @@ def _normalize_persona_records(data: object) -> RecordList:
# Process mapping format (dict of personas keyed by ID) # Process mapping format (dict of personas keyed by ID)
if isinstance(content, Mapping): if isinstance(content, Mapping):
mapping_content = _coerce_mapping(content) content_mapping = cast(Mapping[object, object], content)
mapping_content = _coerce_mapping(content_mapping)
for key, value in mapping_content.items(): for key, value in mapping_content.items():
# Check if value is a mapping before using it as one # Check if value is a mapping before using it as one
record = _coerce_mapping(value) if isinstance(value, Mapping) else {} if isinstance(value, Mapping):
value_mapping = cast(Mapping[object, object], value)
record: JsonRecord = _coerce_mapping(value_mapping)
else:
record = {}
# Ensure record has an id field # Ensure record has an id field
if "id" not in record: if "id" not in record:
record["id"] = key record["id"] = key
records.append(record) records.append(record)
elif isinstance(content, list): elif isinstance(content, list):
for item in content: content_list = cast(list[object], content)
for item in content_list:
if isinstance(item, Mapping): if isinstance(item, Mapping):
records.append(_coerce_mapping(item)) item_mapping = cast(Mapping[object, object], item)
records.append(_coerce_mapping(item_mapping))
return records return records
@@ -189,13 +209,14 @@ def load_settings() -> AppSettings:
if browser_hosts_json := os.environ.get("RAINDROP_DEMO_BROWSER_HOSTS_JSON"): if browser_hosts_json := os.environ.get("RAINDROP_DEMO_BROWSER_HOSTS_JSON"):
try: try:
# Validate JSON is a list and process each record # Validate JSON is a list and process each record
decoded = json.loads(browser_hosts_json) decoded = cast(object, json.loads(browser_hosts_json))
if not isinstance(decoded, list): if not isinstance(decoded, list):
raise ValueError( raise ValueError(
"RAINDROP_DEMO_BROWSER_HOSTS_JSON must be a JSON array" "RAINDROP_DEMO_BROWSER_HOSTS_JSON must be a JSON array"
) )
# Iterate only over validated list # Iterate only over validated list
for item in decoded: decoded_list = cast(list[object], decoded)
for item in decoded_list:
if isinstance(item, Mapping): if isinstance(item, Mapping):
host = BrowserHostConfig.model_validate(item) host = BrowserHostConfig.model_validate(item)
hosts_dict[host.id] = host hosts_dict[host.id] = host
@@ -205,11 +226,12 @@ def load_settings() -> AppSettings:
if personas_json := os.environ.get("RAINDROP_DEMO_PERSONAS_JSON"): if personas_json := os.environ.get("RAINDROP_DEMO_PERSONAS_JSON"):
try: try:
# Validate JSON is a list and process each record # Validate JSON is a list and process each record
decoded = json.loads(personas_json) decoded = cast(object, json.loads(personas_json))
if not isinstance(decoded, list): if not isinstance(decoded, list):
raise ValueError("RAINDROP_DEMO_PERSONAS_JSON must be a JSON array") raise ValueError("RAINDROP_DEMO_PERSONAS_JSON must be a JSON array")
# Iterate only over validated list # Iterate only over validated list
for item in decoded: decoded_list = cast(list[object], decoded)
for item in decoded_list:
if isinstance(item, Mapping): if isinstance(item, Mapping):
persona = DemoPersona.model_validate(item) persona = DemoPersona.model_validate(item)
personas_dict[persona.id] = persona personas_dict[persona.id] = persona