Files
rag-manager/tests/openapi_mocks.py
2025-09-21 01:38:47 +00:00

1062 lines
40 KiB
Python

from __future__ import annotations
import copy
import json as json_module
import time
from collections.abc import Mapping
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from uuid import uuid4
class OpenAPISpec:
"""Utility for generating example payloads from an OpenAPI document."""
def __init__(self, raw: Mapping[str, Any]):
self._raw = raw
self._paths = raw.get("paths", {})
self._components = raw.get("components", {})
@classmethod
def from_file(cls, path: Path) -> OpenAPISpec:
with path.open("r", encoding="utf-8") as handle:
raw = json_module.load(handle)
return cls(raw)
def resolve_ref(self, ref: str) -> Mapping[str, Any]:
if not ref.startswith("#/"):
raise ValueError(f"Unsupported $ref format: {ref}")
parts = ref.lstrip("#/").split("/")
node: Any = self._raw
for part in parts:
if not isinstance(node, Mapping) or part not in node:
raise KeyError(f"Unable to resolve reference: {ref}")
node = node[part]
if not isinstance(node, Mapping):
raise TypeError(f"Reference {ref} did not resolve to an object")
return node
def generate(self, schema: Mapping[str, Any] | None, name: str = "value") -> Any:
if schema is None:
return None
# Handle references
if "$ref" in schema:
resolved = self.resolve_ref(schema["$ref"])
return self.generate(resolved, name=name)
# Handle explicit values
if explicit_value := self._get_explicit_value(schema):
return explicit_value
# Handle schema combinations
if combination_result := self._handle_schema_combinations(schema, name):
return combination_result
# Handle typed schemas
return self._generate_by_type(schema, name)
def _get_explicit_value(self, schema: Mapping[str, Any]) -> Any:
"""Get explicit values from schema (example, default, enum)."""
if "example" in schema:
return copy.deepcopy(schema["example"])
if "default" in schema and schema["default"] is not None:
return copy.deepcopy(schema["default"])
if "enum" in schema and schema["enum"]:
return copy.deepcopy(schema["enum"][0])
return None
def _handle_schema_combinations(self, schema: Mapping[str, Any], name: str) -> Any:
"""Handle anyOf, oneOf, allOf schema combinations."""
if "anyOf" in schema:
return self._handle_any_of(schema["anyOf"], name)
if "oneOf" in schema:
return self.generate(schema["oneOf"][0], name=name)
if "allOf" in schema:
return self._handle_all_of(schema["allOf"], name)
return None
def _handle_any_of(self, options: list[Any], name: str) -> Any:
"""Handle anyOf schema combinations."""
for option in options:
candidate = self.generate(option, name=name)
if candidate is not None:
return candidate
return None
def _handle_all_of(self, options: list[Any], name: str) -> dict[str, Any]:
"""Handle allOf schema combinations."""
result: dict[str, Any] = {}
for option in options:
fragment = self.generate(option, name=name)
if isinstance(fragment, Mapping):
result |= fragment
return result
def _generate_by_type(self, schema: Mapping[str, Any], name: str) -> Any:
"""Generate value based on schema type."""
type_name = schema.get("type")
if type_name == "object":
return self._generate_object(schema, name)
if type_name == "array":
return self._generate_array(schema, name)
if type_name == "string":
return self._generate_string(schema, name)
if type_name == "integer":
return self._generate_integer(schema)
if type_name == "number":
return 1.0
if type_name == "boolean":
return True
return None if type_name == "null" else {}
def _generate_object(self, schema: Mapping[str, Any], name: str) -> dict[str, Any]:
"""Generate object from schema."""
properties = schema.get("properties", {})
required = schema.get("required", [])
result: dict[str, Any] = {}
keys = list(properties.keys()) or list(required)
for key in keys:
prop_schema = properties.get(key, {})
result[key] = self.generate(prop_schema, name=key)
# Handle additional properties
additional = schema.get("additionalProperties")
if additional and isinstance(additional, Mapping) and not properties:
result["key"] = self.generate(additional, name="key")
return result
def _generate_array(self, schema: Mapping[str, Any], name: str) -> list[Any]:
"""Generate array from schema."""
item_schema = schema.get("items", {})
item = self.generate(item_schema, name=name)
return [] if item is None else [item]
def _generate_string(self, schema: Mapping[str, Any], name: str) -> str:
"""Generate string with format considerations."""
format_hint = schema.get("format")
if format_hint == "date-time":
return "2024-01-01T00:00:00+00:00"
if format_hint == "date":
return "2024-01-01"
if format_hint == "uuid":
return "00000000-0000-4000-8000-000000000000"
return "https://example.com" if format_hint == "uri" else f"{name}-value"
def _generate_integer(self, schema: Mapping[str, Any]) -> int:
"""Generate integer with minimum consideration."""
minimum = schema.get("minimum")
return int(minimum) if minimum is not None else 1
def _split_path(self, path: str) -> list[str]:
if path in {"", "/"}:
return []
return [segment for segment in path.strip("/").split("/") if segment]
def find_operation(
self, method: str, path: str
) -> tuple[Mapping[str, Any] | None, dict[str, str]]:
method = method.lower()
normalized = "/" if path in {"", "/"} else "/" + path.strip("/")
actual_segments = self._split_path(normalized)
for template, operations in self._paths.items():
operation = operations.get(method)
if operation is None:
continue
template_path = template or "/"
template_segments = self._split_path(template_path)
if len(template_segments) != len(actual_segments):
continue
params: dict[str, str] = {}
matched = True
for template_part, actual_part in zip(template_segments, actual_segments, strict=False):
if template_part.startswith("{") and template_part.endswith("}"):
params[template_part[1:-1]] = actual_part
elif template_part != actual_part:
matched = False
break
if matched:
return operation, params
return None, {}
def build_response(
self,
operation: Mapping[str, Any],
status: str | None = None,
) -> tuple[int, Any]:
responses = operation.get("responses", {})
target_status = status or ("200" if "200" in responses else next(iter(responses), "200"))
status_code = 200
try:
status_code = int(target_status)
except ValueError:
pass
response_entry = responses.get(target_status, {})
content = response_entry.get("content", {})
media = next(
(
content[candidate]
for candidate in (
"application/json",
"application/problem+json",
"text/plain",
)
if candidate in content
),
None,
)
schema = media.get("schema") if media else None
payload = self.generate(schema, name="response")
return status_code, payload
def generate_from_ref(
self,
ref: str,
overrides: Mapping[str, Any] | None = None,
) -> Any:
base = self.generate({"$ref": ref})
if overrides is None or not isinstance(base, Mapping):
return copy.deepcopy(base)
merged: dict[str, Any] = dict(base)
for key, value in overrides.items():
if isinstance(value, Mapping) and isinstance(merged.get(key), Mapping):
merged[key] = self.generate_from_mapping(
merged[key],
value,
)
else:
merged[key] = copy.deepcopy(value)
return merged
def generate_from_mapping(
self,
base: Mapping[str, Any],
overrides: Mapping[str, Any],
) -> dict[str, Any]:
result: dict[str, Any] = dict(base)
for key, value in overrides.items():
if isinstance(value, Mapping) and isinstance(result.get(key), Mapping):
result[key] = self.generate_from_mapping(result[key], value)
else:
result[key] = copy.deepcopy(value)
return result
class OpenAPIMockService:
"""Base class for stateful mock services backed by an OpenAPI spec."""
def __init__(self, base_url: str, spec: OpenAPISpec):
self.base_url = base_url.rstrip("/")
self.spec = spec
@staticmethod
def _normalize_path(path: str) -> str:
return "/" if not path or path == "/" else "/" + path.strip("/")
def handle(
self,
*,
method: str,
path: str,
json: Mapping[str, Any] | None,
params: Mapping[str, Any] | None,
files: Any,
) -> tuple[int, Any]:
operation, _ = self.spec.find_operation(method, path)
if operation is None:
return 404, {"detail": f"Unhandled {method.upper()} {path}"}
return self.spec.build_response(operation)
class OpenWebUIMockService(OpenAPIMockService):
"""Stateful mock capturing OpenWebUI knowledge and file operations."""
def __init__(self, base_url: str, spec: OpenAPISpec):
super().__init__(base_url, spec)
self._knowledge: dict[str, dict[str, Any]] = {}
self._files: dict[str, dict[str, Any]] = {}
self._knowledge_counter = 1
self._file_counter = 1
self._default_user = "user-1"
@staticmethod
def _timestamp() -> int:
return int(time.time())
def ensure_knowledge(
self,
*,
name: str,
description: str = "",
knowledge_id: str | None = None,
) -> dict[str, Any]:
identifier = knowledge_id or f"kb-{self._knowledge_counter}"
self._knowledge_counter += 1
entry = self.spec.generate_from_ref("#/components/schemas/KnowledgeUserResponse")
ts = self._timestamp()
entry.update(
{
"id": identifier,
"user_id": self._default_user,
"name": name,
"description": description or entry.get("description") or "",
"created_at": ts,
"updated_at": ts,
"data": entry.get("data") or {},
"meta": entry.get("meta") or {},
"access_control": entry.get("access_control") or {},
"files": [],
}
)
self._knowledge[identifier] = entry
return copy.deepcopy(entry)
def create_file(
self,
*,
filename: str,
user_id: str | None = None,
file_id: str | None = None,
) -> dict[str, Any]:
identifier = file_id or f"file-{self._file_counter}"
self._file_counter += 1
entry = self.spec.generate_from_ref("#/components/schemas/FileModelResponse")
ts = self._timestamp()
entry.update(
{
"id": identifier,
"user_id": user_id or self._default_user,
"filename": filename,
"meta": entry.get("meta") or {},
"created_at": ts,
"updated_at": ts,
}
)
self._files[identifier] = entry
return copy.deepcopy(entry)
def _build_file_metadata(self, file_id: str) -> dict[str, Any]:
metadata = self.spec.generate_from_ref("#/components/schemas/FileMetadataResponse")
source = self._files.get(file_id)
ts = self._timestamp()
metadata.update(
{
"id": file_id,
"meta": (source or {}).get("meta", {}),
"created_at": ts,
"updated_at": ts,
}
)
return metadata
def get_knowledge(self, knowledge_id: str) -> dict[str, Any] | None:
entry = self._knowledge.get(knowledge_id)
return copy.deepcopy(entry) if entry is not None else None
def find_knowledge_by_name(self, name: str) -> tuple[str, dict[str, Any]] | None:
return next(
(
(identifier, copy.deepcopy(entry))
for identifier, entry in self._knowledge.items()
if entry.get("name") == name
),
None,
)
def attach_existing_file(self, knowledge_id: str, file_id: str) -> None:
if knowledge_id not in self._knowledge:
raise KeyError(f"Knowledge {knowledge_id} not found")
knowledge = self._knowledge[knowledge_id]
metadata = self._build_file_metadata(file_id)
knowledge.setdefault("files", [])
knowledge["files"] = [item for item in knowledge["files"] if item.get("id") != file_id]
knowledge["files"].append(metadata)
knowledge["updated_at"] = self._timestamp()
def _knowledge_user_response(self, knowledge: Mapping[str, Any]) -> dict[str, Any]:
payload = self.spec.generate_from_ref("#/components/schemas/KnowledgeUserResponse")
payload.update(
{
"id": knowledge["id"],
"user_id": knowledge["user_id"],
"name": knowledge["name"],
"description": knowledge.get("description", ""),
"data": copy.deepcopy(knowledge.get("data", {})),
"meta": copy.deepcopy(knowledge.get("meta", {})),
"access_control": copy.deepcopy(knowledge.get("access_control", {})),
"created_at": knowledge.get("created_at", self._timestamp()),
"updated_at": knowledge.get("updated_at", self._timestamp()),
"files": copy.deepcopy(knowledge.get("files", [])),
}
)
return payload
def _knowledge_files_response(self, knowledge: Mapping[str, Any]) -> dict[str, Any]:
payload = self.spec.generate_from_ref("#/components/schemas/KnowledgeFilesResponse")
payload.update(
{
"id": knowledge["id"],
"user_id": knowledge["user_id"],
"name": knowledge["name"],
"description": knowledge.get("description", ""),
"data": copy.deepcopy(knowledge.get("data", {})),
"meta": copy.deepcopy(knowledge.get("meta", {})),
"access_control": copy.deepcopy(knowledge.get("access_control", {})),
"created_at": knowledge.get("created_at", self._timestamp()),
"updated_at": knowledge.get("updated_at", self._timestamp()),
"files": copy.deepcopy(knowledge.get("files", [])),
}
)
return payload
def handle(
self,
*,
method: str,
path: str,
json: Mapping[str, Any] | None,
params: Mapping[str, Any] | None,
files: Any,
) -> tuple[int, Any]:
normalized = self._normalize_path(path)
segments = [segment for segment in normalized.strip("/").split("/") if segment]
# Handle knowledge endpoints
if segments[:3] == ["api", "v1", "knowledge"]:
return self._handle_knowledge_endpoints(method, segments, json)
# Handle file endpoints
if segments[:3] == ["api", "v1", "files"]:
return self._handle_file_endpoints(method, segments, files)
# Delegate to parent
return super().handle(method=method, path=path, json=json, params=params, files=files)
def _handle_knowledge_endpoints(
self, method: str, segments: list[str], json: Mapping[str, Any] | None
) -> tuple[int, Any]:
"""Handle knowledge-related API endpoints."""
method_upper = method.upper()
segment_count = len(segments)
# List knowledge endpoints
if (
method_upper == "GET"
and segment_count in {3, 4}
and (segment_count == 3 or segments[3] == "list")
):
return self._list_knowledge()
# Create knowledge endpoint
if method_upper == "POST" and segment_count == 4 and segments[3] == "create":
return self._create_knowledge(json or {})
# Knowledge-specific endpoints
if segment_count >= 4:
return self._handle_specific_knowledge(method_upper, segments, json)
return 404, {"detail": "Knowledge endpoint not found"}
def _list_knowledge(self) -> tuple[int, Any]:
"""List all knowledge entries."""
body = [self._knowledge_user_response(entry) for entry in self._knowledge.values()]
return 200, body
def _create_knowledge(self, payload: Mapping[str, Any]) -> tuple[int, Any]:
"""Create a new knowledge entry."""
entry = self.ensure_knowledge(
name=str(payload.get("name", "knowledge")),
description=str(payload.get("description", "")),
)
return 200, entry
def _handle_specific_knowledge(
self, method_upper: str, segments: list[str], json: Mapping[str, Any] | None
) -> tuple[int, Any]:
"""Handle endpoints for specific knowledge entries."""
knowledge_id = segments[3]
knowledge = self._knowledge.get(knowledge_id)
if knowledge is None:
return 404, {"detail": f"Knowledge {knowledge_id} not found"}
segment_count = len(segments)
# Get knowledge details
if method_upper == "GET" and segment_count == 4:
return 200, self._knowledge_files_response(knowledge)
# File operations
if segment_count == 6 and segments[4] == "file":
return self._handle_knowledge_file_operations(
method_upper, segments[5], knowledge, json or {}
)
# Delete knowledge
if method_upper == "DELETE" and segment_count == 5 and segments[4] == "delete":
self._knowledge.pop(knowledge_id, None)
return 200, True
return 404, {"detail": "Knowledge operation not found"}
def _handle_knowledge_file_operations(
self,
method_upper: str,
operation: str,
knowledge: dict[str, Any],
payload: Mapping[str, Any],
) -> tuple[int, Any]:
"""Handle file operations on knowledge entries."""
if method_upper != "POST":
return 405, {"detail": "Method not allowed"}
if operation == "add":
return self._add_file_to_knowledge(knowledge, payload)
if operation == "remove":
return self._remove_file_from_knowledge(knowledge, payload)
return 404, {"detail": "File operation not found"}
def _add_file_to_knowledge(
self, knowledge: dict[str, Any], payload: Mapping[str, Any]
) -> tuple[int, Any]:
"""Add a file to a knowledge entry."""
file_id = str(payload.get("file_id", ""))
if not file_id:
return 422, {"detail": "file_id is required"}
if file_id not in self._files:
self.create_file(filename=f"{file_id}.txt")
metadata = self._build_file_metadata(file_id)
knowledge.setdefault("files", [])
knowledge["files"] = [item for item in knowledge["files"] if item.get("id") != file_id]
knowledge["files"].append(metadata)
knowledge["updated_at"] = self._timestamp()
return 200, self._knowledge_files_response(knowledge)
def _remove_file_from_knowledge(
self, knowledge: dict[str, Any], payload: Mapping[str, Any]
) -> tuple[int, Any]:
"""Remove a file from a knowledge entry."""
file_id = str(payload.get("file_id", ""))
knowledge["files"] = [
item for item in knowledge.get("files", []) if item.get("id") != file_id
]
knowledge["updated_at"] = self._timestamp()
return 200, self._knowledge_files_response(knowledge)
def _handle_file_endpoints(
self, method: str, segments: list[str], files: Any
) -> tuple[int, Any]:
"""Handle file-related API endpoints."""
method_upper = method.upper()
# Upload file
if method_upper == "POST" and len(segments) == 3:
return self._upload_file(files)
# Delete file
if method_upper == "DELETE" and len(segments) == 4:
return self._delete_file(segments[3])
return 404, {"detail": "File endpoint not found"}
def _upload_file(self, files: Any) -> tuple[int, Any]:
"""Handle file upload."""
filename = "uploaded.txt"
if files and isinstance(files, Mapping):
file_entry = files.get("file")
if isinstance(file_entry, tuple) and len(file_entry) >= 1:
filename = str(file_entry[0]) or filename
entry = self.create_file(filename=filename)
return 200, entry
def _delete_file(self, file_id: str) -> tuple[int, Any]:
"""Delete a file and remove from all knowledge entries."""
self._files.pop(file_id, None)
for knowledge in self._knowledge.values():
knowledge["files"] = [
item for item in knowledge.get("files", []) if item.get("id") != file_id
]
return 200, {"deleted": True}
class R2RMockService(OpenAPIMockService):
"""Stateful mock of core R2R collection endpoints."""
def __init__(self, base_url: str, spec: OpenAPISpec):
super().__init__(base_url, spec)
self._collections: dict[str, dict[str, Any]] = {}
self._documents: dict[str, dict[str, Any]] = {}
@staticmethod
def _iso_now() -> str:
return datetime.now(tz=UTC).isoformat()
def create_collection(
self,
*,
name: str,
description: str = "",
collection_id: str | None = None,
) -> dict[str, Any]:
identifier = collection_id or str(uuid4())
entry = self.spec.generate_from_ref("#/components/schemas/CollectionResponse")
timestamp = self._iso_now()
entry.update(
{
"id": identifier,
"owner_id": entry.get("owner_id") or str(uuid4()),
"name": name,
"description": description or entry.get("description") or "",
"graph_cluster_status": entry.get("graph_cluster_status") or "idle",
"graph_sync_status": entry.get("graph_sync_status") or "synced",
"created_at": timestamp,
"updated_at": timestamp,
"user_count": entry.get("user_count", 1) or 1,
"document_count": entry.get("document_count", 0) or 0,
"documents": entry.get("documents", []),
}
)
self._collections[identifier] = entry
return copy.deepcopy(entry)
def get_collection(self, collection_id: str) -> dict[str, Any] | None:
entry = self._collections.get(collection_id)
return copy.deepcopy(entry) if entry is not None else None
def find_collection_by_name(self, name: str) -> tuple[str, dict[str, Any]] | None:
return next(
(
(identifier, copy.deepcopy(entry))
for identifier, entry in self._collections.items()
if entry.get("name") == name
),
None,
)
def _set_collection_document_ids(
self, collection_id: str, document_id: str, *, add: bool
) -> None:
collection = self._collections.get(collection_id)
if collection is None:
collection = self.create_collection(
name=f"Collection {collection_id}", collection_id=collection_id
)
documents = collection.setdefault("documents", [])
if add:
if document_id not in documents:
documents.append(document_id)
else:
documents[:] = [doc for doc in documents if doc != document_id]
collection["document_count"] = len(documents)
def create_document(
self,
*,
document_id: str,
content: str,
metadata: dict[str, Any],
collection_ids: list[str],
) -> dict[str, Any]:
entry = self.spec.generate_from_ref("#/components/schemas/DocumentResponse")
timestamp = self._iso_now()
entry.update(
{
"id": document_id,
"owner_id": entry.get("owner_id") or str(uuid4()),
"collection_ids": collection_ids,
"metadata": copy.deepcopy(metadata),
"document_type": entry.get("document_type") or "text",
"version": entry.get("version") or "1.0",
"size_in_bytes": metadata.get("char_count") or len(content.encode("utf-8")),
"created_at": entry.get("created_at") or timestamp,
"updated_at": timestamp,
"summary": entry.get("summary"),
"summary_embedding": entry.get("summary_embedding") or None,
"chunks": entry.get("chunks") or [],
"content": content,
}
)
entry["__content"] = content
self._documents[document_id] = entry
for collection_id in collection_ids:
self._set_collection_document_ids(collection_id, document_id, add=True)
return copy.deepcopy(entry)
def get_document(self, document_id: str) -> dict[str, Any] | None:
entry = self._documents.get(document_id)
return None if entry is None else copy.deepcopy(entry)
def delete_document(self, document_id: str) -> bool:
entry = self._documents.pop(document_id, None)
if entry is None:
return False
for collection_id in entry.get("collection_ids", []):
self._set_collection_document_ids(collection_id, document_id, add=False)
return True
def append_document_metadata(
self, document_id: str, metadata_list: list[dict[str, Any]]
) -> dict[str, Any] | None:
entry = self._documents.get(document_id)
if entry is None:
return None
metadata = entry.setdefault("metadata", {})
for item in metadata_list:
for key, value in item.items():
metadata[key] = value
entry["updated_at"] = self._iso_now()
return copy.deepcopy(entry)
def handle(
self,
*,
method: str,
path: str,
json: Mapping[str, Any] | None,
params: Mapping[str, Any] | None,
files: Any,
) -> tuple[int, Any]:
normalized = self._normalize_path(path)
method_upper = method.upper()
# Handle collection endpoints
if normalized == "/v3/collections":
return self._handle_collections_endpoint(method_upper, json)
# Handle document endpoints
segments = [segment for segment in normalized.strip("/").split("/") if segment]
if segments[:2] == ["v3", "documents"]:
return self._handle_documents_endpoint(method_upper, segments, json, files)
# Delegate to parent
return super().handle(method=method, path=path, json=json, params=params, files=files)
def _handle_collections_endpoint(
self, method_upper: str, json: Mapping[str, Any] | None
) -> tuple[int, Any]:
"""Handle collection-related endpoints."""
if method_upper == "GET":
return self._list_collections()
if method_upper == "POST":
return self._create_collection_endpoint(json or {})
return 405, {"detail": "Method not allowed"}
def _list_collections(self) -> tuple[int, Any]:
"""List all collections."""
payload = self.spec.generate_from_ref(
"#/components/schemas/PaginatedR2RResult_list_CollectionResponse__"
)
results = []
for _identifier, entry in self._collections.items():
clone = copy.deepcopy(entry)
clone["document_count"] = len(clone.get("documents", []))
results.append(clone)
payload.update(
{
"results": results,
"total_entries": len(results),
}
)
return 200, payload
def _create_collection_endpoint(self, body: Mapping[str, Any]) -> tuple[int, Any]:
"""Create a new collection."""
entry = self.create_collection(
name=str(body.get("name", "Collection")),
description=str(body.get("description", "")),
)
payload = self.spec.generate_from_ref("#/components/schemas/R2RResults_CollectionResponse_")
payload.update({"results": entry})
return 200, payload
def _handle_documents_endpoint(
self, method_upper: str, segments: list[str], json: Mapping[str, Any] | None, files: Any
) -> tuple[int, Any]:
"""Handle document-related endpoints."""
if method_upper == "POST" and len(segments) == 2:
return self._create_document_endpoint(files)
if len(segments) == 3:
doc_id = segments[2]
if method_upper == "GET":
return self._get_document_endpoint(doc_id)
if method_upper == "DELETE":
return self._delete_document_endpoint(doc_id)
if method_upper == "PATCH" and len(segments) == 4 and segments[3] == "metadata":
return self._update_document_metadata(segments[2], json)
return 404, {"detail": "Document endpoint not found"}
def _create_document_endpoint(self, files: Any) -> tuple[int, Any]:
"""Create a new document from files."""
doc_data = self._extract_document_data_from_files(files)
document = self.create_document(**doc_data)
response_payload = self.spec.generate_from_ref(
"#/components/schemas/R2RResults_IngestionResponse_"
)
ingestion = self.spec.generate_from_ref("#/components/schemas/IngestionResponse")
ingestion.update(
{
"message": ingestion.get("message") or "Ingestion task queued successfully.",
"document_id": document["id"],
"task_id": ingestion.get("task_id") or str(uuid4()),
}
)
response_payload["results"] = ingestion
return 202, response_payload
def _extract_document_data_from_files(self, files: Any) -> dict[str, Any]:
"""Extract document creation data from files."""
metadata_raw = {}
content = ""
doc_id = str(uuid4())
collection_ids: list[str] = []
if not isinstance(files, Mapping):
return {
"document_id": doc_id,
"content": content,
"metadata": metadata_raw,
"collection_ids": collection_ids,
}
# Extract metadata
if "metadata" in files:
metadata_raw = self._extract_metadata_from_files(files["metadata"])
# Extract content
if "raw_text" in files:
content = self._extract_content_from_files(files["raw_text"])
# Extract document ID
if "id" in files:
doc_id = self._extract_doc_id_from_files(files["id"]) or doc_id
# Extract collection IDs
if "collection_ids" in files:
collection_ids = self._extract_collection_ids_from_files(files["collection_ids"])
# Fallback collection resolution
if not collection_ids:
collection_ids = self._resolve_collection_ids_from_metadata(metadata_raw)
return {
"document_id": doc_id,
"content": content,
"metadata": metadata_raw,
"collection_ids": collection_ids,
}
def _extract_metadata_from_files(self, metadata_entry: Any) -> dict[str, Any]:
"""Extract metadata from files entry."""
if isinstance(metadata_entry, tuple) and len(metadata_entry) >= 2:
try:
return json_module.loads(metadata_entry[1] or "{}")
except json_module.JSONDecodeError:
return {}
return {}
def _extract_content_from_files(self, raw_text_entry: Any) -> str:
"""Extract content from files entry."""
if (
isinstance(raw_text_entry, tuple)
and len(raw_text_entry) >= 2
and raw_text_entry[1] is not None
):
return str(raw_text_entry[1])
return ""
def _extract_doc_id_from_files(self, id_entry: Any) -> str | None:
"""Extract document ID from files entry."""
if isinstance(id_entry, tuple) and len(id_entry) >= 2 and id_entry[1]:
return str(id_entry[1])
return None
def _extract_collection_ids_from_files(self, coll_entry: Any) -> list[str]:
"""Extract collection IDs from files entry."""
if isinstance(coll_entry, tuple) and len(coll_entry) >= 2 and coll_entry[1]:
try:
parsed = json_module.loads(coll_entry[1])
if isinstance(parsed, list):
return [str(item) for item in parsed]
except json_module.JSONDecodeError:
pass
return []
def _resolve_collection_ids_from_metadata(self, metadata_raw: dict[str, Any]) -> list[str]:
"""Resolve collection IDs from metadata or use default."""
name = metadata_raw.get("collection_name") or metadata_raw.get("collection")
if isinstance(name, str):
if located := self.find_collection_by_name(name):
return [located[0]]
return [next(iter(self._collections))] if self._collections else []
def _get_document_endpoint(self, doc_id: str) -> tuple[int, Any]:
"""Get a document by ID."""
document = self.get_document(doc_id)
if document is None:
return 404, {"detail": f"Document {doc_id} not found"}
response_payload = self.spec.generate_from_ref(
"#/components/schemas/R2RResults_DocumentResponse_"
)
response_payload["results"] = document
return 200, response_payload
def _delete_document_endpoint(self, doc_id: str) -> tuple[int, Any]:
"""Delete a document by ID."""
success = self.delete_document(doc_id)
response_payload = self.spec.generate_from_ref(
"#/components/schemas/R2RResults_GenericBooleanResponse_"
)
results = response_payload.get("results")
if isinstance(results, Mapping):
results = dict(results)
results.update({"success": success})
else:
results = {"success": success}
response_payload["results"] = results
return (200 if success else 404), response_payload
def _update_document_metadata(
self, doc_id: str, json: Mapping[str, Any] | None
) -> tuple[int, Any]:
"""Update document metadata."""
metadata_list = [dict(item) for item in json] if isinstance(json, list) else []
document = self.append_document_metadata(doc_id, metadata_list)
if document is None:
return 404, {"detail": f"Document {doc_id} not found"}
response_payload = self.spec.generate_from_ref(
"#/components/schemas/R2RResults_DocumentResponse_"
)
response_payload["results"] = document
return 200, response_payload
class FirecrawlMockService(OpenAPIMockService):
"""Stateful mock for Firecrawl map and scrape endpoints."""
def __init__(self, base_url: str, spec: OpenAPISpec) -> None:
super().__init__(base_url, spec)
self._maps: dict[str, list[str]] = {}
self._pages: dict[str, dict[str, Any]] = {}
def register_map_result(self, origin: str, links: list[str]) -> None:
self._maps[origin] = list(links)
def register_page(
self,
url: str,
*,
markdown: str | None = None,
html: str | None = None,
metadata: Mapping[str, Any] | None = None,
links: list[str] | None = None,
) -> None:
self._pages[url] = {
"markdown": markdown,
"html": html,
"metadata": dict(metadata or {}),
"links": list(links or []),
}
def _build_map_payload(self, target_url: str, limit: int | None) -> dict[str, Any]:
payload = self.spec.generate_from_ref("#/components/schemas/MapResponse")
links = self._maps.get(target_url, [target_url])
if limit is not None:
try:
limit_value = int(limit)
if limit_value >= 0:
links = links[:limit_value]
except (TypeError, ValueError):
pass
payload["success"] = True
payload["links"] = [{"url": link} for link in links]
return payload
def _default_metadata(self, url: str) -> dict[str, Any]:
metadata = self.spec.generate_from_ref("#/components/schemas/ScrapeMetadata")
metadata.update(
{
"url": url,
"sourceURL": url,
"scrapeId": metadata.get("scrapeId") or str(uuid4()),
"statusCode": metadata.get("statusCode", 200) or 200,
"contentType": metadata.get("contentType", "text/html") or "text/html",
"creditsUsed": metadata.get("creditsUsed", 1) or 1,
}
)
return metadata
def _build_scrape_payload(self, target_url: str, formats: list[str] | None) -> dict[str, Any]:
payload = self.spec.generate_from_ref("#/components/schemas/ScrapeResponse")
payload["success"] = True
page_info = self._pages.get(target_url, {})
data = payload.get("data", {})
markdown = page_info.get("markdown") or f"# Content for {target_url}\n"
html = page_info.get("html") or f"<h1>Content for {target_url}</h1>"
data.update(
{
"markdown": markdown,
"html": html,
"rawHtml": page_info.get("rawHtml", html),
"links": page_info.get("links", []),
}
)
metadata_payload = self._default_metadata(target_url)
metadata_payload.update(page_info.get("metadata", {}))
data["metadata"] = metadata_payload
payload["data"] = data
return payload
def map_response(self, url: str, limit: int | None = None) -> dict[str, Any]:
return self._build_map_payload(url, limit)
def scrape_response(self, url: str, formats: list[str] | None = None) -> dict[str, Any]:
return self._build_scrape_payload(url, formats)
def handle(
self,
*,
method: str,
path: str,
json: Mapping[str, Any] | None,
params: Mapping[str, Any] | None,
files: Any,
) -> tuple[int, Any]:
normalized = self._normalize_path(path)
method_upper = method.upper()
if normalized == "/v2/map" and method_upper == "POST":
body = json or {}
target_url = str(body.get("url", ""))
limit = body.get("limit")
return 200, self._build_map_payload(target_url, limit)
if normalized == "/v2/scrape" and method_upper == "POST":
body = json or {}
target_url = str(body.get("url", ""))
formats = body.get("formats")
return 200, self._build_scrape_payload(target_url, formats)
return super().handle(method=method, path=path, json=json, params=params, files=files)