Files
rag-manager/tests/conftest.py
2025-09-19 06:56:19 +00:00

546 lines
15 KiB
Python

from __future__ import annotations
from collections import deque
from collections.abc import Mapping
from dataclasses import dataclass, field
from datetime import UTC, datetime
from pathlib import Path
from types import SimpleNamespace
from typing import Protocol, TypedDict, cast
from urllib.parse import urlparse
import httpx
import pytest
from pydantic import HttpUrl
from ingest_pipeline.core.models import (
Document,
DocumentMetadata,
IngestionSource,
StorageBackend,
StorageConfig,
VectorConfig,
)
from typings import EmbeddingData, EmbeddingResponse
from .openapi_mocks import (
FirecrawlMockService,
OpenAPISpec,
OpenWebUIMockService,
R2RMockService,
)
# Type aliases for mock responses
MockResponseData = dict[str, object]
PROJECT_ROOT = Path(__file__).resolve().parent.parent
class RequestRecord(TypedDict):
"""Captured HTTP request payload."""
method: str
url: str
json_body: dict[str, object] | None
params: dict[str, object] | None
files: object | None
@dataclass(slots=True)
class StubbedResponse:
"""In-memory HTTP response for httpx client mocking."""
payload: object
status_code: int = 200
def json(self) -> object:
return self.payload
def raise_for_status(self) -> None:
if self.status_code < 400:
return
# Create a minimal exception for testing - we don't need full httpx objects for tests
class TestHTTPError(Exception):
request: object | None
response: object | None
def __init__(self, message: str) -> None:
super().__init__(message)
self.request = None
self.response = None
raise TestHTTPError(f"Stubbed HTTP error with status {self.status_code}")
@dataclass(slots=True)
class AsyncClientStub:
"""Replacement for httpx.AsyncClient used in tests."""
responses: deque[StubbedResponse]
requests: list[RequestRecord]
owner: HttpxStub
timeout: object | None = None
headers: dict[str, str] = field(default_factory=dict)
base_url: str = ""
def __init__(
self,
*,
responses: deque[StubbedResponse],
requests: list[RequestRecord],
timeout: object | None = None,
headers: dict[str, str] | None = None,
base_url: str | None = None,
owner: HttpxStub,
**_: object,
) -> None:
self.responses = responses
self.requests = requests
self.timeout = timeout
self.headers = dict(headers or {})
self.base_url = str(base_url or "")
self.owner = owner
def _normalize_url(self, url: str) -> str:
if url.startswith("http://") or url.startswith("https://"):
return url
if not self.base_url:
return url
prefix = self.base_url.rstrip("/")
suffix = url.lstrip("/")
return f"{prefix}/{suffix}" if suffix else prefix
def _consume(
self,
*,
method: str,
url: str,
json: dict[str, object] | None,
params: dict[str, object] | None,
files: object | None,
) -> StubbedResponse:
if self.responses:
return self.responses.popleft()
dispatched = self.owner.dispatch(
method=method,
url=url,
json=json,
params=params,
files=files,
)
if dispatched is not None:
return dispatched
raise AssertionError(f"No stubbed response for {method} {url}")
def _record(
self,
*,
method: str,
url: str,
json: dict[str, object] | None,
params: dict[str, object] | None,
files: object | None,
) -> str:
normalized = self._normalize_url(url)
record: RequestRecord = {
"method": method,
"url": normalized,
"json_body": json,
"params": params,
"files": files,
}
self.requests.append(record)
return normalized
async def post(
self,
url: str,
*,
json: dict[str, object] | None = None,
files: object | None = None,
params: dict[str, object] | None = None,
) -> StubbedResponse:
normalized = self._record(
method="POST",
url=url,
json=json,
params=params,
files=files,
)
return self._consume(
method="POST",
url=normalized,
json=json,
params=params,
files=files,
)
async def get(
self,
url: str,
*,
params: dict[str, object] | None = None,
) -> StubbedResponse:
normalized = self._record(
method="GET",
url=url,
json=None,
params=params,
files=None,
)
return self._consume(
method="GET",
url=normalized,
json=None,
params=params,
files=None,
)
async def delete(
self,
url: str,
*,
params: dict[str, object] | None = None,
json: dict[str, object] | None = None,
) -> StubbedResponse:
normalized = self._record(
method="DELETE",
url=url,
json=json,
params=params,
files=None,
)
return self._consume(
method="DELETE",
url=normalized,
json=json,
params=params,
files=None,
)
async def aclose(self) -> None:
return None
async def __aenter__(self) -> AsyncClientStub:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: object | None,
) -> None:
return None
@dataclass(slots=True)
class HttpxStub:
"""Helper exposing queued responses and captured requests."""
responses: deque[StubbedResponse] = field(default_factory=deque)
requests: list[RequestRecord] = field(default_factory=list)
clients: list[AsyncClientStub] = field(default_factory=list)
services: dict[str, MockService] = field(default_factory=dict)
def queue_json(self, payload: object, status_code: int = 200) -> None:
self.responses.append(StubbedResponse(payload=payload, status_code=status_code))
def register_service(self, base_url: str, service: MockService) -> None:
normalized = base_url.rstrip("/") or base_url
self.services[normalized] = service
def dispatch(
self,
*,
method: str,
url: str,
json: Mapping[str, object] | None,
params: Mapping[str, object] | None,
files: object | None,
) -> StubbedResponse | None:
parsed = urlparse(url)
base = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme and parsed.netloc else ""
base = base.rstrip("/") if base else base
path = parsed.path or "/"
service = self.services.get(base)
if service is None:
return None
status, payload = service.handle(
method=method,
path=path,
json=json,
params=params,
files=files,
)
return StubbedResponse(payload=payload, status_code=status)
class DocumentFactory(Protocol):
"""Callable protocol for building Document instances."""
def __call__(
self,
*,
content: str,
metadata_updates: dict[str, object] | None = None,
) -> Document:
"""Create Document for testing."""
...
class VectorConfigFactory(Protocol):
"""Callable protocol for building VectorConfig instances."""
def __call__(self, *, model: str, dimension: int, endpoint: str) -> VectorConfig:
"""Create VectorConfig for testing."""
...
class EmbeddingPayloadFactory(Protocol):
"""Callable protocol for synthesizing embedding responses."""
def __call__(self, *, dimension: int) -> EmbeddingResponse:
"""Create embedding payload with the requested dimension."""
...
class MockService(Protocol):
"""Protocol for mock services that can handle HTTP requests."""
def handle(
self,
*,
method: str,
path: str,
json: Mapping[str, object] | None,
params: Mapping[str, object] | None,
files: object | None,
) -> tuple[int, object]:
"""Handle HTTP request and return status code and response payload."""
...
@pytest.fixture(scope="session")
def base_metadata() -> DocumentMetadata:
"""Provide reusable base metadata for document fixtures."""
required = {
"source_url": "https://example.com/article",
"timestamp": datetime.now(UTC),
"content_type": "text/plain",
"word_count": 100,
"char_count": 500,
}
return cast(DocumentMetadata, cast(object, required))
@pytest.fixture(scope="module")
def document_factory(base_metadata: DocumentMetadata) -> DocumentFactory:
"""Build Document models with deterministic defaults."""
def _factory(
*,
content: str,
metadata_updates: dict[str, object] | None = None,
) -> Document:
metadata_dict = dict(base_metadata)
if metadata_updates:
metadata_dict.update(metadata_updates)
return Document(
content=content,
metadata=cast(DocumentMetadata, cast(object, metadata_dict)),
source=IngestionSource.WEB,
)
return _factory
@pytest.fixture(scope="session")
def vector_config_factory() -> VectorConfigFactory:
"""Construct VectorConfig instances for tests."""
def _factory(*, model: str, dimension: int, endpoint: str) -> VectorConfig:
return VectorConfig(model=model, dimension=dimension, embedding_endpoint=HttpUrl(endpoint))
return _factory
@pytest.fixture(scope="session")
def storage_config() -> StorageConfig:
"""Provide canonical storage configuration for adapters."""
return StorageConfig(
backend=StorageBackend.WEAVIATE,
endpoint=HttpUrl("http://storage.local"),
collection_name="documents",
)
@pytest.fixture(scope="session")
def r2r_storage_config() -> StorageConfig:
"""Provide storage configuration for R2R adapter tests."""
return StorageConfig(
backend=StorageBackend.R2R,
endpoint=HttpUrl("http://r2r.local"),
collection_name="documents",
)
@pytest.fixture(scope="session")
def openwebui_spec() -> OpenAPISpec:
"""Load OpenWebUI OpenAPI specification."""
return OpenAPISpec.from_file(PROJECT_ROOT / "chat.json")
@pytest.fixture(scope="session")
def r2r_spec() -> OpenAPISpec:
"""Load R2R OpenAPI specification."""
return OpenAPISpec.from_file(PROJECT_ROOT / "r2r.json")
@pytest.fixture(scope="session")
def firecrawl_spec() -> OpenAPISpec:
"""Load Firecrawl OpenAPI specification."""
return OpenAPISpec.from_file(PROJECT_ROOT / "firecrawl.json")
@pytest.fixture(scope="function")
def httpx_stub(monkeypatch: pytest.MonkeyPatch) -> HttpxStub:
"""Replace httpx.AsyncClient with an in-memory stub."""
stub = HttpxStub()
def _client_factory(**kwargs: object) -> AsyncClientStub:
client = AsyncClientStub(
responses=stub.responses,
requests=stub.requests,
timeout=kwargs.get("timeout"),
headers=cast(dict[str, str] | None, kwargs.get("headers")),
base_url=cast(str | None, kwargs.get("base_url")),
owner=stub,
)
stub.clients.append(client)
return client
monkeypatch.setattr(httpx, "AsyncClient", _client_factory)
monkeypatch.delenv("LLM_API_KEY", raising=False)
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
return stub
@pytest.fixture(scope="function")
def openwebui_service(
httpx_stub: HttpxStub,
openwebui_spec: OpenAPISpec,
storage_config: StorageConfig,
) -> OpenWebUIMockService:
"""Stateful mock for OpenWebUI APIs."""
service = OpenWebUIMockService(
base_url=str(storage_config.endpoint),
spec=openwebui_spec,
)
httpx_stub.register_service(service.base_url, service)
return service
@pytest.fixture(scope="function")
def r2r_service(
httpx_stub: HttpxStub,
r2r_spec: OpenAPISpec,
r2r_storage_config: StorageConfig,
) -> R2RMockService:
"""Stateful mock for R2R APIs."""
service = R2RMockService(
base_url=str(r2r_storage_config.endpoint),
spec=r2r_spec,
)
httpx_stub.register_service(service.base_url, service)
return service
@pytest.fixture(scope="function")
def firecrawl_service(
httpx_stub: HttpxStub,
firecrawl_spec: OpenAPISpec,
) -> FirecrawlMockService:
"""Stateful mock for Firecrawl APIs."""
service = FirecrawlMockService(
base_url="http://crawl.lab:30002",
spec=firecrawl_spec,
)
httpx_stub.register_service(service.base_url, service)
return service
@pytest.fixture(scope="function")
def firecrawl_client_stub(
monkeypatch: pytest.MonkeyPatch,
firecrawl_service: FirecrawlMockService,
) -> object:
"""Patch AsyncFirecrawl to use the mock service."""
class AsyncFirecrawlStub:
_service: FirecrawlMockService
def __init__(self, *args: object, **kwargs: object) -> None:
self._service = firecrawl_service
async def map(self, url: str, limit: int | None = None, **_: object) -> SimpleNamespace:
payload = cast(MockResponseData, self._service.map_response(url, limit))
links_data = cast(list[MockResponseData], payload.get("links", []))
links = [SimpleNamespace(url=cast(str, item.get("url", ""))) for item in links_data]
return SimpleNamespace(success=payload.get("success", True), links=links)
async def scrape(self, url: str, formats: list[str] | None = None, **_: object) -> SimpleNamespace:
payload = cast(MockResponseData, self._service.scrape_response(url, formats))
data = cast(MockResponseData, payload.get("data", {}))
metadata_payload = cast(MockResponseData, data.get("metadata", {}))
metadata_obj = SimpleNamespace(**metadata_payload)
return SimpleNamespace(
markdown=data.get("markdown"),
html=data.get("html"),
rawHtml=data.get("rawHtml"),
links=data.get("links", []),
metadata=metadata_obj,
)
async def close(self) -> None:
return None
async def __aenter__(self) -> AsyncFirecrawlStub:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: object | None,
) -> None:
return None
monkeypatch.setattr(
"ingest_pipeline.ingestors.firecrawl.AsyncFirecrawl",
AsyncFirecrawlStub,
)
return AsyncFirecrawlStub
@pytest.fixture(scope="function")
def embedding_payload_factory() -> EmbeddingPayloadFactory:
"""Provide embedding payloads tailored to the requested dimension."""
def _factory(*, dimension: int) -> EmbeddingResponse:
vector = [float(index) for index in range(dimension)]
data_entry: EmbeddingData = {"embedding": vector}
return {"data": [data_entry]}
return _factory