390 lines
13 KiB
Python
390 lines
13 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Mapping
|
|
from datetime import UTC, datetime
|
|
from types import SimpleNamespace
|
|
from typing import Any, Self
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from ingest_pipeline.core.models import StorageConfig
|
|
from ingest_pipeline.storage.r2r.storage import (
|
|
R2RStorage,
|
|
_as_datetime,
|
|
_as_int,
|
|
_as_mapping,
|
|
_as_sequence,
|
|
_extract_id,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def r2r_client_stub(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
r2r_service,
|
|
) -> object:
|
|
class DummyR2RException(Exception):
|
|
def __init__(self, message: str, status_code: int | None = None) -> None:
|
|
super().__init__(message)
|
|
self.status_code = status_code
|
|
|
|
class MockResponse:
|
|
def __init__(self, json_data: dict[str, Any], status_code: int = 200) -> None:
|
|
self._json_data = json_data
|
|
self.status_code = status_code
|
|
|
|
def json(self) -> dict[str, Any]:
|
|
return self._json_data
|
|
|
|
def raise_for_status(self) -> None:
|
|
if self.status_code >= 400:
|
|
# Create minimal mock request and response for HTTPStatusError
|
|
mock_request = httpx.Request("GET", "http://test.local")
|
|
mock_response = httpx.Response(
|
|
status_code=self.status_code,
|
|
request=mock_request,
|
|
)
|
|
raise httpx.HTTPStatusError("HTTP error", request=mock_request, response=mock_response)
|
|
|
|
class MockAsyncClient:
|
|
def __init__(self, service: Any) -> None:
|
|
self._service = service
|
|
|
|
async def get(self, url: str) -> MockResponse:
|
|
if "/v3/collections" in url:
|
|
# Return existing collections
|
|
collections = []
|
|
for collection_id, collection_data in self._service._collections.items():
|
|
collections.append(
|
|
{
|
|
"id": collection_id,
|
|
"name": collection_data["name"],
|
|
"description": collection_data.get("description", ""),
|
|
}
|
|
)
|
|
return MockResponse({"results": collections})
|
|
return MockResponse({})
|
|
|
|
async def post(
|
|
self,
|
|
url: str,
|
|
*,
|
|
json: dict[str, Any] | None = None,
|
|
files: dict[str, Any] | None = None,
|
|
) -> MockResponse:
|
|
if "/v3/collections" in url and json:
|
|
return self._handle_collection_creation(json)
|
|
if "/v3/documents" in url and files:
|
|
return self._handle_document_creation(files)
|
|
return MockResponse({})
|
|
|
|
def _handle_collection_creation(self, json: dict[str, Any]) -> MockResponse:
|
|
"""Handle collection creation via POST."""
|
|
new_collection_id = f"col-{len(self._service._collections) + 1}"
|
|
self._service.create_collection(
|
|
name=json["name"],
|
|
collection_id=new_collection_id,
|
|
description=json.get("description", ""),
|
|
)
|
|
return MockResponse(
|
|
{
|
|
"results": {
|
|
"id": new_collection_id,
|
|
"name": json["name"],
|
|
"description": json.get("description", ""),
|
|
}
|
|
}
|
|
)
|
|
|
|
def _handle_document_creation(self, files: dict[str, Any]) -> MockResponse:
|
|
"""Handle document creation via POST with files."""
|
|
|
|
document_id = self._extract_document_id(files)
|
|
content = self._extract_content(files)
|
|
metadata = self._extract_metadata(files)
|
|
|
|
# Store document in mock service
|
|
document_data = {
|
|
"id": document_id,
|
|
"content": content,
|
|
"metadata": metadata,
|
|
}
|
|
self._service._documents[document_id] = document_data
|
|
|
|
# Update collection document count if needed
|
|
self._update_collection_document_count(files)
|
|
|
|
return MockResponse(
|
|
{
|
|
"results": {
|
|
"document_id": document_id,
|
|
"message": "Document created successfully",
|
|
}
|
|
}
|
|
)
|
|
|
|
def _extract_document_id(self, files: dict[str, Any]) -> str:
|
|
"""Extract document ID from files."""
|
|
return files.get("id", (None, f"doc-{len(self._service._documents) + 1}"))[1]
|
|
|
|
def _extract_content(self, files: dict[str, Any]) -> str:
|
|
"""Extract content from files."""
|
|
return files.get("raw_text", (None, ""))[1]
|
|
|
|
def _extract_metadata(self, files: dict[str, Any]) -> dict[str, Any]:
|
|
"""Extract and parse metadata from files."""
|
|
import json as json_lib
|
|
|
|
metadata_str = files.get("metadata", (None, "{}"))[1]
|
|
return json_lib.loads(metadata_str) if metadata_str else {}
|
|
|
|
def _update_collection_document_count(self, files: dict[str, Any]) -> None:
|
|
"""Update collection document count if collection_ids present."""
|
|
if not (collection_ids := files.get("collection_ids")):
|
|
return
|
|
|
|
collection_id = self._parse_collection_ids(collection_ids)
|
|
if collection_id and collection_id in self._service._collections:
|
|
total_docs = len(self._service._documents)
|
|
self._service._collections[collection_id]["document_count"] = total_docs
|
|
|
|
def _parse_collection_ids(self, collection_ids: Any) -> str | None:
|
|
"""Parse collection IDs from files entry."""
|
|
import json as json_lib
|
|
|
|
collection_ids_str = (
|
|
collection_ids[1] if isinstance(collection_ids, tuple) else collection_ids
|
|
)
|
|
try:
|
|
collection_list = (
|
|
json_lib.loads(collection_ids_str)
|
|
if isinstance(collection_ids_str, str)
|
|
else collection_ids_str
|
|
)
|
|
if isinstance(collection_list, list) and collection_list:
|
|
first_collection = collection_list[0]
|
|
if isinstance(first_collection, dict) and "id" in first_collection:
|
|
return first_collection["id"]
|
|
if isinstance(first_collection, str):
|
|
return first_collection
|
|
except (json_lib.JSONDecodeError, TypeError, KeyError):
|
|
pass
|
|
return None
|
|
|
|
async def aclose(self) -> None:
|
|
return None
|
|
|
|
async def __aenter__(self) -> Self:
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
return None
|
|
|
|
class DocumentsAPI:
|
|
def __init__(self, service: Any) -> None:
|
|
self._service = service
|
|
|
|
async def retrieve(self, document_id: str) -> dict[str, Any]:
|
|
document = self._service.get_document(document_id)
|
|
if document is None:
|
|
raise DummyR2RException("Not found", status_code=404)
|
|
return {"results": document}
|
|
|
|
async def delete(self, document_id: str) -> dict[str, Any]:
|
|
if not self._service.delete_document(document_id):
|
|
raise DummyR2RException("Not found", status_code=404)
|
|
return {"results": {"success": True}}
|
|
|
|
async def append_metadata(self, id: str, metadata: list[dict[str, Any]]) -> dict[str, Any]:
|
|
document = self._service.append_document_metadata(id, metadata)
|
|
if document is None:
|
|
raise DummyR2RException("Not found", status_code=404)
|
|
return {"results": document}
|
|
|
|
class RetrievalAPI:
|
|
def __init__(self, service: Any) -> None:
|
|
self._service = service
|
|
|
|
async def search(self, query: str, search_settings: Mapping[str, Any]) -> dict[str, Any]:
|
|
results = [{"document_id": doc_id, "score": 1.0} for doc_id in self._service._documents]
|
|
return {"results": results}
|
|
|
|
class DummyClient:
|
|
def __init__(self, service: Any) -> None:
|
|
self.documents = DocumentsAPI(service)
|
|
self.retrieval = RetrievalAPI(service)
|
|
|
|
async def aclose(self) -> None:
|
|
return None
|
|
|
|
async def close(self) -> None:
|
|
return None
|
|
|
|
# Mock the AsyncClient that R2RStorage uses internally
|
|
mock_async_client = MockAsyncClient(r2r_service)
|
|
monkeypatch.setattr(
|
|
"ingest_pipeline.storage.r2r.storage.AsyncClient",
|
|
lambda **kwargs: mock_async_client,
|
|
)
|
|
|
|
client = DummyClient(r2r_service)
|
|
monkeypatch.setattr(
|
|
"ingest_pipeline.storage.r2r.storage.R2RAsyncClient",
|
|
lambda endpoint: client,
|
|
)
|
|
monkeypatch.setattr(
|
|
"ingest_pipeline.storage.r2r.storage.R2RException",
|
|
DummyR2RException,
|
|
)
|
|
return client
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("value", "expected"),
|
|
[
|
|
pytest.param({"a": 1}, {"a": 1}, id="mapping"),
|
|
pytest.param(SimpleNamespace(a=2), {"a": 2}, id="namespace"),
|
|
pytest.param(5, {}, id="other"),
|
|
],
|
|
)
|
|
def test_as_mapping_normalizes(value, expected) -> None:
|
|
"""Convert inputs into dictionaries where possible."""
|
|
|
|
assert _as_mapping(value) == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("value", "expected"),
|
|
[
|
|
pytest.param([1, 2], (1, 2), id="list"),
|
|
pytest.param((3, 4), (3, 4), id="tuple"),
|
|
pytest.param("ab", ("a", "b"), id="string"),
|
|
pytest.param(7, (), id="non-iterable"),
|
|
],
|
|
)
|
|
def test_as_sequence_coerces_iterables(value, expected) -> None:
|
|
"""Represent values as tuples for downstream iteration."""
|
|
|
|
assert _as_sequence(value) == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("source", "fallback", "expected"),
|
|
[
|
|
pytest.param({"id": "abc"}, "x", "abc", id="mapping"),
|
|
pytest.param(SimpleNamespace(id=123), "x", "123", id="attribute"),
|
|
pytest.param({}, "fallback", "fallback", id="fallback"),
|
|
],
|
|
)
|
|
def test_extract_id_falls_back(source, fallback, expected) -> None:
|
|
"""Prefer embedded identifier values and fall back otherwise."""
|
|
|
|
assert _extract_id(source, fallback) == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("value", "expected_year"),
|
|
[
|
|
pytest.param(datetime(2024, 1, 1, tzinfo=UTC), 2024, id="datetime"),
|
|
pytest.param("2024-02-01T00:00:00+00:00", 2024, id="iso"),
|
|
pytest.param("invalid", datetime.now(UTC).year, id="fallback"),
|
|
],
|
|
)
|
|
def test_as_datetime_recognizes_formats(value, expected_year) -> None:
|
|
"""Produce timezone-aware datetime objects."""
|
|
|
|
assert _as_datetime(value).year == expected_year
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("value", "expected"),
|
|
[
|
|
pytest.param(True, 1, id="bool"),
|
|
pytest.param(5, 5, id="int"),
|
|
pytest.param(3.9, 3, id="float"),
|
|
pytest.param("7", 7, id="string"),
|
|
pytest.param("8.2", 8, id="float-string"),
|
|
pytest.param("bad", 2, id="default"),
|
|
],
|
|
)
|
|
def test_as_int_handles_numeric_coercions(value, expected) -> None:
|
|
"""Convert assorted numeric representations."""
|
|
|
|
assert _as_int(value, default=2) == expected
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ensure_collection_finds_existing(
|
|
r2r_storage_config: StorageConfig,
|
|
r2r_service,
|
|
httpx_stub,
|
|
r2r_client_stub,
|
|
) -> None:
|
|
"""Return collection identifier when already present."""
|
|
|
|
r2r_service.create_collection(
|
|
name=r2r_storage_config.collection_name,
|
|
collection_id="col-1",
|
|
)
|
|
storage = R2RStorage(r2r_storage_config)
|
|
|
|
collection_id = await storage._ensure_collection(r2r_storage_config.collection_name)
|
|
|
|
assert collection_id == "col-1"
|
|
assert storage.default_collection_id == "col-1"
|
|
await storage.client.aclose() # type: ignore[attr-defined]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ensure_collection_creates_when_missing(
|
|
r2r_storage_config: StorageConfig,
|
|
r2r_service,
|
|
httpx_stub,
|
|
r2r_client_stub,
|
|
) -> None:
|
|
"""Create collection via POST when absent."""
|
|
|
|
storage = R2RStorage(r2r_storage_config)
|
|
|
|
collection_id = await storage._ensure_collection("alternate")
|
|
|
|
assert collection_id is not None
|
|
located = r2r_service.find_collection_by_name("alternate")
|
|
assert located is not None
|
|
identifier, _ = located
|
|
assert identifier == collection_id
|
|
await storage.client.aclose() # type: ignore[attr-defined]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_store_batch_creates_documents(
|
|
document_factory,
|
|
r2r_storage_config: StorageConfig,
|
|
r2r_service,
|
|
httpx_stub,
|
|
r2r_client_stub,
|
|
) -> None:
|
|
"""Store documents and persist them via the R2R mock service."""
|
|
|
|
storage = R2RStorage(r2r_storage_config)
|
|
documents = [
|
|
document_factory(content="first document", metadata_updates={"title": "First"}),
|
|
document_factory(content="second document", metadata_updates={"title": "Second"}),
|
|
]
|
|
|
|
stored_ids = await storage.store_batch(documents)
|
|
|
|
assert len(stored_ids) == 2
|
|
for doc_id, original in zip(stored_ids, documents, strict=False):
|
|
stored = r2r_service.get_document(doc_id)
|
|
assert stored is not None
|
|
assert stored["metadata"]["source_url"] == original.metadata["source_url"]
|
|
|
|
collection = r2r_service.find_collection_by_name(r2r_storage_config.collection_name)
|
|
assert collection is not None
|
|
_, collection_payload = collection
|
|
assert collection_payload["document_count"] == 2
|
|
|
|
await storage.client.aclose() # type: ignore[attr-defined]
|