588 lines
17 KiB
Python
588 lines
17 KiB
Python
from __future__ import annotations
|
|
|
|
from collections import deque
|
|
from collections.abc import Mapping
|
|
from dataclasses import dataclass, field
|
|
from datetime import UTC, datetime
|
|
from pathlib import Path
|
|
from types import SimpleNamespace
|
|
from typing import Protocol, TypedDict, cast
|
|
from urllib.parse import urlparse
|
|
|
|
import httpx
|
|
import pytest
|
|
from pydantic import HttpUrl
|
|
|
|
from ingest_pipeline.core.models import (
|
|
Document,
|
|
DocumentMetadata,
|
|
IngestionSource,
|
|
StorageBackend,
|
|
StorageConfig,
|
|
VectorConfig,
|
|
)
|
|
from typings import EmbeddingData, EmbeddingResponse
|
|
|
|
from .openapi_mocks import (
|
|
FirecrawlMockService,
|
|
OpenAPISpec,
|
|
OpenWebUIMockService,
|
|
R2RMockService,
|
|
)
|
|
|
|
# Type aliases for mock responses
|
|
MockResponseData = dict[str, object]
|
|
|
|
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
|
|
|
|
|
class RequestRecord(TypedDict):
|
|
"""Captured HTTP request payload."""
|
|
|
|
method: str
|
|
url: str
|
|
json_body: dict[str, object] | None
|
|
params: dict[str, object] | None
|
|
files: object | None
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class StubbedResponse:
|
|
"""In-memory HTTP response for httpx client mocking."""
|
|
|
|
payload: object
|
|
status_code: int = 200
|
|
|
|
def json(self) -> object:
|
|
return self.payload
|
|
|
|
def raise_for_status(self) -> None:
|
|
if self.status_code < 400:
|
|
return
|
|
|
|
# Create a minimal exception for testing - we don't need full httpx objects for tests
|
|
class TestHTTPError(Exception):
|
|
request: object | None
|
|
response: object | None
|
|
|
|
def __init__(self, message: str) -> None:
|
|
super().__init__(message)
|
|
self.request = None
|
|
self.response = None
|
|
|
|
raise TestHTTPError(f"Stubbed HTTP error with status {self.status_code}")
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class AsyncClientStub:
|
|
"""Replacement for httpx.AsyncClient used in tests."""
|
|
|
|
responses: deque[StubbedResponse]
|
|
requests: list[RequestRecord]
|
|
owner: HttpxStub
|
|
timeout: object | None = None
|
|
headers: dict[str, str] = field(default_factory=dict)
|
|
base_url: str = ""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
responses: deque[StubbedResponse],
|
|
requests: list[RequestRecord],
|
|
timeout: object | None = None,
|
|
headers: dict[str, str] | None = None,
|
|
base_url: str | None = None,
|
|
owner: HttpxStub,
|
|
**_: object,
|
|
) -> None:
|
|
self.responses = responses
|
|
self.requests = requests
|
|
self.timeout = timeout
|
|
self.headers = dict(headers or {})
|
|
self.base_url = str(base_url or "")
|
|
self.owner = owner
|
|
|
|
def _normalize_url(self, url: str) -> str:
|
|
if url.startswith("http://") or url.startswith("https://"):
|
|
return url
|
|
if not self.base_url:
|
|
return url
|
|
prefix = self.base_url.rstrip("/")
|
|
suffix = url.lstrip("/")
|
|
return f"{prefix}/{suffix}" if suffix else prefix
|
|
|
|
def _consume(
|
|
self,
|
|
*,
|
|
method: str,
|
|
url: str,
|
|
json: dict[str, object] | None,
|
|
params: dict[str, object] | None,
|
|
files: object | None,
|
|
) -> StubbedResponse:
|
|
if self.responses:
|
|
return self.responses.popleft()
|
|
dispatched = self.owner.dispatch(
|
|
method=method,
|
|
url=url,
|
|
json=json,
|
|
params=params,
|
|
files=files,
|
|
)
|
|
if dispatched is not None:
|
|
return dispatched
|
|
raise AssertionError(f"No stubbed response for {method} {url}")
|
|
|
|
def _record(
|
|
self,
|
|
*,
|
|
method: str,
|
|
url: str,
|
|
json: dict[str, object] | None,
|
|
params: dict[str, object] | None,
|
|
files: object | None,
|
|
) -> str:
|
|
normalized = self._normalize_url(url)
|
|
record: RequestRecord = {
|
|
"method": method,
|
|
"url": normalized,
|
|
"json_body": json,
|
|
"params": params,
|
|
"files": files,
|
|
}
|
|
self.requests.append(record)
|
|
return normalized
|
|
|
|
async def post(
|
|
self,
|
|
url: str,
|
|
*,
|
|
json: dict[str, object] | None = None,
|
|
files: object | None = None,
|
|
params: dict[str, object] | None = None,
|
|
) -> StubbedResponse:
|
|
normalized = self._record(
|
|
method="POST",
|
|
url=url,
|
|
json=json,
|
|
params=params,
|
|
files=files,
|
|
)
|
|
return self._consume(
|
|
method="POST",
|
|
url=normalized,
|
|
json=json,
|
|
params=params,
|
|
files=files,
|
|
)
|
|
|
|
async def get(
|
|
self,
|
|
url: str,
|
|
*,
|
|
params: dict[str, object] | None = None,
|
|
) -> StubbedResponse:
|
|
normalized = self._record(
|
|
method="GET",
|
|
url=url,
|
|
json=None,
|
|
params=params,
|
|
files=None,
|
|
)
|
|
return self._consume(
|
|
method="GET",
|
|
url=normalized,
|
|
json=None,
|
|
params=params,
|
|
files=None,
|
|
)
|
|
|
|
async def delete(
|
|
self,
|
|
url: str,
|
|
*,
|
|
params: dict[str, object] | None = None,
|
|
json: dict[str, object] | None = None,
|
|
) -> StubbedResponse:
|
|
normalized = self._record(
|
|
method="DELETE",
|
|
url=url,
|
|
json=json,
|
|
params=params,
|
|
files=None,
|
|
)
|
|
return self._consume(
|
|
method="DELETE",
|
|
url=normalized,
|
|
json=json,
|
|
params=params,
|
|
files=None,
|
|
)
|
|
|
|
async def request(
|
|
self,
|
|
method: str,
|
|
url: str,
|
|
*,
|
|
json: dict[str, object] | None = None,
|
|
data: dict[str, object] | None = None,
|
|
files: dict[str, tuple[str, bytes, str]] | None = None,
|
|
params: dict[str, str | bool] | None = None,
|
|
) -> StubbedResponse:
|
|
"""Generic request method that delegates to specific HTTP methods."""
|
|
# Convert params to the format expected by other methods
|
|
converted_params: dict[str, object] | None = None
|
|
if params:
|
|
converted_params = dict(params)
|
|
|
|
method_upper = method.upper()
|
|
if method_upper == "GET":
|
|
return await self.get(url, params=converted_params)
|
|
elif method_upper == "POST":
|
|
return await self.post(url, json=json, files=files, params=converted_params)
|
|
elif method_upper == "DELETE":
|
|
return await self.delete(url, json=json, params=converted_params)
|
|
else:
|
|
# For other methods, use the consume/record pattern directly
|
|
normalized = self._record(
|
|
method=method_upper,
|
|
url=url,
|
|
json=json or data,
|
|
params=converted_params,
|
|
files=files,
|
|
)
|
|
return self._consume(
|
|
method=method_upper,
|
|
url=normalized,
|
|
json=json or data,
|
|
params=converted_params,
|
|
files=files,
|
|
)
|
|
|
|
async def aclose(self) -> None:
|
|
return None
|
|
|
|
async def __aenter__(self) -> AsyncClientStub:
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: object | None,
|
|
) -> None:
|
|
return None
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class HttpxStub:
|
|
"""Helper exposing queued responses and captured requests."""
|
|
|
|
responses: deque[StubbedResponse] = field(default_factory=deque)
|
|
requests: list[RequestRecord] = field(default_factory=list)
|
|
clients: list[AsyncClientStub] = field(default_factory=list)
|
|
services: dict[str, MockService] = field(default_factory=dict)
|
|
|
|
def queue_json(self, payload: object, status_code: int = 200) -> None:
|
|
self.responses.append(StubbedResponse(payload=payload, status_code=status_code))
|
|
|
|
def register_service(self, base_url: str, service: MockService) -> None:
|
|
normalized = base_url.rstrip("/") or base_url
|
|
self.services[normalized] = service
|
|
|
|
def dispatch(
|
|
self,
|
|
*,
|
|
method: str,
|
|
url: str,
|
|
json: Mapping[str, object] | None,
|
|
params: Mapping[str, object] | None,
|
|
files: object | None,
|
|
) -> StubbedResponse | None:
|
|
parsed = urlparse(url)
|
|
base = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme and parsed.netloc else ""
|
|
base = base.rstrip("/") if base else base
|
|
path = parsed.path or "/"
|
|
service = self.services.get(base)
|
|
if service is None:
|
|
return None
|
|
status, payload = service.handle(
|
|
method=method,
|
|
path=path,
|
|
json=json,
|
|
params=params,
|
|
files=files,
|
|
)
|
|
return StubbedResponse(payload=payload, status_code=status)
|
|
|
|
|
|
class DocumentFactory(Protocol):
|
|
"""Callable protocol for building Document instances."""
|
|
|
|
def __call__(
|
|
self,
|
|
*,
|
|
content: str,
|
|
metadata_updates: dict[str, object] | None = None,
|
|
) -> Document:
|
|
"""Create Document for testing."""
|
|
...
|
|
|
|
|
|
class VectorConfigFactory(Protocol):
|
|
"""Callable protocol for building VectorConfig instances."""
|
|
|
|
def __call__(self, *, model: str, dimension: int, endpoint: str) -> VectorConfig:
|
|
"""Create VectorConfig for testing."""
|
|
...
|
|
|
|
|
|
class EmbeddingPayloadFactory(Protocol):
|
|
"""Callable protocol for synthesizing embedding responses."""
|
|
|
|
def __call__(self, *, dimension: int) -> EmbeddingResponse:
|
|
"""Create embedding payload with the requested dimension."""
|
|
...
|
|
|
|
|
|
class MockService(Protocol):
|
|
"""Protocol for mock services that can handle HTTP requests."""
|
|
|
|
def handle(
|
|
self,
|
|
*,
|
|
method: str,
|
|
path: str,
|
|
json: Mapping[str, object] | None,
|
|
params: Mapping[str, object] | None,
|
|
files: object | None,
|
|
) -> tuple[int, object]:
|
|
"""Handle HTTP request and return status code and response payload."""
|
|
...
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def base_metadata() -> DocumentMetadata:
|
|
"""Provide reusable base metadata for document fixtures."""
|
|
|
|
required = {
|
|
"source_url": "https://example.com/article",
|
|
"timestamp": datetime.now(UTC),
|
|
"content_type": "text/plain",
|
|
"word_count": 100,
|
|
"char_count": 500,
|
|
}
|
|
return cast(DocumentMetadata, cast(object, required))
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def document_factory(base_metadata: DocumentMetadata) -> DocumentFactory:
|
|
"""Build Document models with deterministic defaults."""
|
|
|
|
def _factory(
|
|
*,
|
|
content: str,
|
|
metadata_updates: dict[str, object] | None = None,
|
|
) -> Document:
|
|
metadata_dict = dict(base_metadata)
|
|
if metadata_updates:
|
|
metadata_dict.update(metadata_updates)
|
|
return Document(
|
|
content=content,
|
|
metadata=cast(DocumentMetadata, cast(object, metadata_dict)),
|
|
source=IngestionSource.WEB,
|
|
)
|
|
|
|
return _factory
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def vector_config_factory() -> VectorConfigFactory:
|
|
"""Construct VectorConfig instances for tests."""
|
|
|
|
def _factory(*, model: str, dimension: int, endpoint: str) -> VectorConfig:
|
|
return VectorConfig(model=model, dimension=dimension, embedding_endpoint=HttpUrl(endpoint))
|
|
|
|
return _factory
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def storage_config() -> StorageConfig:
|
|
"""Provide canonical storage configuration for adapters."""
|
|
|
|
return StorageConfig(
|
|
backend=StorageBackend.WEAVIATE,
|
|
endpoint=HttpUrl("http://storage.local"),
|
|
collection_name="documents",
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def r2r_storage_config() -> StorageConfig:
|
|
"""Provide storage configuration for R2R adapter tests."""
|
|
|
|
return StorageConfig(
|
|
backend=StorageBackend.R2R,
|
|
endpoint=HttpUrl("http://r2r.local"),
|
|
collection_name="documents",
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def openwebui_spec() -> OpenAPISpec:
|
|
"""Load OpenWebUI OpenAPI specification."""
|
|
|
|
return OpenAPISpec.from_file(PROJECT_ROOT / "chat.json")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def r2r_spec() -> OpenAPISpec:
|
|
"""Load R2R OpenAPI specification."""
|
|
|
|
return OpenAPISpec.from_file(PROJECT_ROOT / "r2r.json")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def firecrawl_spec() -> OpenAPISpec:
|
|
"""Load Firecrawl OpenAPI specification."""
|
|
|
|
return OpenAPISpec.from_file(PROJECT_ROOT / "firecrawl.json")
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def httpx_stub(monkeypatch: pytest.MonkeyPatch) -> HttpxStub:
|
|
"""Replace httpx.AsyncClient with an in-memory stub."""
|
|
|
|
stub = HttpxStub()
|
|
|
|
def _client_factory(**kwargs: object) -> AsyncClientStub:
|
|
client = AsyncClientStub(
|
|
responses=stub.responses,
|
|
requests=stub.requests,
|
|
timeout=kwargs.get("timeout"),
|
|
headers=cast(dict[str, str] | None, kwargs.get("headers")),
|
|
base_url=cast(str | None, kwargs.get("base_url")),
|
|
owner=stub,
|
|
)
|
|
stub.clients.append(client)
|
|
return client
|
|
|
|
monkeypatch.setattr(httpx, "AsyncClient", _client_factory)
|
|
monkeypatch.delenv("LLM_API_KEY", raising=False)
|
|
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
|
return stub
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def openwebui_service(
|
|
httpx_stub: HttpxStub,
|
|
openwebui_spec: OpenAPISpec,
|
|
storage_config: StorageConfig,
|
|
) -> OpenWebUIMockService:
|
|
"""Stateful mock for OpenWebUI APIs."""
|
|
|
|
service = OpenWebUIMockService(
|
|
base_url=str(storage_config.endpoint),
|
|
spec=openwebui_spec,
|
|
)
|
|
httpx_stub.register_service(service.base_url, service)
|
|
return service
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def r2r_service(
|
|
httpx_stub: HttpxStub,
|
|
r2r_spec: OpenAPISpec,
|
|
r2r_storage_config: StorageConfig,
|
|
) -> R2RMockService:
|
|
"""Stateful mock for R2R APIs."""
|
|
|
|
service = R2RMockService(
|
|
base_url=str(r2r_storage_config.endpoint),
|
|
spec=r2r_spec,
|
|
)
|
|
httpx_stub.register_service(service.base_url, service)
|
|
return service
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def firecrawl_service(
|
|
httpx_stub: HttpxStub,
|
|
firecrawl_spec: OpenAPISpec,
|
|
) -> FirecrawlMockService:
|
|
"""Stateful mock for Firecrawl APIs."""
|
|
|
|
service = FirecrawlMockService(
|
|
base_url="http://crawl.lab:30002",
|
|
spec=firecrawl_spec,
|
|
)
|
|
httpx_stub.register_service(service.base_url, service)
|
|
return service
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def firecrawl_client_stub(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
firecrawl_service: FirecrawlMockService,
|
|
) -> object:
|
|
"""Patch AsyncFirecrawl to use the mock service."""
|
|
|
|
class AsyncFirecrawlStub:
|
|
_service: FirecrawlMockService
|
|
|
|
def __init__(self, *args: object, **kwargs: object) -> None:
|
|
self._service = firecrawl_service
|
|
|
|
async def map(self, url: str, limit: int | None = None, **_: object) -> SimpleNamespace:
|
|
payload = cast(MockResponseData, self._service.map_response(url, limit))
|
|
links_data = cast(list[MockResponseData], payload.get("links", []))
|
|
links = [SimpleNamespace(url=cast(str, item.get("url", ""))) for item in links_data]
|
|
return SimpleNamespace(success=payload.get("success", True), links=links)
|
|
|
|
async def scrape(
|
|
self, url: str, formats: list[str] | None = None, **_: object
|
|
) -> SimpleNamespace:
|
|
payload = cast(MockResponseData, self._service.scrape_response(url, formats))
|
|
data = cast(MockResponseData, payload.get("data", {}))
|
|
metadata_payload = cast(MockResponseData, data.get("metadata", {}))
|
|
metadata_obj = SimpleNamespace(**metadata_payload)
|
|
return SimpleNamespace(
|
|
markdown=data.get("markdown"),
|
|
html=data.get("html"),
|
|
rawHtml=data.get("rawHtml"),
|
|
links=data.get("links", []),
|
|
metadata=metadata_obj,
|
|
)
|
|
|
|
async def close(self) -> None:
|
|
return None
|
|
|
|
async def __aenter__(self) -> AsyncFirecrawlStub:
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: object | None,
|
|
) -> None:
|
|
return None
|
|
|
|
monkeypatch.setattr(
|
|
"ingest_pipeline.ingestors.firecrawl.AsyncFirecrawl",
|
|
AsyncFirecrawlStub,
|
|
)
|
|
return AsyncFirecrawlStub
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def embedding_payload_factory() -> EmbeddingPayloadFactory:
|
|
"""Provide embedding payloads tailored to the requested dimension."""
|
|
|
|
def _factory(*, dimension: int) -> EmbeddingResponse:
|
|
vector = [float(index) for index in range(dimension)]
|
|
data_entry: EmbeddingData = {"embedding": vector}
|
|
return {"data": [data_entry]}
|
|
|
|
return _factory
|