Files
rag-manager/tests/unit/storage/test_r2r_helpers.py
2025-09-21 10:25:54 +00:00

392 lines
14 KiB
Python

from __future__ import annotations
from collections.abc import Mapping
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import Any, Self
import httpx
import pytest
from ingest_pipeline.core.models import StorageConfig
from ingest_pipeline.storage.r2r.storage import (
R2RStorage,
_as_datetime,
_as_int,
_as_mapping,
_as_sequence,
_extract_id,
)
@pytest.fixture
def r2r_client_stub(
monkeypatch: pytest.MonkeyPatch,
r2r_service,
) -> object:
class DummyR2RException(Exception):
def __init__(self, message: str, status_code: int | None = None) -> None:
super().__init__(message)
self.status_code = status_code
class MockResponse:
def __init__(self, json_data: dict[str, Any], status_code: int = 200) -> None:
self._json_data = json_data
self.status_code = status_code
def json(self) -> dict[str, Any]:
return self._json_data
def raise_for_status(self) -> None:
if self.status_code >= 400:
# Create minimal mock request and response for HTTPStatusError
mock_request = httpx.Request("GET", "http://test.local")
mock_response = httpx.Response(
status_code=self.status_code,
request=mock_request,
)
raise httpx.HTTPStatusError(
"HTTP error", request=mock_request, response=mock_response
)
class MockAsyncClient:
def __init__(self, service: Any) -> None:
self._service = service
async def get(self, url: str) -> MockResponse:
if "/v3/collections" in url:
# Return existing collections
collections = []
for collection_id, collection_data in self._service._collections.items():
collections.append(
{
"id": collection_id,
"name": collection_data["name"],
"description": collection_data.get("description", ""),
}
)
return MockResponse({"results": collections})
return MockResponse({})
async def post(
self,
url: str,
*,
json: dict[str, Any] | None = None,
files: dict[str, Any] | None = None,
) -> MockResponse:
if "/v3/collections" in url and json:
return self._handle_collection_creation(json)
if "/v3/documents" in url and files:
return self._handle_document_creation(files)
return MockResponse({})
def _handle_collection_creation(self, json: dict[str, Any]) -> MockResponse:
"""Handle collection creation via POST."""
new_collection_id = f"col-{len(self._service._collections) + 1}"
self._service.create_collection(
name=json["name"],
collection_id=new_collection_id,
description=json.get("description", ""),
)
return MockResponse(
{
"results": {
"id": new_collection_id,
"name": json["name"],
"description": json.get("description", ""),
}
}
)
def _handle_document_creation(self, files: dict[str, Any]) -> MockResponse:
"""Handle document creation via POST with files."""
document_id = self._extract_document_id(files)
content = self._extract_content(files)
metadata = self._extract_metadata(files)
# Store document in mock service
document_data = {
"id": document_id,
"content": content,
"metadata": metadata,
}
self._service._documents[document_id] = document_data
# Update collection document count if needed
self._update_collection_document_count(files)
return MockResponse(
{
"results": {
"document_id": document_id,
"message": "Document created successfully",
}
}
)
def _extract_document_id(self, files: dict[str, Any]) -> str:
"""Extract document ID from files."""
return files.get("id", (None, f"doc-{len(self._service._documents) + 1}"))[1]
def _extract_content(self, files: dict[str, Any]) -> str:
"""Extract content from files."""
return files.get("raw_text", (None, ""))[1]
def _extract_metadata(self, files: dict[str, Any]) -> dict[str, Any]:
"""Extract and parse metadata from files."""
import json as json_lib
metadata_str = files.get("metadata", (None, "{}"))[1]
return json_lib.loads(metadata_str) if metadata_str else {}
def _update_collection_document_count(self, files: dict[str, Any]) -> None:
"""Update collection document count if collection_ids present."""
if not (collection_ids := files.get("collection_ids")):
return
collection_id = self._parse_collection_ids(collection_ids)
if collection_id and collection_id in self._service._collections:
total_docs = len(self._service._documents)
self._service._collections[collection_id]["document_count"] = total_docs
def _parse_collection_ids(self, collection_ids: Any) -> str | None:
"""Parse collection IDs from files entry."""
import json as json_lib
collection_ids_str = (
collection_ids[1] if isinstance(collection_ids, tuple) else collection_ids
)
try:
collection_list = (
json_lib.loads(collection_ids_str)
if isinstance(collection_ids_str, str)
else collection_ids_str
)
if isinstance(collection_list, list) and collection_list:
first_collection = collection_list[0]
if isinstance(first_collection, dict) and "id" in first_collection:
return first_collection["id"]
if isinstance(first_collection, str):
return first_collection
except (json_lib.JSONDecodeError, TypeError, KeyError):
pass
return None
async def aclose(self) -> None:
return None
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
return None
class DocumentsAPI:
def __init__(self, service: Any) -> None:
self._service = service
async def retrieve(self, document_id: str) -> dict[str, Any]:
document = self._service.get_document(document_id)
if document is None:
raise DummyR2RException("Not found", status_code=404)
return {"results": document}
async def delete(self, document_id: str) -> dict[str, Any]:
if not self._service.delete_document(document_id):
raise DummyR2RException("Not found", status_code=404)
return {"results": {"success": True}}
async def append_metadata(self, id: str, metadata: list[dict[str, Any]]) -> dict[str, Any]:
document = self._service.append_document_metadata(id, metadata)
if document is None:
raise DummyR2RException("Not found", status_code=404)
return {"results": document}
class RetrievalAPI:
def __init__(self, service: Any) -> None:
self._service = service
async def search(self, query: str, search_settings: Mapping[str, Any]) -> dict[str, Any]:
results = [{"document_id": doc_id, "score": 1.0} for doc_id in self._service._documents]
return {"results": results}
class DummyClient:
def __init__(self, service: Any) -> None:
self.documents = DocumentsAPI(service)
self.retrieval = RetrievalAPI(service)
async def aclose(self) -> None:
return None
async def close(self) -> None:
return None
# Mock the AsyncClient that R2RStorage uses internally
mock_async_client = MockAsyncClient(r2r_service)
monkeypatch.setattr(
"ingest_pipeline.storage.r2r.storage.AsyncClient",
lambda **kwargs: mock_async_client,
)
client = DummyClient(r2r_service)
monkeypatch.setattr(
"ingest_pipeline.storage.r2r.storage.R2RAsyncClient",
lambda endpoint: client,
)
monkeypatch.setattr(
"ingest_pipeline.storage.r2r.storage.R2RException",
DummyR2RException,
)
return client
@pytest.mark.parametrize(
("value", "expected"),
[
pytest.param({"a": 1}, {"a": 1}, id="mapping"),
pytest.param(SimpleNamespace(a=2), {"a": 2}, id="namespace"),
pytest.param(5, {}, id="other"),
],
)
def test_as_mapping_normalizes(value, expected) -> None:
"""Convert inputs into dictionaries where possible."""
assert _as_mapping(value) == expected
@pytest.mark.parametrize(
("value", "expected"),
[
pytest.param([1, 2], (1, 2), id="list"),
pytest.param((3, 4), (3, 4), id="tuple"),
pytest.param("ab", ("a", "b"), id="string"),
pytest.param(7, (), id="non-iterable"),
],
)
def test_as_sequence_coerces_iterables(value, expected) -> None:
"""Represent values as tuples for downstream iteration."""
assert _as_sequence(value) == expected
@pytest.mark.parametrize(
("source", "fallback", "expected"),
[
pytest.param({"id": "abc"}, "x", "abc", id="mapping"),
pytest.param(SimpleNamespace(id=123), "x", "123", id="attribute"),
pytest.param({}, "fallback", "fallback", id="fallback"),
],
)
def test_extract_id_falls_back(source, fallback, expected) -> None:
"""Prefer embedded identifier values and fall back otherwise."""
assert _extract_id(source, fallback) == expected
@pytest.mark.parametrize(
("value", "expected_year"),
[
pytest.param(datetime(2024, 1, 1, tzinfo=UTC), 2024, id="datetime"),
pytest.param("2024-02-01T00:00:00+00:00", 2024, id="iso"),
pytest.param("invalid", datetime.now(UTC).year, id="fallback"),
],
)
def test_as_datetime_recognizes_formats(value, expected_year) -> None:
"""Produce timezone-aware datetime objects."""
assert _as_datetime(value).year == expected_year
@pytest.mark.parametrize(
("value", "expected"),
[
pytest.param(True, 1, id="bool"),
pytest.param(5, 5, id="int"),
pytest.param(3.9, 3, id="float"),
pytest.param("7", 7, id="string"),
pytest.param("8.2", 8, id="float-string"),
pytest.param("bad", 2, id="default"),
],
)
def test_as_int_handles_numeric_coercions(value, expected) -> None:
"""Convert assorted numeric representations."""
assert _as_int(value, default=2) == expected
@pytest.mark.asyncio
async def test_ensure_collection_finds_existing(
r2r_storage_config: StorageConfig,
r2r_service,
httpx_stub,
r2r_client_stub,
) -> None:
"""Return collection identifier when already present."""
r2r_service.create_collection(
name=r2r_storage_config.collection_name,
collection_id="col-1",
)
storage = R2RStorage(r2r_storage_config)
collection_id = await storage._ensure_collection(r2r_storage_config.collection_name)
assert collection_id == "col-1"
assert storage.default_collection_id == "col-1"
await storage.client.aclose() # type: ignore[attr-defined]
@pytest.mark.asyncio
async def test_ensure_collection_creates_when_missing(
r2r_storage_config: StorageConfig,
r2r_service,
httpx_stub,
r2r_client_stub,
) -> None:
"""Create collection via POST when absent."""
storage = R2RStorage(r2r_storage_config)
collection_id = await storage._ensure_collection("alternate")
assert collection_id is not None
located = r2r_service.find_collection_by_name("alternate")
assert located is not None
identifier, _ = located
assert identifier == collection_id
await storage.client.aclose() # type: ignore[attr-defined]
@pytest.mark.asyncio
async def test_store_batch_creates_documents(
document_factory,
r2r_storage_config: StorageConfig,
r2r_service,
httpx_stub,
r2r_client_stub,
) -> None:
"""Store documents and persist them via the R2R mock service."""
storage = R2RStorage(r2r_storage_config)
documents = [
document_factory(content="first document", metadata_updates={"title": "First"}),
document_factory(content="second document", metadata_updates={"title": "Second"}),
]
stored_ids = await storage.store_batch(documents)
assert len(stored_ids) == 2
for doc_id, original in zip(stored_ids, documents, strict=False):
stored = r2r_service.get_document(doc_id)
assert stored is not None
assert stored["metadata"]["source_url"] == original.metadata["source_url"]
collection = r2r_service.find_collection_by_name(r2r_storage_config.collection_name)
assert collection is not None
_, collection_payload = collection
assert collection_payload["document_count"] == 2
await storage.client.aclose() # type: ignore[attr-defined]