1062 lines
40 KiB
Python
1062 lines
40 KiB
Python
from __future__ import annotations
|
|
|
|
import copy
|
|
import json as json_module
|
|
import time
|
|
from collections.abc import Mapping
|
|
from datetime import UTC, datetime
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from uuid import uuid4
|
|
|
|
|
|
class OpenAPISpec:
|
|
"""Utility for generating example payloads from an OpenAPI document."""
|
|
|
|
def __init__(self, raw: Mapping[str, Any]):
|
|
self._raw = raw
|
|
self._paths = raw.get("paths", {})
|
|
self._components = raw.get("components", {})
|
|
|
|
@classmethod
|
|
def from_file(cls, path: Path) -> OpenAPISpec:
|
|
with path.open("r", encoding="utf-8") as handle:
|
|
raw = json_module.load(handle)
|
|
return cls(raw)
|
|
|
|
def resolve_ref(self, ref: str) -> Mapping[str, Any]:
|
|
if not ref.startswith("#/"):
|
|
raise ValueError(f"Unsupported $ref format: {ref}")
|
|
parts = ref.lstrip("#/").split("/")
|
|
node: Any = self._raw
|
|
for part in parts:
|
|
if not isinstance(node, Mapping) or part not in node:
|
|
raise KeyError(f"Unable to resolve reference: {ref}")
|
|
node = node[part]
|
|
if not isinstance(node, Mapping):
|
|
raise TypeError(f"Reference {ref} did not resolve to an object")
|
|
return node
|
|
|
|
def generate(self, schema: Mapping[str, Any] | None, name: str = "value") -> Any:
|
|
if schema is None:
|
|
return None
|
|
|
|
# Handle references
|
|
if "$ref" in schema:
|
|
resolved = self.resolve_ref(schema["$ref"])
|
|
return self.generate(resolved, name=name)
|
|
|
|
# Handle explicit values
|
|
if explicit_value := self._get_explicit_value(schema):
|
|
return explicit_value
|
|
|
|
# Handle schema combinations
|
|
if combination_result := self._handle_schema_combinations(schema, name):
|
|
return combination_result
|
|
|
|
# Handle typed schemas
|
|
return self._generate_by_type(schema, name)
|
|
|
|
def _get_explicit_value(self, schema: Mapping[str, Any]) -> Any:
|
|
"""Get explicit values from schema (example, default, enum)."""
|
|
if "example" in schema:
|
|
return copy.deepcopy(schema["example"])
|
|
if "default" in schema and schema["default"] is not None:
|
|
return copy.deepcopy(schema["default"])
|
|
if "enum" in schema and schema["enum"]:
|
|
return copy.deepcopy(schema["enum"][0])
|
|
return None
|
|
|
|
def _handle_schema_combinations(self, schema: Mapping[str, Any], name: str) -> Any:
|
|
"""Handle anyOf, oneOf, allOf schema combinations."""
|
|
if "anyOf" in schema:
|
|
return self._handle_any_of(schema["anyOf"], name)
|
|
if "oneOf" in schema:
|
|
return self.generate(schema["oneOf"][0], name=name)
|
|
if "allOf" in schema:
|
|
return self._handle_all_of(schema["allOf"], name)
|
|
return None
|
|
|
|
def _handle_any_of(self, options: list[Any], name: str) -> Any:
|
|
"""Handle anyOf schema combinations."""
|
|
for option in options:
|
|
candidate = self.generate(option, name=name)
|
|
if candidate is not None:
|
|
return candidate
|
|
return None
|
|
|
|
def _handle_all_of(self, options: list[Any], name: str) -> dict[str, Any]:
|
|
"""Handle allOf schema combinations."""
|
|
result: dict[str, Any] = {}
|
|
for option in options:
|
|
fragment = self.generate(option, name=name)
|
|
if isinstance(fragment, Mapping):
|
|
result |= fragment
|
|
return result
|
|
|
|
def _generate_by_type(self, schema: Mapping[str, Any], name: str) -> Any:
|
|
"""Generate value based on schema type."""
|
|
type_name = schema.get("type")
|
|
|
|
if type_name == "object":
|
|
return self._generate_object(schema, name)
|
|
if type_name == "array":
|
|
return self._generate_array(schema, name)
|
|
if type_name == "string":
|
|
return self._generate_string(schema, name)
|
|
if type_name == "integer":
|
|
return self._generate_integer(schema)
|
|
if type_name == "number":
|
|
return 1.0
|
|
if type_name == "boolean":
|
|
return True
|
|
return None if type_name == "null" else {}
|
|
|
|
def _generate_object(self, schema: Mapping[str, Any], name: str) -> dict[str, Any]:
|
|
"""Generate object from schema."""
|
|
properties = schema.get("properties", {})
|
|
required = schema.get("required", [])
|
|
result: dict[str, Any] = {}
|
|
|
|
keys = list(properties.keys()) or list(required)
|
|
for key in keys:
|
|
prop_schema = properties.get(key, {})
|
|
result[key] = self.generate(prop_schema, name=key)
|
|
|
|
# Handle additional properties
|
|
additional = schema.get("additionalProperties")
|
|
if additional and isinstance(additional, Mapping) and not properties:
|
|
result["key"] = self.generate(additional, name="key")
|
|
|
|
return result
|
|
|
|
def _generate_array(self, schema: Mapping[str, Any], name: str) -> list[Any]:
|
|
"""Generate array from schema."""
|
|
item_schema = schema.get("items", {})
|
|
item = self.generate(item_schema, name=name)
|
|
return [] if item is None else [item]
|
|
|
|
def _generate_string(self, schema: Mapping[str, Any], name: str) -> str:
|
|
"""Generate string with format considerations."""
|
|
format_hint = schema.get("format")
|
|
|
|
if format_hint == "date-time":
|
|
return "2024-01-01T00:00:00+00:00"
|
|
if format_hint == "date":
|
|
return "2024-01-01"
|
|
if format_hint == "uuid":
|
|
return "00000000-0000-4000-8000-000000000000"
|
|
return "https://example.com" if format_hint == "uri" else f"{name}-value"
|
|
|
|
def _generate_integer(self, schema: Mapping[str, Any]) -> int:
|
|
"""Generate integer with minimum consideration."""
|
|
minimum = schema.get("minimum")
|
|
return int(minimum) if minimum is not None else 1
|
|
|
|
def _split_path(self, path: str) -> list[str]:
|
|
if path in {"", "/"}:
|
|
return []
|
|
return [segment for segment in path.strip("/").split("/") if segment]
|
|
|
|
def find_operation(
|
|
self, method: str, path: str
|
|
) -> tuple[Mapping[str, Any] | None, dict[str, str]]:
|
|
method = method.lower()
|
|
normalized = "/" if path in {"", "/"} else "/" + path.strip("/")
|
|
actual_segments = self._split_path(normalized)
|
|
for template, operations in self._paths.items():
|
|
operation = operations.get(method)
|
|
if operation is None:
|
|
continue
|
|
template_path = template or "/"
|
|
template_segments = self._split_path(template_path)
|
|
if len(template_segments) != len(actual_segments):
|
|
continue
|
|
params: dict[str, str] = {}
|
|
matched = True
|
|
for template_part, actual_part in zip(template_segments, actual_segments, strict=False):
|
|
if template_part.startswith("{") and template_part.endswith("}"):
|
|
params[template_part[1:-1]] = actual_part
|
|
elif template_part != actual_part:
|
|
matched = False
|
|
break
|
|
if matched:
|
|
return operation, params
|
|
return None, {}
|
|
|
|
def build_response(
|
|
self,
|
|
operation: Mapping[str, Any],
|
|
status: str | None = None,
|
|
) -> tuple[int, Any]:
|
|
responses = operation.get("responses", {})
|
|
target_status = status or ("200" if "200" in responses else next(iter(responses), "200"))
|
|
status_code = 200
|
|
try:
|
|
status_code = int(target_status)
|
|
except ValueError:
|
|
pass
|
|
response_entry = responses.get(target_status, {})
|
|
content = response_entry.get("content", {})
|
|
media = next(
|
|
(
|
|
content[candidate]
|
|
for candidate in (
|
|
"application/json",
|
|
"application/problem+json",
|
|
"text/plain",
|
|
)
|
|
if candidate in content
|
|
),
|
|
None,
|
|
)
|
|
schema = media.get("schema") if media else None
|
|
payload = self.generate(schema, name="response")
|
|
return status_code, payload
|
|
|
|
def generate_from_ref(
|
|
self,
|
|
ref: str,
|
|
overrides: Mapping[str, Any] | None = None,
|
|
) -> Any:
|
|
base = self.generate({"$ref": ref})
|
|
if overrides is None or not isinstance(base, Mapping):
|
|
return copy.deepcopy(base)
|
|
merged: dict[str, Any] = dict(base)
|
|
for key, value in overrides.items():
|
|
if isinstance(value, Mapping) and isinstance(merged.get(key), Mapping):
|
|
merged[key] = self.generate_from_mapping(
|
|
merged[key],
|
|
value,
|
|
)
|
|
else:
|
|
merged[key] = copy.deepcopy(value)
|
|
return merged
|
|
|
|
def generate_from_mapping(
|
|
self,
|
|
base: Mapping[str, Any],
|
|
overrides: Mapping[str, Any],
|
|
) -> dict[str, Any]:
|
|
result: dict[str, Any] = dict(base)
|
|
for key, value in overrides.items():
|
|
if isinstance(value, Mapping) and isinstance(result.get(key), Mapping):
|
|
result[key] = self.generate_from_mapping(result[key], value)
|
|
else:
|
|
result[key] = copy.deepcopy(value)
|
|
return result
|
|
|
|
|
|
class OpenAPIMockService:
|
|
"""Base class for stateful mock services backed by an OpenAPI spec."""
|
|
|
|
def __init__(self, base_url: str, spec: OpenAPISpec):
|
|
self.base_url = base_url.rstrip("/")
|
|
self.spec = spec
|
|
|
|
@staticmethod
|
|
def _normalize_path(path: str) -> str:
|
|
return "/" if not path or path == "/" else "/" + path.strip("/")
|
|
|
|
def handle(
|
|
self,
|
|
*,
|
|
method: str,
|
|
path: str,
|
|
json: Mapping[str, Any] | None,
|
|
params: Mapping[str, Any] | None,
|
|
files: Any,
|
|
) -> tuple[int, Any]:
|
|
operation, _ = self.spec.find_operation(method, path)
|
|
if operation is None:
|
|
return 404, {"detail": f"Unhandled {method.upper()} {path}"}
|
|
return self.spec.build_response(operation)
|
|
|
|
|
|
class OpenWebUIMockService(OpenAPIMockService):
|
|
"""Stateful mock capturing OpenWebUI knowledge and file operations."""
|
|
|
|
def __init__(self, base_url: str, spec: OpenAPISpec):
|
|
super().__init__(base_url, spec)
|
|
self._knowledge: dict[str, dict[str, Any]] = {}
|
|
self._files: dict[str, dict[str, Any]] = {}
|
|
self._knowledge_counter = 1
|
|
self._file_counter = 1
|
|
self._default_user = "user-1"
|
|
|
|
@staticmethod
|
|
def _timestamp() -> int:
|
|
return int(time.time())
|
|
|
|
def ensure_knowledge(
|
|
self,
|
|
*,
|
|
name: str,
|
|
description: str = "",
|
|
knowledge_id: str | None = None,
|
|
) -> dict[str, Any]:
|
|
identifier = knowledge_id or f"kb-{self._knowledge_counter}"
|
|
self._knowledge_counter += 1
|
|
entry = self.spec.generate_from_ref("#/components/schemas/KnowledgeUserResponse")
|
|
ts = self._timestamp()
|
|
entry.update(
|
|
{
|
|
"id": identifier,
|
|
"user_id": self._default_user,
|
|
"name": name,
|
|
"description": description or entry.get("description") or "",
|
|
"created_at": ts,
|
|
"updated_at": ts,
|
|
"data": entry.get("data") or {},
|
|
"meta": entry.get("meta") or {},
|
|
"access_control": entry.get("access_control") or {},
|
|
"files": [],
|
|
}
|
|
)
|
|
self._knowledge[identifier] = entry
|
|
return copy.deepcopy(entry)
|
|
|
|
def create_file(
|
|
self,
|
|
*,
|
|
filename: str,
|
|
user_id: str | None = None,
|
|
file_id: str | None = None,
|
|
) -> dict[str, Any]:
|
|
identifier = file_id or f"file-{self._file_counter}"
|
|
self._file_counter += 1
|
|
entry = self.spec.generate_from_ref("#/components/schemas/FileModelResponse")
|
|
ts = self._timestamp()
|
|
entry.update(
|
|
{
|
|
"id": identifier,
|
|
"user_id": user_id or self._default_user,
|
|
"filename": filename,
|
|
"meta": entry.get("meta") or {},
|
|
"created_at": ts,
|
|
"updated_at": ts,
|
|
}
|
|
)
|
|
self._files[identifier] = entry
|
|
return copy.deepcopy(entry)
|
|
|
|
def _build_file_metadata(self, file_id: str) -> dict[str, Any]:
|
|
metadata = self.spec.generate_from_ref("#/components/schemas/FileMetadataResponse")
|
|
source = self._files.get(file_id)
|
|
ts = self._timestamp()
|
|
metadata.update(
|
|
{
|
|
"id": file_id,
|
|
"meta": (source or {}).get("meta", {}),
|
|
"created_at": ts,
|
|
"updated_at": ts,
|
|
}
|
|
)
|
|
return metadata
|
|
|
|
def get_knowledge(self, knowledge_id: str) -> dict[str, Any] | None:
|
|
entry = self._knowledge.get(knowledge_id)
|
|
return copy.deepcopy(entry) if entry is not None else None
|
|
|
|
def find_knowledge_by_name(self, name: str) -> tuple[str, dict[str, Any]] | None:
|
|
return next(
|
|
(
|
|
(identifier, copy.deepcopy(entry))
|
|
for identifier, entry in self._knowledge.items()
|
|
if entry.get("name") == name
|
|
),
|
|
None,
|
|
)
|
|
|
|
def attach_existing_file(self, knowledge_id: str, file_id: str) -> None:
|
|
if knowledge_id not in self._knowledge:
|
|
raise KeyError(f"Knowledge {knowledge_id} not found")
|
|
knowledge = self._knowledge[knowledge_id]
|
|
metadata = self._build_file_metadata(file_id)
|
|
knowledge.setdefault("files", [])
|
|
knowledge["files"] = [item for item in knowledge["files"] if item.get("id") != file_id]
|
|
knowledge["files"].append(metadata)
|
|
knowledge["updated_at"] = self._timestamp()
|
|
|
|
def _knowledge_user_response(self, knowledge: Mapping[str, Any]) -> dict[str, Any]:
|
|
payload = self.spec.generate_from_ref("#/components/schemas/KnowledgeUserResponse")
|
|
payload.update(
|
|
{
|
|
"id": knowledge["id"],
|
|
"user_id": knowledge["user_id"],
|
|
"name": knowledge["name"],
|
|
"description": knowledge.get("description", ""),
|
|
"data": copy.deepcopy(knowledge.get("data", {})),
|
|
"meta": copy.deepcopy(knowledge.get("meta", {})),
|
|
"access_control": copy.deepcopy(knowledge.get("access_control", {})),
|
|
"created_at": knowledge.get("created_at", self._timestamp()),
|
|
"updated_at": knowledge.get("updated_at", self._timestamp()),
|
|
"files": copy.deepcopy(knowledge.get("files", [])),
|
|
}
|
|
)
|
|
return payload
|
|
|
|
def _knowledge_files_response(self, knowledge: Mapping[str, Any]) -> dict[str, Any]:
|
|
payload = self.spec.generate_from_ref("#/components/schemas/KnowledgeFilesResponse")
|
|
payload.update(
|
|
{
|
|
"id": knowledge["id"],
|
|
"user_id": knowledge["user_id"],
|
|
"name": knowledge["name"],
|
|
"description": knowledge.get("description", ""),
|
|
"data": copy.deepcopy(knowledge.get("data", {})),
|
|
"meta": copy.deepcopy(knowledge.get("meta", {})),
|
|
"access_control": copy.deepcopy(knowledge.get("access_control", {})),
|
|
"created_at": knowledge.get("created_at", self._timestamp()),
|
|
"updated_at": knowledge.get("updated_at", self._timestamp()),
|
|
"files": copy.deepcopy(knowledge.get("files", [])),
|
|
}
|
|
)
|
|
return payload
|
|
|
|
def handle(
|
|
self,
|
|
*,
|
|
method: str,
|
|
path: str,
|
|
json: Mapping[str, Any] | None,
|
|
params: Mapping[str, Any] | None,
|
|
files: Any,
|
|
) -> tuple[int, Any]:
|
|
normalized = self._normalize_path(path)
|
|
segments = [segment for segment in normalized.strip("/").split("/") if segment]
|
|
|
|
# Handle knowledge endpoints
|
|
if segments[:3] == ["api", "v1", "knowledge"]:
|
|
return self._handle_knowledge_endpoints(method, segments, json)
|
|
|
|
# Handle file endpoints
|
|
if segments[:3] == ["api", "v1", "files"]:
|
|
return self._handle_file_endpoints(method, segments, files)
|
|
|
|
# Delegate to parent
|
|
return super().handle(method=method, path=path, json=json, params=params, files=files)
|
|
|
|
def _handle_knowledge_endpoints(
|
|
self, method: str, segments: list[str], json: Mapping[str, Any] | None
|
|
) -> tuple[int, Any]:
|
|
"""Handle knowledge-related API endpoints."""
|
|
method_upper = method.upper()
|
|
segment_count = len(segments)
|
|
|
|
# List knowledge endpoints
|
|
if (
|
|
method_upper == "GET"
|
|
and segment_count in {3, 4}
|
|
and (segment_count == 3 or segments[3] == "list")
|
|
):
|
|
return self._list_knowledge()
|
|
|
|
# Create knowledge endpoint
|
|
if method_upper == "POST" and segment_count == 4 and segments[3] == "create":
|
|
return self._create_knowledge(json or {})
|
|
|
|
# Knowledge-specific endpoints
|
|
if segment_count >= 4:
|
|
return self._handle_specific_knowledge(method_upper, segments, json)
|
|
|
|
return 404, {"detail": "Knowledge endpoint not found"}
|
|
|
|
def _list_knowledge(self) -> tuple[int, Any]:
|
|
"""List all knowledge entries."""
|
|
body = [self._knowledge_user_response(entry) for entry in self._knowledge.values()]
|
|
return 200, body
|
|
|
|
def _create_knowledge(self, payload: Mapping[str, Any]) -> tuple[int, Any]:
|
|
"""Create a new knowledge entry."""
|
|
entry = self.ensure_knowledge(
|
|
name=str(payload.get("name", "knowledge")),
|
|
description=str(payload.get("description", "")),
|
|
)
|
|
return 200, entry
|
|
|
|
def _handle_specific_knowledge(
|
|
self, method_upper: str, segments: list[str], json: Mapping[str, Any] | None
|
|
) -> tuple[int, Any]:
|
|
"""Handle endpoints for specific knowledge entries."""
|
|
knowledge_id = segments[3]
|
|
knowledge = self._knowledge.get(knowledge_id)
|
|
if knowledge is None:
|
|
return 404, {"detail": f"Knowledge {knowledge_id} not found"}
|
|
|
|
segment_count = len(segments)
|
|
|
|
# Get knowledge details
|
|
if method_upper == "GET" and segment_count == 4:
|
|
return 200, self._knowledge_files_response(knowledge)
|
|
|
|
# File operations
|
|
if segment_count == 6 and segments[4] == "file":
|
|
return self._handle_knowledge_file_operations(
|
|
method_upper, segments[5], knowledge, json or {}
|
|
)
|
|
|
|
# Delete knowledge
|
|
if method_upper == "DELETE" and segment_count == 5 and segments[4] == "delete":
|
|
self._knowledge.pop(knowledge_id, None)
|
|
return 200, True
|
|
|
|
return 404, {"detail": "Knowledge operation not found"}
|
|
|
|
def _handle_knowledge_file_operations(
|
|
self,
|
|
method_upper: str,
|
|
operation: str,
|
|
knowledge: dict[str, Any],
|
|
payload: Mapping[str, Any],
|
|
) -> tuple[int, Any]:
|
|
"""Handle file operations on knowledge entries."""
|
|
if method_upper != "POST":
|
|
return 405, {"detail": "Method not allowed"}
|
|
|
|
if operation == "add":
|
|
return self._add_file_to_knowledge(knowledge, payload)
|
|
if operation == "remove":
|
|
return self._remove_file_from_knowledge(knowledge, payload)
|
|
|
|
return 404, {"detail": "File operation not found"}
|
|
|
|
def _add_file_to_knowledge(
|
|
self, knowledge: dict[str, Any], payload: Mapping[str, Any]
|
|
) -> tuple[int, Any]:
|
|
"""Add a file to a knowledge entry."""
|
|
file_id = str(payload.get("file_id", ""))
|
|
if not file_id:
|
|
return 422, {"detail": "file_id is required"}
|
|
|
|
if file_id not in self._files:
|
|
self.create_file(filename=f"{file_id}.txt")
|
|
|
|
metadata = self._build_file_metadata(file_id)
|
|
knowledge.setdefault("files", [])
|
|
knowledge["files"] = [item for item in knowledge["files"] if item.get("id") != file_id]
|
|
knowledge["files"].append(metadata)
|
|
knowledge["updated_at"] = self._timestamp()
|
|
|
|
return 200, self._knowledge_files_response(knowledge)
|
|
|
|
def _remove_file_from_knowledge(
|
|
self, knowledge: dict[str, Any], payload: Mapping[str, Any]
|
|
) -> tuple[int, Any]:
|
|
"""Remove a file from a knowledge entry."""
|
|
file_id = str(payload.get("file_id", ""))
|
|
knowledge["files"] = [
|
|
item for item in knowledge.get("files", []) if item.get("id") != file_id
|
|
]
|
|
knowledge["updated_at"] = self._timestamp()
|
|
|
|
return 200, self._knowledge_files_response(knowledge)
|
|
|
|
def _handle_file_endpoints(
|
|
self, method: str, segments: list[str], files: Any
|
|
) -> tuple[int, Any]:
|
|
"""Handle file-related API endpoints."""
|
|
method_upper = method.upper()
|
|
|
|
# Upload file
|
|
if method_upper == "POST" and len(segments) == 3:
|
|
return self._upload_file(files)
|
|
|
|
# Delete file
|
|
if method_upper == "DELETE" and len(segments) == 4:
|
|
return self._delete_file(segments[3])
|
|
|
|
return 404, {"detail": "File endpoint not found"}
|
|
|
|
def _upload_file(self, files: Any) -> tuple[int, Any]:
|
|
"""Handle file upload."""
|
|
filename = "uploaded.txt"
|
|
if files and isinstance(files, Mapping):
|
|
file_entry = files.get("file")
|
|
if isinstance(file_entry, tuple) and len(file_entry) >= 1:
|
|
filename = str(file_entry[0]) or filename
|
|
|
|
entry = self.create_file(filename=filename)
|
|
return 200, entry
|
|
|
|
def _delete_file(self, file_id: str) -> tuple[int, Any]:
|
|
"""Delete a file and remove from all knowledge entries."""
|
|
self._files.pop(file_id, None)
|
|
for knowledge in self._knowledge.values():
|
|
knowledge["files"] = [
|
|
item for item in knowledge.get("files", []) if item.get("id") != file_id
|
|
]
|
|
|
|
return 200, {"deleted": True}
|
|
|
|
|
|
class R2RMockService(OpenAPIMockService):
|
|
"""Stateful mock of core R2R collection endpoints."""
|
|
|
|
def __init__(self, base_url: str, spec: OpenAPISpec):
|
|
super().__init__(base_url, spec)
|
|
self._collections: dict[str, dict[str, Any]] = {}
|
|
self._documents: dict[str, dict[str, Any]] = {}
|
|
|
|
@staticmethod
|
|
def _iso_now() -> str:
|
|
return datetime.now(tz=UTC).isoformat()
|
|
|
|
def create_collection(
|
|
self,
|
|
*,
|
|
name: str,
|
|
description: str = "",
|
|
collection_id: str | None = None,
|
|
) -> dict[str, Any]:
|
|
identifier = collection_id or str(uuid4())
|
|
entry = self.spec.generate_from_ref("#/components/schemas/CollectionResponse")
|
|
timestamp = self._iso_now()
|
|
entry.update(
|
|
{
|
|
"id": identifier,
|
|
"owner_id": entry.get("owner_id") or str(uuid4()),
|
|
"name": name,
|
|
"description": description or entry.get("description") or "",
|
|
"graph_cluster_status": entry.get("graph_cluster_status") or "idle",
|
|
"graph_sync_status": entry.get("graph_sync_status") or "synced",
|
|
"created_at": timestamp,
|
|
"updated_at": timestamp,
|
|
"user_count": entry.get("user_count", 1) or 1,
|
|
"document_count": entry.get("document_count", 0) or 0,
|
|
"documents": entry.get("documents", []),
|
|
}
|
|
)
|
|
self._collections[identifier] = entry
|
|
return copy.deepcopy(entry)
|
|
|
|
def get_collection(self, collection_id: str) -> dict[str, Any] | None:
|
|
entry = self._collections.get(collection_id)
|
|
return copy.deepcopy(entry) if entry is not None else None
|
|
|
|
def find_collection_by_name(self, name: str) -> tuple[str, dict[str, Any]] | None:
|
|
return next(
|
|
(
|
|
(identifier, copy.deepcopy(entry))
|
|
for identifier, entry in self._collections.items()
|
|
if entry.get("name") == name
|
|
),
|
|
None,
|
|
)
|
|
|
|
def _set_collection_document_ids(
|
|
self, collection_id: str, document_id: str, *, add: bool
|
|
) -> None:
|
|
collection = self._collections.get(collection_id)
|
|
if collection is None:
|
|
collection = self.create_collection(
|
|
name=f"Collection {collection_id}", collection_id=collection_id
|
|
)
|
|
documents = collection.setdefault("documents", [])
|
|
if add:
|
|
if document_id not in documents:
|
|
documents.append(document_id)
|
|
else:
|
|
documents[:] = [doc for doc in documents if doc != document_id]
|
|
collection["document_count"] = len(documents)
|
|
|
|
def create_document(
|
|
self,
|
|
*,
|
|
document_id: str,
|
|
content: str,
|
|
metadata: dict[str, Any],
|
|
collection_ids: list[str],
|
|
) -> dict[str, Any]:
|
|
entry = self.spec.generate_from_ref("#/components/schemas/DocumentResponse")
|
|
timestamp = self._iso_now()
|
|
entry.update(
|
|
{
|
|
"id": document_id,
|
|
"owner_id": entry.get("owner_id") or str(uuid4()),
|
|
"collection_ids": collection_ids,
|
|
"metadata": copy.deepcopy(metadata),
|
|
"document_type": entry.get("document_type") or "text",
|
|
"version": entry.get("version") or "1.0",
|
|
"size_in_bytes": metadata.get("char_count") or len(content.encode("utf-8")),
|
|
"created_at": entry.get("created_at") or timestamp,
|
|
"updated_at": timestamp,
|
|
"summary": entry.get("summary"),
|
|
"summary_embedding": entry.get("summary_embedding") or None,
|
|
"chunks": entry.get("chunks") or [],
|
|
"content": content,
|
|
}
|
|
)
|
|
entry["__content"] = content
|
|
self._documents[document_id] = entry
|
|
for collection_id in collection_ids:
|
|
self._set_collection_document_ids(collection_id, document_id, add=True)
|
|
return copy.deepcopy(entry)
|
|
|
|
def get_document(self, document_id: str) -> dict[str, Any] | None:
|
|
entry = self._documents.get(document_id)
|
|
return None if entry is None else copy.deepcopy(entry)
|
|
|
|
def delete_document(self, document_id: str) -> bool:
|
|
entry = self._documents.pop(document_id, None)
|
|
if entry is None:
|
|
return False
|
|
for collection_id in entry.get("collection_ids", []):
|
|
self._set_collection_document_ids(collection_id, document_id, add=False)
|
|
return True
|
|
|
|
def append_document_metadata(
|
|
self, document_id: str, metadata_list: list[dict[str, Any]]
|
|
) -> dict[str, Any] | None:
|
|
entry = self._documents.get(document_id)
|
|
if entry is None:
|
|
return None
|
|
metadata = entry.setdefault("metadata", {})
|
|
for item in metadata_list:
|
|
for key, value in item.items():
|
|
metadata[key] = value
|
|
entry["updated_at"] = self._iso_now()
|
|
return copy.deepcopy(entry)
|
|
|
|
def handle(
|
|
self,
|
|
*,
|
|
method: str,
|
|
path: str,
|
|
json: Mapping[str, Any] | None,
|
|
params: Mapping[str, Any] | None,
|
|
files: Any,
|
|
) -> tuple[int, Any]:
|
|
normalized = self._normalize_path(path)
|
|
method_upper = method.upper()
|
|
|
|
# Handle collection endpoints
|
|
if normalized == "/v3/collections":
|
|
return self._handle_collections_endpoint(method_upper, json)
|
|
|
|
# Handle document endpoints
|
|
segments = [segment for segment in normalized.strip("/").split("/") if segment]
|
|
if segments[:2] == ["v3", "documents"]:
|
|
return self._handle_documents_endpoint(method_upper, segments, json, files)
|
|
|
|
# Delegate to parent
|
|
return super().handle(method=method, path=path, json=json, params=params, files=files)
|
|
|
|
def _handle_collections_endpoint(
|
|
self, method_upper: str, json: Mapping[str, Any] | None
|
|
) -> tuple[int, Any]:
|
|
"""Handle collection-related endpoints."""
|
|
if method_upper == "GET":
|
|
return self._list_collections()
|
|
if method_upper == "POST":
|
|
return self._create_collection_endpoint(json or {})
|
|
|
|
return 405, {"detail": "Method not allowed"}
|
|
|
|
def _list_collections(self) -> tuple[int, Any]:
|
|
"""List all collections."""
|
|
payload = self.spec.generate_from_ref(
|
|
"#/components/schemas/PaginatedR2RResult_list_CollectionResponse__"
|
|
)
|
|
results = []
|
|
for _identifier, entry in self._collections.items():
|
|
clone = copy.deepcopy(entry)
|
|
clone["document_count"] = len(clone.get("documents", []))
|
|
results.append(clone)
|
|
|
|
payload.update(
|
|
{
|
|
"results": results,
|
|
"total_entries": len(results),
|
|
}
|
|
)
|
|
return 200, payload
|
|
|
|
def _create_collection_endpoint(self, body: Mapping[str, Any]) -> tuple[int, Any]:
|
|
"""Create a new collection."""
|
|
entry = self.create_collection(
|
|
name=str(body.get("name", "Collection")),
|
|
description=str(body.get("description", "")),
|
|
)
|
|
payload = self.spec.generate_from_ref("#/components/schemas/R2RResults_CollectionResponse_")
|
|
payload.update({"results": entry})
|
|
return 200, payload
|
|
|
|
def _handle_documents_endpoint(
|
|
self, method_upper: str, segments: list[str], json: Mapping[str, Any] | None, files: Any
|
|
) -> tuple[int, Any]:
|
|
"""Handle document-related endpoints."""
|
|
if method_upper == "POST" and len(segments) == 2:
|
|
return self._create_document_endpoint(files)
|
|
|
|
if len(segments) == 3:
|
|
doc_id = segments[2]
|
|
if method_upper == "GET":
|
|
return self._get_document_endpoint(doc_id)
|
|
if method_upper == "DELETE":
|
|
return self._delete_document_endpoint(doc_id)
|
|
|
|
if method_upper == "PATCH" and len(segments) == 4 and segments[3] == "metadata":
|
|
return self._update_document_metadata(segments[2], json)
|
|
|
|
return 404, {"detail": "Document endpoint not found"}
|
|
|
|
def _create_document_endpoint(self, files: Any) -> tuple[int, Any]:
|
|
"""Create a new document from files."""
|
|
doc_data = self._extract_document_data_from_files(files)
|
|
document = self.create_document(**doc_data)
|
|
|
|
response_payload = self.spec.generate_from_ref(
|
|
"#/components/schemas/R2RResults_IngestionResponse_"
|
|
)
|
|
ingestion = self.spec.generate_from_ref("#/components/schemas/IngestionResponse")
|
|
ingestion.update(
|
|
{
|
|
"message": ingestion.get("message") or "Ingestion task queued successfully.",
|
|
"document_id": document["id"],
|
|
"task_id": ingestion.get("task_id") or str(uuid4()),
|
|
}
|
|
)
|
|
response_payload["results"] = ingestion
|
|
return 202, response_payload
|
|
|
|
def _extract_document_data_from_files(self, files: Any) -> dict[str, Any]:
|
|
"""Extract document creation data from files."""
|
|
metadata_raw = {}
|
|
content = ""
|
|
doc_id = str(uuid4())
|
|
collection_ids: list[str] = []
|
|
|
|
if not isinstance(files, Mapping):
|
|
return {
|
|
"document_id": doc_id,
|
|
"content": content,
|
|
"metadata": metadata_raw,
|
|
"collection_ids": collection_ids,
|
|
}
|
|
|
|
# Extract metadata
|
|
if "metadata" in files:
|
|
metadata_raw = self._extract_metadata_from_files(files["metadata"])
|
|
|
|
# Extract content
|
|
if "raw_text" in files:
|
|
content = self._extract_content_from_files(files["raw_text"])
|
|
|
|
# Extract document ID
|
|
if "id" in files:
|
|
doc_id = self._extract_doc_id_from_files(files["id"]) or doc_id
|
|
|
|
# Extract collection IDs
|
|
if "collection_ids" in files:
|
|
collection_ids = self._extract_collection_ids_from_files(files["collection_ids"])
|
|
|
|
# Fallback collection resolution
|
|
if not collection_ids:
|
|
collection_ids = self._resolve_collection_ids_from_metadata(metadata_raw)
|
|
|
|
return {
|
|
"document_id": doc_id,
|
|
"content": content,
|
|
"metadata": metadata_raw,
|
|
"collection_ids": collection_ids,
|
|
}
|
|
|
|
def _extract_metadata_from_files(self, metadata_entry: Any) -> dict[str, Any]:
|
|
"""Extract metadata from files entry."""
|
|
if isinstance(metadata_entry, tuple) and len(metadata_entry) >= 2:
|
|
try:
|
|
return json_module.loads(metadata_entry[1] or "{}")
|
|
except json_module.JSONDecodeError:
|
|
return {}
|
|
return {}
|
|
|
|
def _extract_content_from_files(self, raw_text_entry: Any) -> str:
|
|
"""Extract content from files entry."""
|
|
if (
|
|
isinstance(raw_text_entry, tuple)
|
|
and len(raw_text_entry) >= 2
|
|
and raw_text_entry[1] is not None
|
|
):
|
|
return str(raw_text_entry[1])
|
|
return ""
|
|
|
|
def _extract_doc_id_from_files(self, id_entry: Any) -> str | None:
|
|
"""Extract document ID from files entry."""
|
|
if isinstance(id_entry, tuple) and len(id_entry) >= 2 and id_entry[1]:
|
|
return str(id_entry[1])
|
|
return None
|
|
|
|
def _extract_collection_ids_from_files(self, coll_entry: Any) -> list[str]:
|
|
"""Extract collection IDs from files entry."""
|
|
if isinstance(coll_entry, tuple) and len(coll_entry) >= 2 and coll_entry[1]:
|
|
try:
|
|
parsed = json_module.loads(coll_entry[1])
|
|
if isinstance(parsed, list):
|
|
return [str(item) for item in parsed]
|
|
except json_module.JSONDecodeError:
|
|
pass
|
|
return []
|
|
|
|
def _resolve_collection_ids_from_metadata(self, metadata_raw: dict[str, Any]) -> list[str]:
|
|
"""Resolve collection IDs from metadata or use default."""
|
|
name = metadata_raw.get("collection_name") or metadata_raw.get("collection")
|
|
if isinstance(name, str):
|
|
if located := self.find_collection_by_name(name):
|
|
return [located[0]]
|
|
|
|
return [next(iter(self._collections))] if self._collections else []
|
|
|
|
def _get_document_endpoint(self, doc_id: str) -> tuple[int, Any]:
|
|
"""Get a document by ID."""
|
|
document = self.get_document(doc_id)
|
|
if document is None:
|
|
return 404, {"detail": f"Document {doc_id} not found"}
|
|
|
|
response_payload = self.spec.generate_from_ref(
|
|
"#/components/schemas/R2RResults_DocumentResponse_"
|
|
)
|
|
response_payload["results"] = document
|
|
return 200, response_payload
|
|
|
|
def _delete_document_endpoint(self, doc_id: str) -> tuple[int, Any]:
|
|
"""Delete a document by ID."""
|
|
success = self.delete_document(doc_id)
|
|
response_payload = self.spec.generate_from_ref(
|
|
"#/components/schemas/R2RResults_GenericBooleanResponse_"
|
|
)
|
|
|
|
results = response_payload.get("results")
|
|
if isinstance(results, Mapping):
|
|
results = dict(results)
|
|
results.update({"success": success})
|
|
else:
|
|
results = {"success": success}
|
|
|
|
response_payload["results"] = results
|
|
return (200 if success else 404), response_payload
|
|
|
|
def _update_document_metadata(
|
|
self, doc_id: str, json: Mapping[str, Any] | None
|
|
) -> tuple[int, Any]:
|
|
"""Update document metadata."""
|
|
metadata_list = [dict(item) for item in json] if isinstance(json, list) else []
|
|
document = self.append_document_metadata(doc_id, metadata_list)
|
|
if document is None:
|
|
return 404, {"detail": f"Document {doc_id} not found"}
|
|
|
|
response_payload = self.spec.generate_from_ref(
|
|
"#/components/schemas/R2RResults_DocumentResponse_"
|
|
)
|
|
response_payload["results"] = document
|
|
return 200, response_payload
|
|
|
|
|
|
class FirecrawlMockService(OpenAPIMockService):
|
|
"""Stateful mock for Firecrawl map and scrape endpoints."""
|
|
|
|
def __init__(self, base_url: str, spec: OpenAPISpec) -> None:
|
|
super().__init__(base_url, spec)
|
|
self._maps: dict[str, list[str]] = {}
|
|
self._pages: dict[str, dict[str, Any]] = {}
|
|
|
|
def register_map_result(self, origin: str, links: list[str]) -> None:
|
|
self._maps[origin] = list(links)
|
|
|
|
def register_page(
|
|
self,
|
|
url: str,
|
|
*,
|
|
markdown: str | None = None,
|
|
html: str | None = None,
|
|
metadata: Mapping[str, Any] | None = None,
|
|
links: list[str] | None = None,
|
|
) -> None:
|
|
self._pages[url] = {
|
|
"markdown": markdown,
|
|
"html": html,
|
|
"metadata": dict(metadata or {}),
|
|
"links": list(links or []),
|
|
}
|
|
|
|
def _build_map_payload(self, target_url: str, limit: int | None) -> dict[str, Any]:
|
|
payload = self.spec.generate_from_ref("#/components/schemas/MapResponse")
|
|
links = self._maps.get(target_url, [target_url])
|
|
if limit is not None:
|
|
try:
|
|
limit_value = int(limit)
|
|
if limit_value >= 0:
|
|
links = links[:limit_value]
|
|
except (TypeError, ValueError):
|
|
pass
|
|
payload["success"] = True
|
|
payload["links"] = [{"url": link} for link in links]
|
|
return payload
|
|
|
|
def _default_metadata(self, url: str) -> dict[str, Any]:
|
|
metadata = self.spec.generate_from_ref("#/components/schemas/ScrapeMetadata")
|
|
metadata.update(
|
|
{
|
|
"url": url,
|
|
"sourceURL": url,
|
|
"scrapeId": metadata.get("scrapeId") or str(uuid4()),
|
|
"statusCode": metadata.get("statusCode", 200) or 200,
|
|
"contentType": metadata.get("contentType", "text/html") or "text/html",
|
|
"creditsUsed": metadata.get("creditsUsed", 1) or 1,
|
|
}
|
|
)
|
|
return metadata
|
|
|
|
def _build_scrape_payload(self, target_url: str, formats: list[str] | None) -> dict[str, Any]:
|
|
payload = self.spec.generate_from_ref("#/components/schemas/ScrapeResponse")
|
|
payload["success"] = True
|
|
page_info = self._pages.get(target_url, {})
|
|
data = payload.get("data", {})
|
|
markdown = page_info.get("markdown") or f"# Content for {target_url}\n"
|
|
html = page_info.get("html") or f"<h1>Content for {target_url}</h1>"
|
|
data.update(
|
|
{
|
|
"markdown": markdown,
|
|
"html": html,
|
|
"rawHtml": page_info.get("rawHtml", html),
|
|
"links": page_info.get("links", []),
|
|
}
|
|
)
|
|
metadata_payload = self._default_metadata(target_url)
|
|
metadata_payload.update(page_info.get("metadata", {}))
|
|
data["metadata"] = metadata_payload
|
|
payload["data"] = data
|
|
return payload
|
|
|
|
def map_response(self, url: str, limit: int | None = None) -> dict[str, Any]:
|
|
return self._build_map_payload(url, limit)
|
|
|
|
def scrape_response(self, url: str, formats: list[str] | None = None) -> dict[str, Any]:
|
|
return self._build_scrape_payload(url, formats)
|
|
|
|
def handle(
|
|
self,
|
|
*,
|
|
method: str,
|
|
path: str,
|
|
json: Mapping[str, Any] | None,
|
|
params: Mapping[str, Any] | None,
|
|
files: Any,
|
|
) -> tuple[int, Any]:
|
|
normalized = self._normalize_path(path)
|
|
method_upper = method.upper()
|
|
|
|
if normalized == "/v2/map" and method_upper == "POST":
|
|
body = json or {}
|
|
target_url = str(body.get("url", ""))
|
|
limit = body.get("limit")
|
|
return 200, self._build_map_payload(target_url, limit)
|
|
|
|
if normalized == "/v2/scrape" and method_upper == "POST":
|
|
body = json or {}
|
|
target_url = str(body.get("url", ""))
|
|
formats = body.get("formats")
|
|
return 200, self._build_scrape_payload(target_url, formats)
|
|
|
|
return super().handle(method=method, path=path, json=json, params=params, files=files)
|