Add bearer token handling to GraphQL client and enhance session management

- Introduced ExtractedToken dataclass for structured token extraction.
- Implemented discover_auth_tokens function to scan localStorage for auth-related keys.
- Enhanced extract_bearer_token function to prioritize token extraction based on common patterns.
- Updated GraphQLClient to accept bearer_token in headers for API requests.
- Improved JSON parsing and error handling in validate_persona_from_storage function.
This commit is contained in:
2025-12-03 05:19:19 +00:00
parent f5e50f88f9
commit d0ca9c3aa7
3 changed files with 260 additions and 15 deletions

View File

@@ -0,0 +1,55 @@
"""Diagnostic action to discover auth tokens in localStorage."""
from typing import ClassVar, override
from guide.app.actions.base import DemoAction, register_action
from guide.app.auth.session import discover_auth_tokens, extract_bearer_token
from guide.app.browser.types import PageLike
from guide.app.models.domain import ActionContext, ActionResult
@register_action
class DiscoverTokensAction(DemoAction):
"""Discover auth-related localStorage keys and extract bearer token."""
id: ClassVar[str] = "discover-tokens"
description: ClassVar[str] = "Discover auth-related localStorage keys."
category: ClassVar[str] = "diagnose"
@override
async def run(self, page: PageLike, context: ActionContext) -> ActionResult:
"""Scan localStorage for auth tokens and attempt bearer extraction."""
try:
tokens = await discover_auth_tokens(page)
bearer = await extract_bearer_token(page)
# Mask token values for security (show only first/last 4 chars)
masked_tokens: dict[str, str] = {}
for key, value in tokens.items():
if len(value) > 12:
masked_tokens[key] = f"{value[:4]}...{value[-4:]}"
else:
masked_tokens[key] = "****"
return ActionResult(
details={
"discovered_keys": list(tokens.keys()),
"token_count": len(tokens),
"masked_values": masked_tokens,
"bearer_token_found": bearer is not None,
"bearer_source_key": bearer.source_key if bearer else None,
"bearer_preview": (
f"{bearer.value[:8]}...{bearer.value[-4:]}"
if bearer and len(bearer.value) > 16
else None
),
}
)
except Exception as e:
return ActionResult(
details={
"error": str(e),
"token_count": 0,
"bearer_token_found": False,
}
)

View File

@@ -1,4 +1,6 @@
import json
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import cast
from guide.app.auth.mfa import MfaCodeProvider
@@ -8,6 +10,15 @@ from guide.app.models.personas.models import DemoPersona
from guide.app.strings.registry import app_strings
@dataclass(frozen=True)
class ExtractedToken:
"""Bearer token extracted from browser session."""
value: str
source_key: str
extracted_at: datetime
async def detect_current_persona(page: PageLike) -> str | None:
"""Return the email/identifier of the currently signed-in user, if visible."""
element = page.locator(app_strings.auth.current_user_display)
@@ -23,7 +34,10 @@ async def detect_current_persona(page: PageLike) -> str | None:
async def login_with_mfa(
page: PageLike, email: str, mfa_provider: MfaCodeProvider, login_url: str | None = None
page: PageLike,
email: str,
mfa_provider: MfaCodeProvider,
login_url: str | None = None,
) -> None:
"""Log in with MFA. Only proceeds if email input exists after navigation."""
email_input = page.locator(app_strings.auth.email_input)
@@ -114,26 +128,33 @@ async def validate_persona_from_storage(
details={"storage_key": storage_key, "expected_email": expected_email},
)
parsed: str | dict[str, object]
try:
parsed = json.loads(raw)
# json.loads returns Any; cast to union of possible JSON types
loaded = cast(
dict[str, object] | list[object] | str | int | float | bool | None,
json.loads(raw),
)
if isinstance(loaded, dict):
parsed = loaded
elif isinstance(loaded, str):
parsed = loaded
else:
# For list, int, float, bool, None - treat as raw string
parsed = raw
except json.JSONDecodeError:
parsed = raw
# Allow either JSON object with an email field or a plain string payload.
stored_email: str | None
if isinstance(parsed, dict):
parsed_dict = cast(dict[str, object], parsed)
email_value = parsed_dict.get(email_field)
email_value = parsed.get(email_field)
stored_email = email_value if isinstance(email_value, str) else None
else:
stored_email = parsed if isinstance(parsed, str) else None
stored_email = parsed
if not isinstance(stored_email, str):
if isinstance(parsed, dict):
payload_type_name: str = "dict"
elif isinstance(parsed, str):
payload_type_name = "str"
else:
payload_type_name = parsed.__class__.__name__
if stored_email is None:
payload_type_name = "dict" if isinstance(parsed, dict) else "str"
raise PersonaError(
"localStorage user record does not contain an email",
details={
@@ -154,3 +175,164 @@ async def validate_persona_from_storage(
"actual_email": stored_email,
},
)
_JS_DISCOVER_AUTH_TOKENS = """
(() => {
const patterns = ['token', 'jwt', 'bearer', 'auth', 'access', 'session', 'credential'];
const results = {};
for (let i = 0; i < localStorage.length; i++) {
const key = localStorage.key(i);
if (!key) continue;
const lowerKey = key.toLowerCase();
if (patterns.some(p => lowerKey.includes(p))) {
try {
const value = localStorage.getItem(key);
if (value) results[key] = value;
} catch (e) {}
}
}
return results;
})();
"""
async def discover_auth_tokens(page: PageLike) -> dict[str, str]:
"""Scan localStorage for auth-related keys and return all matches.
Search for keys containing patterns like 'token', 'jwt', 'bearer', 'auth',
'access', 'session', or 'credential'.
Args:
page: Page-like object (Playwright Page or ExtensionPage).
Returns:
Dictionary mapping localStorage key names to their values.
"""
result = await page.evaluate(_JS_DISCOVER_AUTH_TOKENS)
if not isinstance(result, dict):
return {}
result_dict = cast(dict[str, object], result)
return {key: str(value) for key, value in result_dict.items() if value is not None}
def _is_jwt_format(value: str) -> bool:
"""Check if value looks like a JWT (three base64 segments separated by dots)."""
parts = value.split(".")
if len(parts) != 3:
return False
# Each part should be non-empty and contain valid base64-ish characters
for part in parts:
if not part or not all(c.isalnum() or c in "-_=" for c in part):
return False
return True
def _extract_token_from_json(value: str) -> str | None:
"""Try to extract a token from a JSON object value."""
try:
# json.loads returns Any; cast to union of possible JSON types
loaded = cast(
dict[str, object] | list[object] | str | int | float | bool | None,
json.loads(value),
)
except json.JSONDecodeError:
return None
if not isinstance(loaded, dict):
return None
parsed_dict = loaded
# Priority order for token fields within JSON
token_fields = [
"access_token",
"accessToken",
"token",
"bearer",
"id_token",
"idToken",
"jwt",
]
for field in token_fields:
token_value = parsed_dict.get(field)
if isinstance(token_value, str) and token_value:
return token_value
return None
async def extract_bearer_token(page: PageLike) -> ExtractedToken | None:
"""Extract bearer token from localStorage, trying common key patterns.
Priority order for token selection:
1. Keys containing 'access_token' or 'accessToken'
2. Keys containing 'bearer'
3. Keys containing 'token' (excluding 'refresh')
4. JWT-formatted values in any auth-related key
For each key, if the value is JSON, attempts to extract token fields.
Otherwise uses the raw value if it looks like a token.
Args:
page: Page-like object (Playwright Page or ExtensionPage).
Returns:
ExtractedToken with the token value and source key, or None if not found.
"""
tokens = await discover_auth_tokens(page)
if not tokens:
return None
# Priority-based key matching
priority_patterns: list[tuple[str, bool]] = [
("access_token", False), # (pattern, exclude_refresh)
("accesstoken", False),
("bearer", False),
("id_token", False),
("idtoken", False),
("token", True), # exclude refresh tokens
]
for pattern, exclude_refresh in priority_patterns:
for key, value in tokens.items():
lower_key = key.lower()
if pattern not in lower_key:
continue
if exclude_refresh and "refresh" in lower_key:
continue
# Try to extract from JSON first
extracted = _extract_token_from_json(value)
if extracted:
return ExtractedToken(
value=extracted,
source_key=key,
extracted_at=datetime.now(timezone.utc),
)
# Use raw value if it looks like a token (JWT or long alphanumeric)
if _is_jwt_format(value) or (len(value) > 20 and value.isalnum()):
return ExtractedToken(
value=value,
source_key=key,
extracted_at=datetime.now(timezone.utc),
)
# Fallback: find any JWT-formatted value
for key, value in tokens.items():
extracted = _extract_token_from_json(value)
if extracted and _is_jwt_format(extracted):
return ExtractedToken(
value=extracted,
source_key=key,
extracted_at=datetime.now(timezone.utc),
)
if _is_jwt_format(value):
return ExtractedToken(
value=value,
source_key=key,
extracted_at=datetime.now(timezone.utc),
)
return None

View File

@@ -19,9 +19,10 @@ class GraphQLClient:
variables: Mapping[str, object] | None,
persona: DemoPersona | None,
operation_name: str | None = None,
bearer_token: str | None = None,
) -> dict[str, object]:
url = self._settings.raindrop_graphql_url
headers = self._build_headers(persona)
headers = self._build_headers(persona, bearer_token=bearer_token)
async with httpx.AsyncClient(timeout=10.0) as client:
try:
resp = await client.post(
@@ -53,8 +54,15 @@ class GraphQLClient:
payload = data.get("data", {})
return cast(dict[str, object], payload) if isinstance(payload, dict) else {}
def _build_headers(self, persona: DemoPersona | None) -> dict[str, str]:
def _build_headers(
self,
persona: DemoPersona | None,
*,
bearer_token: str | None = None,
) -> dict[str, str]:
headers: dict[str, str] = {"Content-Type": "application/json"}
# TODO: attach persona/service auth tokens when available in config
if bearer_token:
headers["Authorization"] = f"Bearer {bearer_token}"
# Reserved for future persona-specific auth
_ = persona
return headers