This commit is contained in:
2025-09-19 06:56:19 +00:00
parent 37d1e434af
commit 2d01228f24
120 changed files with 30801 additions and 1255 deletions

View File

@@ -7,7 +7,8 @@
"WebSearch",
"Bash(cat:*)",
"mcp__firecrawl__firecrawl_search",
"mcp__firecrawl__firecrawl_scrape"
"mcp__firecrawl__firecrawl_scrape",
"Bash(python:*)"
],
"deny": [],
"ask": []

100
.gitignore vendored Normal file
View File

@@ -0,0 +1,100 @@
# Environment files
.env
.env.local
.env.*.local
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# Virtual environments
.venv/
.env/
venv/
ENV/
env/
# Package managers
.uv_cache/
.uv-cache/
.pip-cache/
pip-log.txt
pip-delete-this-directory.txt
# Testing
.tox/
.coverage
.coverage.*
.cache
.pytest_cache/
htmlcov/
.nox/
coverage.xml
*.cover
.hypothesis/
# Type checking
.mypy_cache/
.dmypy.json
dmypy.json
.pyre/
.pytype/
# Linting and formatting
.ruff_cache/
.flake8_cache/
.black_cache/
# Jupyter Notebook
.ipynb_checkpoints
# IDE and editors
.vscode/
.idea/
*.swp
*.swo
*~
# OS
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db
# Logs
logs/
*.log
log/
# Temporary files
*.tmp
*.temp
.tmp/
.temp/
# Project specific
repomix-output.xml
*.json.bak
chat.json

36
.vscode/settings.json vendored
View File

@@ -1,7 +1,37 @@
{
"chatgpt.openOnStartup": true,
"python.languageServer": "None",
"python.analysis.typeCheckingMode": "off",
"python.defaultInterpreterPath": "./.venv/bin/python",
"python.terminal.activateEnvironment": true
"python.terminal.activateEnvironment": true,
"python.linting.enabled": true,
"python.linting.mypyEnabled": true,
"python.linting.mypyPath": "./.venv/bin/mypy",
"python.linting.pylintEnabled": false,
"python.linting.flake8Enabled": false,
"python.analysis.typeCheckingMode": "basic",
"python.analysis.autoImportCompletions": true,
"python.analysis.stubPath": "./.venv/lib/python3.12/site-packages",
"basedpyright.analysis.typeCheckingMode": "standard",
"basedpyright.analysis.autoSearchPaths": true,
"basedpyright.analysis.autoImportCompletions": true,
"basedpyright.analysis.diagnosticMode": "workspace",
"basedpyright.analysis.stubPath": "./.venv/lib/python3.12/site-packages",
"basedpyright.analysis.extraPaths": [
"./ingest_pipeline",
"./.venv/lib/python3.12/site-packages"
],
"pyright.analysis.typeCheckingMode": "standard",
"pyright.analysis.autoSearchPaths": true,
"pyright.analysis.autoImportCompletions": true,
"pyright.analysis.diagnosticMode": "workspace",
"pyright.analysis.stubPath": "./.venv/lib/python3.12/site-packages",
"pyright.analysis.extraPaths": [
"./ingest_pipeline",
"./.venv/lib/python3.12/site-packages"
],
"files.exclude": {
"**/__pycache__": true,
"**/.pytest_cache": true,
"**/node_modules": true,
".mypy_cache": true
}
}

View File

@@ -6,11 +6,26 @@
"**/__pycache__",
"**/.pytest_cache",
"**/node_modules",
".venv"
".venv",
"build",
"dist"
],
"pythonPath": "./.venv/bin/python",
"pythonVersion": "3.12",
"venvPath": ".",
"venv": ".venv",
"typeCheckingMode": "standard",
"useLibraryCodeForTypes": true,
"stubPath": "./.venv/lib/python3.12/site-packages",
"executionEnvironments": [
{
"root": ".",
"pythonVersion": "3.12",
"extraPaths": [
"./ingest_pipeline",
"./.venv/lib/python3.12/site-packages"
]
}
],
"reportCallInDefaultInitializer": "none",
"reportUnknownVariableType": "warning",
"reportUnknownMemberType": "warning",
@@ -21,9 +36,12 @@
"reportUnannotatedClassAttribute": "warning",
"reportMissingTypeStubs": "none",
"reportMissingModuleSource": "none",
"reportImportCycles": "none",
"reportAttributeAccessIssue": "warning",
"reportAny": "warning",
"reportUnusedCallResult": "none",
"reportUnnecessaryIsInstance": "none",
"reportImplicitOverride": "none",
"reportDeprecated": "warning"
"reportDeprecated": "warning",
"analyzeUnannotatedFunctions": true
}

23483
chat.json

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,62 @@
"""Prefect Automations for ingestion pipeline monitoring and management."""
# Automation configurations as YAML-ready dictionaries
AUTOMATION_TEMPLATES = {
"cancel_long_running": """
name: Cancel Long Running Ingestion Flows
description: Cancels ingestion flows running longer than 30 minutes
trigger:
type: event
posture: Proactive
expect: [prefect.flow-run.Running]
match_related:
prefect.resource.role: flow
prefect.resource.name: ingestion_pipeline
threshold: 1
within: 1800
actions:
- type: cancel-flow-run
source: inferred
enabled: true
""",
"retry_failed": """
name: Retry Failed Ingestion Flows
description: Retries failed ingestion flows with original parameters
trigger:
type: event
posture: Reactive
expect: [prefect.flow-run.Failed]
match_related:
prefect.resource.role: flow
prefect.resource.name: ingestion_pipeline
threshold: 1
within: 0
actions:
- type: run-deployment
source: inferred
parameters:
validate_first: false
enabled: true
""",
"resource_monitoring": """
name: Manage Work Pool Based on Resources
description: Pauses work pool when system resources are constrained
trigger:
type: event
posture: Reactive
expect: [system.resource.high_usage]
threshold: 1
within: 120
actions:
- type: pause-work-pool
work_pool_name: default
enabled: true
""",
}
def get_automation_yaml_templates() -> dict[str, str]:
"""Get automation templates as YAML strings."""
return AUTOMATION_TEMPLATES.copy()

View File

@@ -4,13 +4,19 @@ import asyncio
from typing import Annotated
import typer
from pydantic import SecretStr
from rich.console import Console
from rich.panel import Panel
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
from rich.table import Table
from ..config import configure_prefect, get_settings
from ..core.models import IngestionResult, IngestionSource, StorageBackend
from ..core.models import (
IngestionResult,
IngestionSource,
StorageBackend,
StorageConfig,
)
from ..flows.ingestion import create_ingestion_flow
from ..flows.scheduler import create_scheduled_deployment, serve_deployments
@@ -439,6 +445,22 @@ def search(
asyncio.run(run_search(query, collection, backend.value, limit))
@app.command(name="blocks")
def blocks_command() -> None:
"""🧩 List and manage Prefect Blocks."""
console.print("[bold cyan]📦 Prefect Blocks Management[/bold cyan]")
console.print("Use 'prefect block register --module ingest_pipeline.core.models' to register custom blocks")
console.print("Use 'prefect block ls' to list available blocks")
@app.command(name="variables")
def variables_command() -> None:
"""📊 Manage Prefect Variables."""
console.print("[bold cyan]📊 Prefect Variables Management[/bold cyan]")
console.print("Use 'prefect variable set VARIABLE_NAME value' to set variables")
console.print("Use 'prefect variable ls' to list variables")
async def run_ingestion(
url: str,
source_type: IngestionSource,
@@ -470,8 +492,8 @@ async def run_list_collections() -> None:
"""
List collections across storage backends.
"""
from ..config import configure_prefect, get_settings
from ..core.models import StorageBackend, StorageConfig
from ..config import get_settings
from ..core.models import StorageBackend
from ..storage.openwebui import OpenWebUIStorage
from ..storage.weaviate import WeaviateStorage
@@ -485,7 +507,7 @@ async def run_list_collections() -> None:
weaviate_config = StorageConfig(
backend=StorageBackend.WEAVIATE,
endpoint=settings.weaviate_endpoint,
api_key=settings.weaviate_api_key,
api_key=SecretStr(settings.weaviate_api_key) if settings.weaviate_api_key is not None else None,
collection_name="default",
)
weaviate = WeaviateStorage(weaviate_config)
@@ -494,7 +516,8 @@ async def run_list_collections() -> None:
overview = await weaviate.describe_collections()
for item in overview:
name = str(item.get("name", "Unknown"))
count = int(item.get("count", 0))
count_val = item.get("count", 0)
count = int(count_val) if isinstance(count_val, (int, str)) else 0
weaviate_collections.append((name, count))
except Exception as e:
console.print(f"❌ [red]Weaviate connection failed: {e}[/red]")
@@ -505,7 +528,7 @@ async def run_list_collections() -> None:
openwebui_config = StorageConfig(
backend=StorageBackend.OPEN_WEBUI,
endpoint=settings.openwebui_endpoint,
api_key=settings.openwebui_api_key,
api_key=SecretStr(settings.openwebui_api_key) if settings.openwebui_api_key is not None else None,
collection_name="default",
)
openwebui = OpenWebUIStorage(openwebui_config)
@@ -514,7 +537,8 @@ async def run_list_collections() -> None:
overview = await openwebui.describe_collections()
for item in overview:
name = str(item.get("name", "Unknown"))
count = int(item.get("count", 0))
count_val = item.get("count", 0)
count = int(count_val) if isinstance(count_val, (int, str)) else 0
openwebui_collections.append((name, count))
except Exception as e:
console.print(f"❌ [red]OpenWebUI connection failed: {e}[/red]")
@@ -551,8 +575,8 @@ async def run_search(query: str, collection: str | None, backend: str, limit: in
"""
Search across collections.
"""
from ..config import configure_prefect, get_settings
from ..core.models import StorageBackend, StorageConfig
from ..config import get_settings
from ..core.models import StorageBackend
from ..storage.weaviate import WeaviateStorage
settings = get_settings()
@@ -569,7 +593,7 @@ async def run_search(query: str, collection: str | None, backend: str, limit: in
weaviate_config = StorageConfig(
backend=StorageBackend.WEAVIATE,
endpoint=settings.weaviate_endpoint,
api_key=settings.weaviate_api_key,
api_key=SecretStr(settings.weaviate_api_key) if settings.weaviate_api_key is not None else None,
collection_name=collection or "default",
)
weaviate = WeaviateStorage(weaviate_config)

View File

@@ -17,7 +17,8 @@ from textual.timer import Timer
from ...storage.base import BaseStorage
from ...storage.openwebui import OpenWebUIStorage
from ...storage.weaviate import WeaviateStorage
from .screens import CollectionOverviewScreen, HelpScreen
from .screens.dashboard import CollectionOverviewScreen
from .screens.help import HelpScreen
from .styles import TUI_CSS
from .utils.storage_manager import StorageManager
@@ -35,6 +36,8 @@ class CollectionManagementApp(App[None]):
"""Enhanced modern Textual application with comprehensive keyboard navigation."""
CSS: ClassVar[str] = TUI_CSS
TITLE = "Collection Management"
SUB_TITLE = "Document Ingestion Pipeline"
def safe_notify(
self,
@@ -89,8 +92,8 @@ class CollectionManagementApp(App[None]):
self.weaviate = weaviate
self.openwebui = openwebui
self.r2r = r2r
self.title: str = ""
self.sub_title: str = ""
# Remove direct assignment to read-only title properties
# These should be set through class attributes or overridden methods
self.log_queue = log_queue
self._log_formatter = log_formatter or logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
@@ -146,7 +149,7 @@ class CollectionManagementApp(App[None]):
if drained and self._log_viewer is not None:
self._log_viewer.append_logs(drained)
def attach_log_viewer(self, viewer: "LogViewerScreen") -> None:
def attach_log_viewer(self, viewer: LogViewerScreen) -> None:
"""Register an active log viewer and hydrate it with existing entries."""
self._log_viewer = viewer
viewer.replace_logs(list(self._log_buffer))
@@ -154,7 +157,7 @@ class CollectionManagementApp(App[None]):
# Drain once more to deliver any entries gathered between instantiation and mount
self._drain_log_queue()
def detach_log_viewer(self, viewer: "LogViewerScreen") -> None:
def detach_log_viewer(self, viewer: LogViewerScreen) -> None:
"""Remove the current log viewer when it is dismissed."""
if self._log_viewer is viewer:
self._log_viewer = None
@@ -273,7 +276,7 @@ class CollectionManagementApp(App[None]):
if len(self.screen_stack) > 1: # Don't close the main screen
_ = self.pop_screen()
else:
_ = self.notify("Cannot close main screen. Use Q to quit.", severity="warning")
self.notify("Cannot close main screen. Use Q to quit.", severity="warning")
def action_dashboard_tab(self) -> None:
"""Switch to dashboard tab in current screen."""
@@ -305,7 +308,7 @@ class CollectionManagementApp(App[None]):
_ = event.prevent_default()
elif event.key == "ctrl+alt+r":
# Force refresh all connections
_ = self.notify("🔄 Refreshing all connections...", severity="information")
self.notify("🔄 Refreshing all connections...", severity="information")
# This could trigger a full reinit if needed
_ = event.prevent_default()
# No else clause needed - just handle our events

View File

@@ -163,12 +163,12 @@ class CollapsibleSidebar(Container):
else:
_ = self.remove_class("collapsed")
def expand(self) -> None:
def expand_sidebar(self) -> None:
"""Expand sidebar."""
if self.collapsed:
self.toggle()
def collapse(self) -> None:
def collapse_sidebar(self) -> None:
"""Collapse sidebar."""
if not self.collapsed:
self.toggle()

View File

@@ -2,14 +2,14 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from typing import TYPE_CHECKING, Generic, TypeVar
from textual import work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.screen import ModalScreen, Screen
from textual.widget import Widget
from textual.widgets import Button, DataTable, LoadingIndicator, Static
from typing_extensions import override
@@ -19,16 +19,25 @@ if TYPE_CHECKING:
T = TypeVar("T")
class BaseScreen(Screen, ABC):
class BaseScreen(Screen[object]):
"""Base screen with common functionality."""
def __init__(self, storage_manager: StorageManager, **kwargs: Any) -> None:
def __init__(
self,
storage_manager: StorageManager,
*,
name: str | None = None,
id: str | None = None,
classes: str | None = None,
**kwargs: object
) -> None:
"""Initialize base screen."""
super().__init__(**kwargs)
super().__init__(name=name, id=id, classes=classes)
# Ignore any additional kwargs to avoid type issues
self.storage_manager = storage_manager
class CRUDScreen(BaseScreen, Generic[T], ABC):
class CRUDScreen(BaseScreen, Generic[T]):
"""Base class for Create/Read/Update/Delete operations."""
BINDINGS = [
@@ -39,9 +48,17 @@ class CRUDScreen(BaseScreen, Generic[T], ABC):
Binding("escape", "app.pop_screen", "Back"),
]
def __init__(self, storage_manager: StorageManager, **kwargs: Any) -> None:
def __init__(
self,
storage_manager: StorageManager,
*,
name: str | None = None,
id: str | None = None,
classes: str | None = None,
**kwargs: object
) -> None:
"""Initialize CRUD screen."""
super().__init__(storage_manager, **kwargs)
super().__init__(storage_manager, name=name, id=id, classes=classes)
self.items: list[T] = []
self.selected_item: T | None = None
self.loading = False
@@ -77,35 +94,29 @@ class CRUDScreen(BaseScreen, Generic[T], ABC):
table.add_columns(*self.get_table_columns())
return table
@abstractmethod
def get_table_columns(self) -> list[str]:
"""Get table column headers."""
pass
raise NotImplementedError("Subclasses must implement get_table_columns")
@abstractmethod
async def load_items(self) -> list[T]:
"""Load items from storage."""
pass
raise NotImplementedError("Subclasses must implement load_items")
@abstractmethod
def item_to_row(self, item: T) -> list[str]:
"""Convert item to table row."""
pass
raise NotImplementedError("Subclasses must implement item_to_row")
@abstractmethod
async def create_item_dialog(self) -> T | None:
"""Show create item dialog."""
pass
raise NotImplementedError("Subclasses must implement create_item_dialog")
@abstractmethod
async def edit_item_dialog(self, item: T) -> T | None:
"""Show edit item dialog."""
pass
raise NotImplementedError("Subclasses must implement edit_item_dialog")
@abstractmethod
async def delete_item(self, item: T) -> bool:
"""Delete item."""
pass
raise NotImplementedError("Subclasses must implement delete_item")
def on_mount(self) -> None:
"""Initialize screen."""
@@ -176,17 +187,21 @@ class CRUDScreen(BaseScreen, Generic[T], ABC):
self.refresh_items()
class ListScreen(BaseScreen, Generic[T], ABC):
class ListScreen(BaseScreen, Generic[T]):
"""Base for paginated list views."""
def __init__(
self,
storage_manager: StorageManager,
page_size: int = 20,
**kwargs: Any,
*,
name: str | None = None,
id: str | None = None,
classes: str | None = None,
**kwargs: object,
) -> None:
"""Initialize list screen."""
super().__init__(storage_manager, **kwargs)
super().__init__(storage_manager, name=name, id=id, classes=classes)
self.page_size = page_size
self.current_page = 0
self.total_items = 0
@@ -204,25 +219,21 @@ class ListScreen(BaseScreen, Generic[T], ABC):
classes="list-container",
)
@abstractmethod
def get_title(self) -> str:
"""Get screen title."""
pass
raise NotImplementedError("Subclasses must implement get_title")
@abstractmethod
def create_filters(self) -> Container:
"""Create filter widgets."""
pass
raise NotImplementedError("Subclasses must implement create_filters")
@abstractmethod
def create_list_view(self) -> Any:
def create_list_view(self) -> Widget:
"""Create list view widget."""
pass
raise NotImplementedError("Subclasses must implement create_list_view")
@abstractmethod
async def load_page(self, page: int, page_size: int) -> tuple[list[T], int]:
"""Load page of items."""
pass
raise NotImplementedError("Subclasses must implement load_page")
def create_pagination(self) -> Container:
"""Create pagination controls."""
@@ -249,10 +260,9 @@ class ListScreen(BaseScreen, Generic[T], ABC):
finally:
self.set_loading(False)
@abstractmethod
async def update_list_view(self) -> None:
"""Update list view with current items."""
pass
raise NotImplementedError("Subclasses must implement update_list_view")
def update_pagination_info(self) -> None:
"""Update pagination information."""
@@ -285,7 +295,7 @@ class ListScreen(BaseScreen, Generic[T], ABC):
self.load_current_page()
class FormScreen(ModalScreen[T], Generic[T], ABC):
class FormScreen(ModalScreen[T], Generic[T]):
"""Base for input forms with validation."""
BINDINGS = [
@@ -294,9 +304,18 @@ class FormScreen(ModalScreen[T], Generic[T], ABC):
Binding("enter", "save", "Save"),
]
def __init__(self, item: T | None = None, **kwargs: Any) -> None:
def __init__(
self,
item: T | None = None,
*,
name: str | None = None,
id: str | None = None,
classes: str | None = None,
**kwargs: object
) -> None:
"""Initialize form screen."""
super().__init__(**kwargs)
super().__init__(name=name, id=id, classes=classes)
# Ignore any additional kwargs to avoid type issues
self.item = item
self.is_edit_mode = item is not None
@@ -315,35 +334,30 @@ class FormScreen(ModalScreen[T], Generic[T], ABC):
classes="form-container",
)
@abstractmethod
def get_item_type(self) -> str:
"""Get item type name for title."""
pass
raise NotImplementedError("Subclasses must implement get_item_type")
@abstractmethod
def create_form_fields(self) -> Container:
"""Create form input fields."""
pass
raise NotImplementedError("Subclasses must implement create_form_fields")
@abstractmethod
def validate_form(self) -> tuple[bool, list[str]]:
"""Validate form data."""
pass
raise NotImplementedError("Subclasses must implement validate_form")
@abstractmethod
def get_form_data(self) -> T:
"""Get item from form data."""
pass
raise NotImplementedError("Subclasses must implement get_form_data")
def on_mount(self) -> None:
"""Initialize form."""
if self.is_edit_mode and self.item:
self.populate_form(self.item)
@abstractmethod
def populate_form(self, item: T) -> None:
"""Populate form with item data."""
pass
raise NotImplementedError("Subclasses must implement populate_form")
def action_save(self) -> None:
"""Save form data."""

View File

@@ -212,8 +212,9 @@ class CollectionOverviewScreen(Screen[None]):
"""Update the metrics cards display."""
try:
dashboard_tab = self.query_one("#dashboard")
metrics_cards = dashboard_tab.query(MetricsCard)
if len(metrics_cards) >= 4:
metrics_cards_query = dashboard_tab.query(MetricsCard)
if len(metrics_cards_query) >= 4:
metrics_cards = list(metrics_cards_query)
self._update_card_values(metrics_cards)
self._update_status_card(metrics_cards[3])
except NoMatches:
@@ -221,13 +222,13 @@ class CollectionOverviewScreen(Screen[None]):
except Exception as exc:
LOGGER.exception("Failed to update dashboard metrics", exc_info=exc)
def _update_card_values(self, metrics_cards: list) -> None:
def _update_card_values(self, metrics_cards: list[MetricsCard]) -> None:
"""Update individual metric card values."""
metrics_cards[0].query_one(".metrics-value", Static).update(f"{self.total_collections:,}")
metrics_cards[1].query_one(".metrics-value", Static).update(f"{self.total_documents:,}")
metrics_cards[2].query_one(".metrics-value", Static).update(str(self.active_backends))
def _update_status_card(self, status_card: object) -> None:
def _update_status_card(self, status_card: MetricsCard) -> None:
"""Update the system status card."""
if self.active_backends > 0 and self.total_collections > 0:
status_text, status_class = "🟢 Healthy", "status-active"
@@ -362,8 +363,10 @@ class CollectionOverviewScreen(Screen[None]):
collections: list[CollectionInfo] = []
for item in overview:
count_val = int(item.get("count", 0))
size_mb_val = float(item.get("size_mb", 0.0))
count_raw = item.get("count", 0)
count_val = int(count_raw) if isinstance(count_raw, (int, str)) else 0
size_mb_raw = item.get("size_mb", 0.0)
size_mb_val = float(size_mb_raw) if isinstance(size_mb_raw, (int, float, str)) else 0.0
collections.append(
CollectionInfo(
name=str(item.get("name", "Unknown")),
@@ -393,8 +396,10 @@ class CollectionOverviewScreen(Screen[None]):
collections: list[CollectionInfo] = []
for item in overview:
count_val = int(item.get("count", 0))
size_mb_val = float(item.get("size_mb", 0.0))
count_raw = item.get("count", 0)
count_val = int(count_raw) if isinstance(count_raw, (int, str)) else 0
size_mb_raw = item.get("size_mb", 0.0)
size_mb_val = float(size_mb_raw) if isinstance(size_mb_raw, (int, float, str)) else 0.0
collection_name = str(item.get("name", "Unknown"))
collections.append(
CollectionInfo(
@@ -504,9 +509,7 @@ class CollectionOverviewScreen(Screen[None]):
def action_manage(self) -> None:
"""Manage documents in selected collection."""
if selected := self.get_selected_collection():
# Get the appropriate storage backend for the collection
storage_backend = self._get_storage_for_collection(selected)
if storage_backend:
if storage_backend := self._get_storage_for_collection(selected):
from .documents import DocumentManagementScreen
self.app.push_screen(DocumentManagementScreen(selected, storage_backend))

View File

@@ -1,10 +1,10 @@
"""Dialog screens for confirmations and user interactions."""
from pathlib import Path
from typing import TYPE_CHECKING, ClassVar
from typing import TYPE_CHECKING
from textual.app import ComposeResult
from textual.binding import Binding, BindingType
from textual.binding import Binding
from textual.containers import Container, Horizontal
from textual.screen import ModalScreen, Screen
from textual.widgets import Button, Footer, Header, LoadingIndicator, RichLog, Static
@@ -23,7 +23,7 @@ class ConfirmDeleteScreen(Screen[None]):
collection: CollectionInfo
parent_screen: "CollectionOverviewScreen"
BINDINGS: ClassVar[list[BindingType]] = [
BINDINGS = [
Binding("escape", "app.pop_screen", "Cancel"),
Binding("y", "confirm_delete", "Yes"),
Binding("n", "app.pop_screen", "No"),
@@ -145,7 +145,7 @@ class ConfirmDocumentDeleteScreen(Screen[None]):
collection: CollectionInfo
parent_screen: "DocumentManagementScreen"
BINDINGS: ClassVar[list[BindingType]] = [
BINDINGS = [
Binding("escape", "app.pop_screen", "Cancel"),
Binding("y", "confirm_delete", "Yes"),
Binding("n", "app.pop_screen", "No"),
@@ -205,9 +205,12 @@ class ConfirmDocumentDeleteScreen(Screen[None]):
loading.display = True
try:
if self.parent_screen.weaviate:
# Delete documents
results = await self.parent_screen.weaviate.delete_documents(
if hasattr(self.parent_screen, 'storage') and self.parent_screen.storage:
# Delete documents via storage
# The storage should have delete_documents method for weaviate
storage = self.parent_screen.storage
if hasattr(storage, 'delete_documents'):
results = await storage.delete_documents(
self.doc_ids,
collection_name=self.collection["name"],
)
@@ -238,7 +241,7 @@ class LogViewerScreen(ModalScreen[None]):
_log_widget: RichLog | None
_log_file: Path | None
BINDINGS: ClassVar[list[BindingType]] = [
BINDINGS = [
Binding("escape", "close", "Close"),
Binding("ctrl+l", "close", "Close"),
Binding("s", "show_path", "Log File"),
@@ -264,21 +267,21 @@ class LogViewerScreen(ModalScreen[None]):
def on_mount(self) -> None:
"""Attach this viewer to the parent application once mounted."""
self._log_widget = self.query_one(RichLog)
from ..app import CollectionManagementApp
if isinstance(self.app, CollectionManagementApp):
if hasattr(self.app, 'attach_log_viewer'):
self.app.attach_log_viewer(self)
def on_unmount(self) -> None:
"""Detach from the parent application when closed."""
from ..app import CollectionManagementApp
if isinstance(self.app, CollectionManagementApp):
if hasattr(self.app, 'detach_log_viewer'):
self.app.detach_log_viewer(self)
def _get_log_widget(self) -> RichLog:
if self._log_widget is None:
self._log_widget = self.query_one(RichLog)
if self._log_widget is None:
raise RuntimeError("RichLog widget not found")
return self._log_widget
def replace_logs(self, lines: list[str]) -> None:

View File

@@ -10,7 +10,6 @@ from textual.widgets import Button, Footer, Header, Label, LoadingIndicator, Sta
from typing_extensions import override
from ....storage.base import BaseStorage
from ....storage.weaviate import WeaviateStorage
from ..models import CollectionInfo, DocumentInfo
from ..widgets import EnhancedDataTable
@@ -115,9 +114,9 @@ class DocumentManagementScreen(Screen[None]):
description=str(doc.get("description", "")),
content_type=str(doc.get("content_type", "text/plain")),
content_preview=str(doc.get("content_preview", "")),
word_count=int(doc.get("word_count", 0))
if str(doc.get("word_count", 0)).isdigit()
else 0,
word_count=(
lambda wc_val: int(wc_val) if isinstance(wc_val, (int, str)) and str(wc_val).isdigit() else 0
)(doc.get("word_count", 0)),
timestamp=str(doc.get("timestamp", "")),
)
for i, doc in enumerate(raw_docs)

View File

@@ -383,12 +383,12 @@ class IngestionScreen(ModalScreen[None]):
# Run the Prefect flow for this backend using asyncio.run with timeout
import asyncio
async def run_flow_with_timeout() -> IngestionResult:
async def run_flow_with_timeout(current_backend: StorageBackend = backend) -> IngestionResult:
return await asyncio.wait_for(
create_ingestion_flow(
source_url=source_url,
source_type=self.selected_type,
storage_backend=backend,
storage_backend=current_backend,
collection_name=final_collection_name,
progress_callback=progress_reporter,
),
@@ -403,7 +403,7 @@ class IngestionScreen(ModalScreen[None]):
if result.error_messages:
flow_errors.extend([f"{backend.value}: {err}" for err in result.error_messages])
except asyncio.TimeoutError:
except TimeoutError:
error_msg = f"{backend.value}: Timeout after 10 minutes"
flow_errors.append(error_msg)
progress_reporter(0, f"{backend.value} timed out")

View File

@@ -112,7 +112,7 @@ class SearchScreen(Screen[None]):
async def search_collection(self, query: str) -> None:
"""Search the collection."""
loading = self.query_one("#loading")
loading = self.query_one("#loading", LoadingIndicator)
table = self.query_one("#results_table", EnhancedDataTable)
status = self.query_one("#search_status", Static)
@@ -127,7 +127,7 @@ class SearchScreen(Screen[None]):
finally:
loading.display = False
def _setup_search_ui(self, loading, table, status, query: str) -> None:
def _setup_search_ui(self, loading: LoadingIndicator, table: EnhancedDataTable, status: Static, query: str) -> None:
"""Setup the search UI elements."""
loading.display = True
status.update(f"🔍 Searching for '{query}'...")
@@ -142,7 +142,7 @@ class SearchScreen(Screen[None]):
return await self.search_openwebui(query)
return []
def _populate_results_table(self, table, results: list[dict[str, str | float]]) -> None:
def _populate_results_table(self, table: EnhancedDataTable, results: list[dict[str, str | float]]) -> None:
"""Populate the results table with search results."""
for result in results:
row_data = self._format_result_row(result)
@@ -183,7 +183,7 @@ class SearchScreen(Screen[None]):
content = str(content) if content is not None else ""
return f"{content[:60]}..." if len(content) > 60 else content
def _format_score(self, score) -> str:
def _format_score(self, score: object) -> str:
"""Format search score for display."""
if isinstance(score, (int, float)):
return f"{score:.3f}"
@@ -193,7 +193,7 @@ class SearchScreen(Screen[None]):
return str(score)
def _update_search_status(
self, status, query: str, results: list[dict[str, str | float]], table
self, status: Static, query: str, results: list[dict[str, str | float]], table: EnhancedDataTable
) -> None:
"""Update search status and notifications based on results."""
if not results:

View File

@@ -1106,7 +1106,7 @@ def get_css_for_theme(theme_type: ThemeType) -> str:
return css
def apply_theme_to_app(app, theme_type: ThemeType) -> None:
def apply_theme_to_app(app: object, theme_type: ThemeType) -> None:
"""Apply a theme to a Textual app instance."""
try:
css = set_theme(theme_type)
@@ -1114,8 +1114,8 @@ def apply_theme_to_app(app, theme_type: ThemeType) -> None:
app.stylesheet.clear()
app.stylesheet.parse(css)
elif hasattr(app, "CSS"):
app.CSS = css
else:
setattr(app, "CSS", css)
elif hasattr(app, "refresh"):
# Fallback: try to refresh the app with new CSS
app.refresh()
except Exception as e:
@@ -1127,7 +1127,7 @@ def apply_theme_to_app(app, theme_type: ThemeType) -> None:
class ThemeSwitcher:
"""Helper class for managing theme switching in TUI applications."""
def __init__(self, app=None):
def __init__(self, app: object | None = None) -> None:
self.app = app
self.theme_history = [ThemeType.DARK]

View File

@@ -8,12 +8,10 @@ from logging import Logger
from logging.handlers import QueueHandler, RotatingFileHandler
from pathlib import Path
from queue import Queue
from typing import NamedTuple, cast
from typing import NamedTuple
from ....config import configure_prefect, get_settings
from ....core.models import StorageBackend
from ....storage.openwebui import OpenWebUIStorage
from ....storage.weaviate import WeaviateStorage
from .storage_manager import StorageManager
@@ -113,11 +111,16 @@ async def run_textual_tui() -> None:
)
# Get individual storage instances for backward compatibility
weaviate = cast(WeaviateStorage | None, storage_manager.get_backend(StorageBackend.WEAVIATE))
openwebui = cast(
OpenWebUIStorage | None, storage_manager.get_backend(StorageBackend.OPEN_WEBUI)
)
r2r = storage_manager.get_backend(StorageBackend.R2R)
from ....storage.openwebui import OpenWebUIStorage
from ....storage.weaviate import WeaviateStorage
weaviate_backend = storage_manager.get_backend(StorageBackend.WEAVIATE)
openwebui_backend = storage_manager.get_backend(StorageBackend.OPEN_WEBUI)
r2r_backend = storage_manager.get_backend(StorageBackend.R2R)
# Type-safe casting to specific storage types
weaviate = weaviate_backend if isinstance(weaviate_backend, WeaviateStorage) else None
openwebui = openwebui_backend if isinstance(openwebui_backend, OpenWebUIStorage) else None
# Import here to avoid circular import
from ..app import CollectionManagementApp
@@ -125,7 +128,7 @@ async def run_textual_tui() -> None:
storage_manager,
weaviate,
openwebui,
r2r,
r2r_backend,
log_queue=logging_context.queue,
log_formatter=logging_context.formatter,
log_file=logging_context.log_file,

View File

@@ -11,12 +11,13 @@ from ....core.exceptions import StorageError
from ....core.models import Document, StorageBackend, StorageConfig
from ..models import CollectionInfo, StorageCapabilities
from ....storage.base import BaseStorage
from ....storage.openwebui import OpenWebUIStorage
from ....storage.r2r.storage import R2RStorage
from ....storage.weaviate import WeaviateStorage
if TYPE_CHECKING:
from ....config.settings import Settings
from ....storage.weaviate import WeaviateStorage
from ....storage.r2r.storage import R2RStorage
from ....storage.openwebui import OpenWebUIStorage
from ....storage.base import BaseStorage
class StorageBackendProtocol(Protocol):

View File

@@ -2,7 +2,8 @@
from __future__ import annotations
from typing import Any, cast
import json
from typing import cast
from textual.app import ComposeResult
from textual.containers import Container, Horizontal
@@ -61,9 +62,17 @@ class ScrapeOptionsForm(Container):
}
"""
def __init__(self, **kwargs: Any) -> None:
def __init__(
self,
*,
name: str | None = None,
id: str | None = None,
classes: str | None = None,
disabled: bool = False,
markup: bool = True,
) -> None:
"""Initialize scrape options form."""
super().__init__(**kwargs)
super().__init__(name=name, id=id, classes=classes, disabled=disabled, markup=markup)
@override
def compose(self) -> ComposeResult:
@@ -136,10 +145,8 @@ class ScrapeOptionsForm(Container):
classes="form-section",
)
def get_scrape_options(self) -> dict[str, Any]:
def get_scrape_options(self) -> dict[str, object]:
"""Get scraping options from form."""
options: dict[str, Any] = {}
# Collect formats
formats = []
if self.query_one("#format_markdown", Checkbox).value:
@@ -148,11 +155,12 @@ class ScrapeOptionsForm(Container):
formats.append("html")
if self.query_one("#format_screenshot", Checkbox).value:
formats.append("screenshot")
options["formats"] = formats
# Content filtering
options["only_main_content"] = self.query_one("#only_main_content", Switch).value
options: dict[str, object] = {
"formats": formats,
"only_main_content": self.query_one(
"#only_main_content", Switch
).value,
}
include_tags_input = self.query_one("#include_tags", Input).value
if include_tags_input.strip():
options["include_tags"] = [tag.strip() for tag in include_tags_input.split(",")]
@@ -171,22 +179,26 @@ class ScrapeOptionsForm(Container):
return options
def set_scrape_options(self, options: dict[str, Any]) -> None:
def set_scrape_options(self, options: dict[str, object]) -> None:
"""Set form values from options."""
# Set formats
formats = options.get("formats", ["markdown"])
self.query_one("#format_markdown", Checkbox).value = "markdown" in formats
self.query_one("#format_html", Checkbox).value = "html" in formats
self.query_one("#format_screenshot", Checkbox).value = "screenshot" in formats
formats_list = formats if isinstance(formats, list) else []
self.query_one("#format_markdown", Checkbox).value = "markdown" in formats_list
self.query_one("#format_html", Checkbox).value = "html" in formats_list
self.query_one("#format_screenshot", Checkbox).value = "screenshot" in formats_list
# Set content filtering
self.query_one("#only_main_content", Switch).value = options.get("only_main_content", True)
main_content_val = options.get("only_main_content", True)
self.query_one("#only_main_content", Switch).value = bool(main_content_val)
if include_tags := options.get("include_tags", []):
self.query_one("#include_tags", Input).value = ", ".join(include_tags)
include_list = include_tags if isinstance(include_tags, list) else []
self.query_one("#include_tags", Input).value = ", ".join(str(tag) for tag in include_list)
if exclude_tags := options.get("exclude_tags", []):
self.query_one("#exclude_tags", Input).value = ", ".join(exclude_tags)
exclude_list = exclude_tags if isinstance(exclude_tags, list) else []
self.query_one("#exclude_tags", Input).value = ", ".join(str(tag) for tag in exclude_list)
# Set performance
wait_for = options.get("wait_for")
@@ -231,9 +243,17 @@ class MapOptionsForm(Container):
}
"""
def __init__(self, **kwargs: Any) -> None:
def __init__(
self,
*,
name: str | None = None,
id: str | None = None,
classes: str | None = None,
disabled: bool = False,
markup: bool = True,
) -> None:
"""Initialize map options form."""
super().__init__(**kwargs)
super().__init__(name=name, id=id, classes=classes, disabled=disabled, markup=markup)
@override
def compose(self) -> ComposeResult:
@@ -286,9 +306,9 @@ class MapOptionsForm(Container):
classes="form-section",
)
def get_map_options(self) -> dict[str, Any]:
def get_map_options(self) -> dict[str, object]:
"""Get mapping options from form."""
options: dict[str, Any] = {}
options: dict[str, object] = {}
# Discovery settings
search_pattern = self.query_one("#search_pattern", Input).value
@@ -314,12 +334,13 @@ class MapOptionsForm(Container):
return options
def set_map_options(self, options: dict[str, Any]) -> None:
def set_map_options(self, options: dict[str, object]) -> None:
"""Set form values from options."""
if search := options.get("search"):
self.query_one("#search_pattern", Input).value = str(search)
self.query_one("#include_subdomains", Switch).value = options.get("include_subdomains", False)
subdomains_val = options.get("include_subdomains", False)
self.query_one("#include_subdomains", Switch).value = bool(subdomains_val)
# Set limits
limit = options.get("limit")
@@ -373,9 +394,17 @@ class ExtractOptionsForm(Container):
}
"""
def __init__(self, **kwargs: Any) -> None:
def __init__(
self,
*,
name: str | None = None,
id: str | None = None,
classes: str | None = None,
disabled: bool = False,
markup: bool = True,
) -> None:
"""Initialize extract options form."""
super().__init__(**kwargs)
super().__init__(name=name, id=id, classes=classes, disabled=disabled, markup=markup)
@override
def compose(self) -> ComposeResult:
@@ -429,9 +458,9 @@ class ExtractOptionsForm(Container):
classes="form-section",
)
def get_extract_options(self) -> dict[str, Any]:
def get_extract_options(self) -> dict[str, object]:
"""Get extraction options from form."""
options: dict[str, Any] = {}
options: dict[str, object] = {}
# Extract prompt
prompt = self.query_one("#extract_prompt", TextArea).text
@@ -442,8 +471,6 @@ class ExtractOptionsForm(Container):
schema_text = self.query_one("#extract_schema", TextArea).text
if schema_text.strip():
try:
import json
schema = json.loads(schema_text)
options["extract_schema"] = schema
except json.JSONDecodeError:
@@ -452,7 +479,7 @@ class ExtractOptionsForm(Container):
return options
def set_extract_options(self, options: dict[str, Any]) -> None:
def set_extract_options(self, options: dict[str, object]) -> None:
"""Set form values from options."""
if prompt := options.get("extract_prompt"):
self.query_one("#extract_prompt", TextArea).text = str(prompt)
@@ -551,9 +578,17 @@ class FirecrawlConfigWidget(Container):
}
"""
def __init__(self, **kwargs: Any) -> None:
def __init__(
self,
*,
name: str | None = None,
id: str | None = None,
classes: str | None = None,
disabled: bool = False,
markup: bool = True,
) -> None:
"""Initialize Firecrawl config widget."""
super().__init__(**kwargs)
super().__init__(name=name, id=id, classes=classes, disabled=disabled, markup=markup)
self.current_tab = "scrape"
@override
@@ -604,7 +639,7 @@ class FirecrawlConfigWidget(Container):
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button presses."""
if event.button.id.startswith("tab_"):
if event.button.id and event.button.id.startswith("tab_"):
tab_name = event.button.id[4:] # Remove "tab_" prefix
self.show_tab(tab_name)
@@ -622,15 +657,15 @@ class FirecrawlConfigWidget(Container):
pass
elif self.current_tab == "map":
try:
form = self.query_one("#map_form", MapOptionsForm)
map_opts = form.get_map_options()
map_form = self.query_one("#map_form", MapOptionsForm)
map_opts = map_form.get_map_options()
options.update(cast(FirecrawlOptions, map_opts))
except Exception:
pass
elif self.current_tab == "extract":
try:
form = self.query_one("#extract_form", ExtractOptionsForm)
extract_opts = form.get_extract_options()
extract_form = self.query_one("#extract_form", ExtractOptionsForm)
extract_opts = extract_form.get_extract_options()
options.update(cast(FirecrawlOptions, extract_opts))
except Exception:
pass

View File

@@ -98,9 +98,11 @@ class ChunkViewer(Container):
"id": str(chunk_data.get("id", "")),
"document_id": self.document_id,
"content": str(chunk_data.get("text", "")),
"start_index": int(chunk_data.get("start_index", 0)),
"end_index": int(chunk_data.get("end_index", 0)),
"metadata": dict(chunk_data.get("metadata", {})),
"start_index": (lambda si: int(si) if isinstance(si, (int, str)) else 0)(chunk_data.get("start_index", 0)),
"end_index": (lambda ei: int(ei) if isinstance(ei, (int, str)) else 0)(chunk_data.get("end_index", 0)),
"metadata": (
dict(metadata_val) if (metadata_val := chunk_data.get("metadata")) and isinstance(metadata_val, dict) else {}
),
}
self.chunks.append(chunk_info)
@@ -217,6 +219,8 @@ class EntityGraph(Container):
# Parse entities from R2R response
entities_list = entities_data.get("entities", [])
if not isinstance(entities_list, list):
entities_list = []
for entity_data in entities_list:
entity_info: EntityInfo = {
"id": str(entity_data.get("id", "")),
@@ -260,7 +264,7 @@ class EntityGraph(Container):
tree.root.expand()
def on_tree_node_selected(self, event: Tree.NodeSelected) -> None:
def on_tree_node_selected(self, event: Tree.NodeSelected[EntityInfo]) -> None:
"""Handle entity selection."""
if hasattr(event.node, "data") and event.node.data:
entity = event.node.data
@@ -488,7 +492,8 @@ class DocumentOverview(Container):
doc_table = self.query_one("#doc_info_table", DataTable)
doc_table.add_columns("Property", "Value")
document_info = overview_data.get("document", {})
document_info_raw = overview_data.get("document", {})
document_info = document_info_raw if isinstance(document_info_raw, dict) else {}
doc_table.add_row("ID", str(document_info.get("id", "N/A")))
doc_table.add_row("Title", str(document_info.get("title", "N/A")))
doc_table.add_row("Created", str(document_info.get("created_at", "N/A")))

View File

@@ -4,15 +4,43 @@ from __future__ import annotations
from contextlib import ExitStack
from prefect.settings import (
PREFECT_API_KEY,
PREFECT_API_URL,
PREFECT_DEFAULT_WORK_POOL_NAME,
Setting,
temporary_settings,
)
from prefect.settings import Setting, temporary_settings
from .settings import Settings, get_settings
# Import Prefect settings with version compatibility - avoid static analysis issues
def _setup_prefect_settings() -> tuple[object, object, object]:
"""Setup Prefect settings with proper fallbacks."""
try:
import prefect.settings as ps
# Try to get the settings directly
api_key = getattr(ps, "PREFECT_API_KEY", None)
api_url = getattr(ps, "PREFECT_API_URL", None)
work_pool = getattr(ps, "PREFECT_DEFAULT_WORK_POOL_NAME", None)
if api_key is not None:
return api_key, api_url, work_pool
# Fallback to registry-based approach
registry = getattr(ps, "PREFECT_SETTING_REGISTRY", None)
if registry is not None:
Setting = getattr(ps, "Setting", None)
if Setting is not None:
api_key = registry.get("PREFECT_API_KEY") or Setting("PREFECT_API_KEY", type_=str, default=None)
api_url = registry.get("PREFECT_API_URL") or Setting("PREFECT_API_URL", type_=str, default=None)
work_pool = registry.get("PREFECT_DEFAULT_WORK_POOL_NAME") or Setting("PREFECT_DEFAULT_WORK_POOL_NAME", type_=str, default=None)
return api_key, api_url, work_pool
except ImportError:
pass
# Ultimate fallback
return None, None, None
PREFECT_API_KEY, PREFECT_API_URL, PREFECT_DEFAULT_WORK_POOL_NAME = _setup_prefect_settings()
# Import after Prefect settings setup to avoid circular dependencies
from .settings import Settings, get_settings # noqa: E402
__all__ = ["Settings", "get_settings", "configure_prefect"]
@@ -25,11 +53,11 @@ def configure_prefect(settings: Settings) -> None:
overrides: dict[Setting, str] = {}
if settings.prefect_api_url is not None:
if settings.prefect_api_url is not None and PREFECT_API_URL is not None and isinstance(PREFECT_API_URL, Setting):
overrides[PREFECT_API_URL] = str(settings.prefect_api_url)
if settings.prefect_api_key:
if settings.prefect_api_key and PREFECT_API_KEY is not None and isinstance(PREFECT_API_KEY, Setting):
overrides[PREFECT_API_KEY] = settings.prefect_api_key
if settings.prefect_work_pool:
if settings.prefect_work_pool and PREFECT_DEFAULT_WORK_POOL_NAME is not None and isinstance(PREFECT_DEFAULT_WORK_POOL_NAME, Setting):
overrides[PREFECT_DEFAULT_WORK_POOL_NAME] = settings.prefect_work_pool
if not overrides:

View File

@@ -3,6 +3,7 @@
from functools import lru_cache
from typing import Annotated, Literal
from prefect.variables import Variable
from pydantic import Field, HttpUrl, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -147,3 +148,77 @@ def get_settings() -> Settings:
Settings instance
"""
return Settings()
class PrefectVariableConfig:
"""Helper class for managing Prefect variables with fallbacks to settings."""
def __init__(self) -> None:
self._settings = get_settings()
self._variable_names = [
"default_batch_size", "max_file_size", "max_crawl_depth", "max_crawl_pages",
"default_storage_backend", "default_collection_prefix", "max_concurrent_tasks",
"request_timeout", "default_schedule_interval"
]
def _get_fallback_value(self, name: str, default_value: object = None) -> object:
"""Get fallback value from settings or default."""
return default_value or getattr(self._settings, name, default_value)
def get_with_fallback(self, name: str, default_value: str | int | float | None = None) -> str | int | float | None:
"""Get variable value with fallback synchronously."""
fallback = self._get_fallback_value(name, default_value)
# Ensure fallback is a type that Variable expects
variable_fallback = str(fallback) if fallback is not None else None
try:
result = Variable.get(name, default=variable_fallback)
# Variable can return various types, convert to our expected types
if isinstance(result, (str, int, float)):
return result
elif result is None:
return None
else:
# Convert other types to string
return str(result)
except Exception:
# Return fallback with proper type
if isinstance(fallback, (str, int, float)) or fallback is None:
return fallback
return str(fallback) if fallback is not None else None
async def get_with_fallback_async(self, name: str, default_value: str | int | float | None = None) -> str | int | float | None:
"""Get variable value with fallback asynchronously."""
fallback = self._get_fallback_value(name, default_value)
variable_fallback = str(fallback) if fallback is not None else None
try:
result = await Variable.aget(name, default=variable_fallback)
# Variable can return various types, convert to our expected types
if isinstance(result, (str, int, float)):
return result
elif result is None:
return None
else:
# Convert other types to string
return str(result)
except Exception:
# Return fallback with proper type
if isinstance(fallback, (str, int, float)) or fallback is None:
return fallback
return str(fallback) if fallback is not None else None
def get_ingestion_config(self) -> dict[str, str | int | float | None]:
"""Get all ingestion-related configuration variables synchronously."""
return {name: self.get_with_fallback(name) for name in self._variable_names}
async def get_ingestion_config_async(self) -> dict[str, str | int | float | None]:
"""Get all ingestion-related configuration variables asynchronously."""
result = {}
for name in self._variable_names:
result[name] = await self.get_with_fallback_async(name)
return result
@lru_cache
def get_prefect_config() -> PrefectVariableConfig:
"""Get cached Prefect variable configuration helper."""
return PrefectVariableConfig()

View File

@@ -5,7 +5,8 @@ from enum import Enum
from typing import Annotated, TypedDict
from uuid import UUID, uuid4
from pydantic import BaseModel, Field, HttpUrl
from prefect.blocks.core import Block
from pydantic import BaseModel, Field, HttpUrl, SecretStr
class IngestionStatus(str, Enum):
@@ -44,19 +45,27 @@ class VectorConfig(BaseModel):
batch_size: Annotated[int, Field(gt=0, le=1000)] = 100
class StorageConfig(BaseModel):
class StorageConfig(Block):
"""Configuration for storage backend."""
_block_type_name = "Storage Configuration"
_block_type_slug = "storage-config"
_description = "Configures storage backend connections and settings for document ingestion"
backend: StorageBackend
endpoint: HttpUrl
api_key: str | None = Field(default=None)
api_key: SecretStr | None = Field(default=None)
collection_name: str = Field(default="documents")
batch_size: Annotated[int, Field(gt=0, le=1000)] = 100
class FirecrawlConfig(BaseModel):
class FirecrawlConfig(Block):
"""Configuration for Firecrawl ingestion (operational parameters only)."""
_block_type_name = "Firecrawl Configuration"
_block_type_slug = "firecrawl-config"
_description = "Configures Firecrawl web scraping and crawling parameters"
formats: list[str] = Field(default_factory=lambda: ["markdown", "html"])
max_depth: Annotated[int, Field(ge=1, le=20)] = 5
limit: Annotated[int, Field(ge=1, le=1000)] = 100
@@ -64,9 +73,13 @@ class FirecrawlConfig(BaseModel):
include_subdomains: bool = Field(default=False)
class RepomixConfig(BaseModel):
class RepomixConfig(Block):
"""Configuration for Repomix ingestion."""
_block_type_name = "Repomix Configuration"
_block_type_slug = "repomix-config"
_description = "Configures repository ingestion patterns and file processing settings"
include_patterns: list[str] = Field(
default_factory=lambda: ["*.py", "*.js", "*.ts", "*.md", "*.yaml", "*.json"]
)
@@ -77,9 +90,13 @@ class RepomixConfig(BaseModel):
respect_gitignore: bool = Field(default=True)
class R2RConfig(BaseModel):
class R2RConfig(Block):
"""Configuration for R2R ingestion."""
_block_type_name = "R2R Configuration"
_block_type_slug = "r2r-config"
_description = "Configures R2R-specific ingestion settings including chunking and graph enrichment"
chunk_size: Annotated[int, Field(ge=100, le=8192)] = 1000
chunk_overlap: Annotated[int, Field(ge=0, le=1000)] = 200
enable_graph_enrichment: bool = Field(default=False)

View File

@@ -3,12 +3,13 @@
from __future__ import annotations
from collections.abc import Callable
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING, Literal, assert_never, cast
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Literal, TypeAlias, assert_never, cast
from prefect import flow, get_run_logger, task
from prefect.cache_policies import NO_CACHE
from prefect.futures import wait
from prefect.blocks.core import Block
from prefect.variables import Variable
from pydantic.types import SecretStr
from ..config.settings import Settings
from ..core.exceptions import IngestionError
@@ -31,8 +32,14 @@ from ..utils.metadata_tagger import MetadataTagger
SourceTypeLiteral = Literal["web", "repository", "documentation"]
StorageBackendLiteral = Literal["weaviate", "open_webui", "r2r"]
SourceTypeLike = IngestionSource | SourceTypeLiteral
StorageBackendLike = StorageBackend | StorageBackendLiteral
SourceTypeLike: TypeAlias = IngestionSource | SourceTypeLiteral
StorageBackendLike: TypeAlias = StorageBackend | StorageBackendLiteral
def _safe_cache_key(prefix: str, params: dict[str, object], key: str) -> str:
"""Create a type-safe cache key from task parameters."""
value = params.get(key, "")
return f"{prefix}_{hash(str(value))}"
if TYPE_CHECKING:
@@ -65,16 +72,22 @@ async def validate_source_task(source_url: str, source_type: IngestionSource) ->
@task(name="initialize_storage", retries=3, retry_delay_seconds=5, tags=["storage"])
async def initialize_storage_task(config: StorageConfig) -> BaseStorage:
async def initialize_storage_task(config: StorageConfig | str) -> BaseStorage:
"""
Initialize storage backend.
Args:
config: Storage configuration
config: Storage configuration block or block name
Returns:
Initialized storage adapter
"""
# Load block if string provided
if isinstance(config, str):
# Use Block.aload with type slug for better type inference
loaded_block = await Block.aload(f"storage-config/{config}")
config = cast(StorageConfig, loaded_block)
if config.backend == StorageBackend.WEAVIATE:
storage = WeaviateStorage(config)
elif config.backend == StorageBackend.OPEN_WEBUI:
@@ -90,38 +103,48 @@ async def initialize_storage_task(config: StorageConfig) -> BaseStorage:
return storage
@task(name="map_firecrawl_site", retries=2, retry_delay_seconds=15, tags=["firecrawl", "map"])
async def map_firecrawl_site_task(source_url: str, config: FirecrawlConfig) -> list[str]:
@task(name="map_firecrawl_site", retries=2, retry_delay_seconds=15, tags=["firecrawl", "map"],
cache_key_fn=lambda ctx, p: _safe_cache_key("firecrawl_map", p, "source_url"))
async def map_firecrawl_site_task(source_url: str, config: FirecrawlConfig | str) -> list[str]:
"""Map a site using Firecrawl and return discovered URLs."""
# Load block if string provided
if isinstance(config, str):
# Use Block.aload with type slug for better type inference
loaded_block = await Block.aload(f"firecrawl-config/{config}")
config = cast(FirecrawlConfig, loaded_block)
ingestor = FirecrawlIngestor(config)
mapped = await ingestor.map_site(source_url)
return mapped or [source_url]
@task(name="filter_existing_documents", retries=1, retry_delay_seconds=5, tags=["r2r", "dedup"], cache_policy=NO_CACHE)
@task(name="filter_existing_documents", retries=1, retry_delay_seconds=5, tags=["dedup"],
cache_key_fn=lambda ctx, p: _safe_cache_key("filter_docs", p, "urls")) # Cache based on URL list
async def filter_existing_documents_task(
urls: list[str],
storage_client: R2RStorageType,
storage_client: BaseStorage,
stale_after_days: int = 30,
*,
collection_name: str | None = None,
) -> list[str]:
"""Filter URLs whose documents are missing or stale in R2R."""
"""Filter URLs to only those that need scraping (missing or stale in storage)."""
logger = get_run_logger()
cutoff = datetime.now(UTC) - timedelta(days=stale_after_days)
eligible: list[str] = []
for url in urls:
document_id = str(FirecrawlIngestor.compute_document_id(url))
existing: Document | None = await storage_client.retrieve(document_id)
if existing is None:
eligible.append(url)
continue
exists = await storage_client.check_exists(
document_id,
collection_name=collection_name,
stale_after_days=stale_after_days
)
timestamp = existing.metadata["timestamp"]
if timestamp < cutoff:
if not exists:
eligible.append(url)
if skipped := len(urls) - len(eligible):
logger.info("Skipping %s up-to-date pages", skipped)
skipped = len(urls) - len(eligible)
if skipped > 0:
logger.info("Skipping %s up-to-date documents in %s", skipped, storage_client.display_name)
return eligible
@@ -134,7 +157,8 @@ async def scrape_firecrawl_batch_task(
) -> list[FirecrawlPage]:
"""Scrape a batch of URLs via Firecrawl."""
ingestor = FirecrawlIngestor(config)
return await ingestor.scrape_pages(batch_urls)
result: list[FirecrawlPage] = await ingestor.scrape_pages(batch_urls)
return result
@task(name="annotate_firecrawl_metadata", retries=1, retry_delay_seconds=10, tags=["metadata"])
@@ -153,7 +177,8 @@ async def annotate_firecrawl_metadata_task(
settings = get_settings()
async with MetadataTagger(llm_endpoint=str(settings.llm_endpoint)) as tagger:
return await tagger.tag_batch(documents)
tagged_documents: list[Document] = await tagger.tag_batch(documents)
return tagged_documents
except IngestionError as exc: # pragma: no cover - logging side effect
logger = get_run_logger()
logger.warning("Metadata tagging failed: %s", exc)
@@ -164,7 +189,7 @@ async def annotate_firecrawl_metadata_task(
return documents
@task(name="upsert_r2r_documents", retries=2, retry_delay_seconds=20, tags=["storage", "r2r"], cache_policy=NO_CACHE)
@task(name="upsert_r2r_documents", retries=2, retry_delay_seconds=20, tags=["storage", "r2r"])
async def upsert_r2r_documents_task(
storage_client: R2RStorageType,
documents: list[Document],
@@ -187,12 +212,14 @@ async def upsert_r2r_documents_task(
return processed, failed
@task(name="ingest_documents", retries=2, retry_delay_seconds=30, tags=["ingestion"], cache_policy=NO_CACHE)
@task(name="ingest_documents", retries=2, retry_delay_seconds=30, tags=["ingestion"])
async def ingest_documents_task(
job: IngestionJob,
collection_name: str | None = None,
batch_size: int = 50,
batch_size: int | None = None,
storage_client: BaseStorage | None = None,
storage_block_name: str | None = None,
ingestor_config_block_name: str | None = None,
progress_callback: Callable[[int, str], None] | None = None,
) -> tuple[int, int]:
"""
@@ -201,8 +228,10 @@ async def ingest_documents_task(
Args:
job: Ingestion job configuration
collection_name: Target collection name
batch_size: Number of documents per batch
batch_size: Number of documents per batch (uses Variable if None)
storage_client: Optional pre-initialized storage client
storage_block_name: Optional storage block name to load
ingestor_config_block_name: Optional ingestor config block name to load
progress_callback: Optional callback for progress updates
Returns:
@@ -211,8 +240,22 @@ async def ingest_documents_task(
if progress_callback:
progress_callback(35, "Creating ingestor and storage clients...")
ingestor = _create_ingestor(job)
storage = storage_client or await _create_storage(job, collection_name)
# Use Variable for batch size if not provided
if batch_size is None:
try:
batch_size_var = await Variable.aget("default_batch_size", default="50")
# Convert Variable result to int, handling various types
if isinstance(batch_size_var, int):
batch_size = batch_size_var
elif isinstance(batch_size_var, (str, float)):
batch_size = int(float(str(batch_size_var)))
else:
batch_size = 50
except Exception:
batch_size = 50
ingestor = await _create_ingestor(job, ingestor_config_block_name)
storage = storage_client or await _create_storage(job, collection_name, storage_block_name)
if progress_callback:
progress_callback(40, "Starting document processing...")
@@ -220,30 +263,50 @@ async def ingest_documents_task(
return await _process_documents(ingestor, storage, job, batch_size, collection_name, progress_callback)
def _create_ingestor(job: IngestionJob) -> BaseIngestor:
async def _create_ingestor(job: IngestionJob, config_block_name: str | None = None) -> BaseIngestor:
"""Create appropriate ingestor based on job source type."""
if job.source_type == IngestionSource.WEB:
config = FirecrawlConfig()
if config_block_name:
# Use Block.aload with type slug for better type inference
loaded_block = await Block.aload(f"firecrawl-config/{config_block_name}")
config = cast(FirecrawlConfig, loaded_block)
else:
# Fallback to default configuration
config = FirecrawlConfig()
return FirecrawlIngestor(config)
elif job.source_type == IngestionSource.REPOSITORY:
config = RepomixConfig()
if config_block_name:
# Use Block.aload with type slug for better type inference
loaded_block = await Block.aload(f"repomix-config/{config_block_name}")
config = cast(RepomixConfig, loaded_block)
else:
# Fallback to default configuration
config = RepomixConfig()
return RepomixIngestor(config)
else:
raise ValueError(f"Unsupported source: {job.source_type}")
async def _create_storage(job: IngestionJob, collection_name: str | None) -> BaseStorage:
async def _create_storage(job: IngestionJob, collection_name: str | None, storage_block_name: str | None = None) -> BaseStorage:
"""Create and initialize storage client."""
if collection_name is None:
collection_name = f"docs_{job.source_type.value}"
# Use variable for default collection prefix
prefix = await Variable.aget("default_collection_prefix", default="docs")
collection_name = f"{prefix}_{job.source_type.value}"
from ..config import get_settings
if storage_block_name:
# Load storage config from block
loaded_block = await Block.aload(f"storage-config/{storage_block_name}")
storage_config = cast(StorageConfig, loaded_block)
# Override collection name if provided
storage_config.collection_name = collection_name
else:
# Fallback to building config from settings
from ..config import get_settings
settings = get_settings()
storage_config = _build_storage_config(job, settings, collection_name)
settings = get_settings()
storage_config = _build_storage_config(job, settings, collection_name)
storage = _instantiate_storage(job.storage_backend, storage_config)
await storage.initialize()
return storage
@@ -257,16 +320,19 @@ def _build_storage_config(
StorageBackend.OPEN_WEBUI: settings.openwebui_endpoint,
StorageBackend.R2R: settings.get_storage_endpoint("r2r"),
}
storage_api_keys = {
storage_api_keys: dict[StorageBackend, str | None] = {
StorageBackend.WEAVIATE: settings.get_api_key("weaviate"),
StorageBackend.OPEN_WEBUI: settings.get_api_key("openwebui"),
StorageBackend.R2R: None, # R2R is self-hosted, no API key needed
}
api_key_raw: str | None = storage_api_keys[job.storage_backend]
api_key: SecretStr | None = SecretStr(api_key_raw) if api_key_raw is not None else None
return StorageConfig(
backend=job.storage_backend,
endpoint=storage_endpoints[job.storage_backend],
api_key=storage_api_keys[job.storage_backend],
api_key=api_key,
collection_name=collection_name,
)
@@ -324,7 +390,20 @@ async def _process_documents(
if progress_callback:
progress_callback(45, "Ingesting documents from source...")
async for document in ingestor.ingest(job):
# Use smart ingestion with deduplication if storage supports it
if hasattr(storage, 'check_exists'):
try:
# Try to use the smart ingestion method
document_generator = ingestor.ingest_with_dedup(
job, storage, collection_name=collection_name
)
except Exception:
# Fall back to regular ingestion if smart method fails
document_generator = ingestor.ingest(job)
else:
document_generator = ingestor.ingest(job)
async for document in document_generator:
batch.append(document)
total_documents += 1
@@ -425,16 +504,35 @@ async def firecrawl_to_r2r_flow(
r2r_storage = cast("R2RStorageType", storage_client)
if progress_callback:
progress_callback(45, "Discovering pages with Firecrawl...")
progress_callback(45, "Checking for existing content before mapping...")
discovered_urls = await map_firecrawl_site_task(str(job.source_url), firecrawl_config)
# Smart mapping: try single URL first to avoid expensive map operation
base_url = str(job.source_url)
single_url_id = str(FirecrawlIngestor.compute_document_id(base_url))
base_exists = await r2r_storage.check_exists(
single_url_id, collection_name=resolved_collection, stale_after_days=30
)
if base_exists:
# Check if this is a recent single-page update
logger.info("Base URL %s exists and is fresh, skipping expensive mapping", base_url)
if progress_callback:
progress_callback(100, "Content is up to date, no processing needed")
return 0, 0
if progress_callback:
progress_callback(50, "Discovering pages with Firecrawl...")
discovered_urls = await map_firecrawl_site_task(base_url, firecrawl_config)
unique_urls = _deduplicate_urls(discovered_urls)
logger.info("Discovered %s unique URLs from Firecrawl map", len(unique_urls))
if progress_callback:
progress_callback(55, f"Found {len(unique_urls)} pages, filtering existing content...")
progress_callback(60, f"Found {len(unique_urls)} pages, filtering existing content...")
eligible_urls = await filter_existing_documents_task(unique_urls, r2r_storage)
eligible_urls = await filter_existing_documents_task(
unique_urls, r2r_storage, collection_name=resolved_collection
)
if not eligible_urls:
logger.info("All Firecrawl pages are up to date for %s", job.source_url)
@@ -443,7 +541,7 @@ async def firecrawl_to_r2r_flow(
return 0, 0
if progress_callback:
progress_callback(65, f"Scraping {len(eligible_urls)} new/updated pages...")
progress_callback(70, f"Scraping {len(eligible_urls)} new/updated pages...")
batch_size = min(settings.default_batch_size, firecrawl_config.limit)
url_batches = _chunk_urls(eligible_urls, batch_size)
@@ -462,7 +560,7 @@ async def firecrawl_to_r2r_flow(
scraped_pages.extend(batch_pages)
if progress_callback:
progress_callback(75, f"Processing {len(scraped_pages)} scraped pages...")
progress_callback(80, f"Processing {len(scraped_pages)} scraped pages...")
documents = await annotate_firecrawl_metadata_task(scraped_pages, job)
@@ -471,7 +569,7 @@ async def firecrawl_to_r2r_flow(
return 0, len(eligible_urls)
if progress_callback:
progress_callback(85, f"Storing {len(documents)} documents in R2R...")
progress_callback(90, f"Storing {len(documents)} documents in R2R...")
processed, failed = await upsert_r2r_documents_task(r2r_storage, documents, resolved_collection)
@@ -481,7 +579,7 @@ async def firecrawl_to_r2r_flow(
@task(name="update_job_status", tags=["tracking"])
def update_job_status_task(
async def update_job_status_task(
job: IngestionJob,
status: IngestionStatus,
processed: int = 0,
@@ -575,7 +673,7 @@ async def create_ingestion_flow(
# Update status to in progress
if progress_callback:
progress_callback(20, "Initializing storage...")
job = update_job_status_task(job, IngestionStatus.IN_PROGRESS)
job = await update_job_status_task(job, IngestionStatus.IN_PROGRESS)
# Run ingestion
if progress_callback:
@@ -601,7 +699,7 @@ async def create_ingestion_flow(
else:
final_status = IngestionStatus.COMPLETED
job = update_job_status_task(job, final_status, processed=processed, _failed=failed)
job = await update_job_status_task(job, final_status, processed=processed, _failed=failed)
print(f"Ingestion completed: {processed} processed, {failed} failed")
@@ -610,7 +708,7 @@ async def create_ingestion_flow(
error_messages.append(str(e))
# Don't reset counts - keep whatever was processed before the error
job = update_job_status_task(
job = await update_job_status_task(
job, IngestionStatus.FAILED, processed=processed, _failed=failed, error=str(e)
)

View File

@@ -6,6 +6,7 @@ from typing import Literal, Protocol, cast
from prefect import serve
from prefect.deployments.runner import RunnerDeployment
from prefect.schedules import Cron, Interval
from prefect.variables import Variable
from ..core.models import IngestionSource, StorageBackend
from .ingestion import SourceTypeLike, StorageBackendLike, create_ingestion_flow
@@ -30,11 +31,13 @@ def create_scheduled_deployment(
storage_backend: StorageBackendLike = StorageBackend.WEAVIATE,
schedule_type: Literal["cron", "interval"] = "interval",
cron_expression: str | None = None,
interval_minutes: int = 60,
interval_minutes: int | None = None,
tags: list[str] | None = None,
storage_block_name: str | None = None,
ingestor_config_block_name: str | None = None,
) -> RunnerDeployment:
"""
Create a scheduled deployment for ingestion.
Create a scheduled deployment for ingestion with block support.
Args:
name: Deployment name
@@ -43,12 +46,28 @@ def create_scheduled_deployment(
storage_backend: Storage backend
schedule_type: Type of schedule
cron_expression: Cron expression if using cron
interval_minutes: Interval in minutes if using interval
interval_minutes: Interval in minutes (uses Variable if None)
tags: Optional tags for deployment
storage_block_name: Optional storage block name
ingestor_config_block_name: Optional ingestor config block name
Returns:
Deployment configuration
"""
# Use Variable for interval if not provided
if interval_minutes is None:
try:
interval_var = Variable.get("default_schedule_interval", default="60")
# Convert Variable result to int, handling various types
if isinstance(interval_var, int):
interval_minutes = interval_var
elif isinstance(interval_var, (str, float)):
interval_minutes = int(float(str(interval_var)))
else:
interval_minutes = 60
except Exception:
interval_minutes = 60
# Create schedule
if schedule_type == "cron" and cron_expression:
schedule = Cron(cron_expression, timezone="UTC")
@@ -62,18 +81,27 @@ def create_scheduled_deployment(
if tags is None:
tags = [source_enum.value, backend_enum.value]
# Create deployment parameters with block support
parameters = {
"source_url": source_url,
"source_type": source_enum.value,
"storage_backend": backend_enum.value,
"validate_first": True,
}
# Add block names if provided
if storage_block_name:
parameters["storage_block_name"] = storage_block_name
if ingestor_config_block_name:
parameters["ingestor_config_block_name"] = ingestor_config_block_name
# Create deployment
# The flow decorator adds the to_deployment method at runtime
to_deployment = create_ingestion_flow.to_deployment
deployment = to_deployment(
name=name,
schedule=schedule,
parameters={
"source_url": source_url,
"source_type": source_enum.value,
"storage_backend": backend_enum.value,
"validate_first": True,
},
parameters=parameters,
tags=tags,
description=f"Scheduled ingestion from {source_url}",
)

View File

@@ -1,10 +1,16 @@
"""Base ingestor interface."""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING
from ..core.models import Document, IngestionJob
if TYPE_CHECKING:
from ..storage.base import BaseStorage
class BaseIngestor(ABC):
"""Abstract base class for all ingestors."""
@@ -47,3 +53,30 @@ class BaseIngestor(ABC):
Estimated number of documents
"""
pass # pragma: no cover
async def ingest_with_dedup(
self,
job: IngestionJob,
storage_client: BaseStorage,
*,
collection_name: str | None = None,
stale_after_days: int = 30,
) -> AsyncGenerator[Document, None]:
"""
Ingest documents with duplicate detection (optional optimization).
Default implementation falls back to regular ingestion.
Subclasses can override to provide optimized deduplication.
Args:
job: The ingestion job configuration
storage_client: Storage client to check for existing documents
collection_name: Collection to check for duplicates
stale_after_days: Consider documents stale after this many days
Yields:
Documents from the source (with deduplication if implemented)
"""
# Default implementation: fall back to regular ingestion
async for document in self.ingest(job):
yield document

View File

@@ -6,6 +6,7 @@ import re
from collections.abc import AsyncGenerator, Awaitable, Callable
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import TYPE_CHECKING
from urllib.parse import urlparse
from uuid import NAMESPACE_URL, UUID, uuid5
@@ -23,8 +24,11 @@ from ..core.models import (
)
from .base import BaseIngestor
if TYPE_CHECKING:
from ..storage.base import BaseStorage
class FirecrawlError(IngestionError): # type: ignore[misc]
class FirecrawlError(IngestionError):
"""Base exception for Firecrawl-related errors."""
def __init__(self, message: str, status_code: int | None = None) -> None:
@@ -161,6 +165,55 @@ class FirecrawlIngestor(BaseIngestor):
for page in pages:
yield self.create_document(page, job)
async def ingest_with_dedup(
self,
job: IngestionJob,
storage_client: "BaseStorage",
*,
collection_name: str | None = None,
stale_after_days: int = 30,
) -> AsyncGenerator[Document, None]:
"""
Ingest documents with duplicate detection to avoid unnecessary scraping.
Args:
job: The ingestion job configuration
storage_client: Storage client to check for existing documents
collection_name: Collection to check for duplicates
stale_after_days: Consider documents stale after this many days
Yields:
Documents from the web source (only new/stale ones)
"""
url = str(job.source_url)
# First, map the site to understand its structure
site_map = await self.map_site(url) or [url]
# Filter out URLs that already exist in storage and are fresh
eligible_urls: list[str] = []
for check_url in site_map:
document_id = str(self.compute_document_id(check_url))
exists = await storage_client.check_exists(
document_id,
collection_name=collection_name,
stale_after_days=stale_after_days
)
if not exists:
eligible_urls.append(check_url)
if not eligible_urls:
return # No new documents to scrape
# Process eligible pages in batches
batch_size = 10
for i in range(0, len(eligible_urls), batch_size):
batch_urls = eligible_urls[i : i + batch_size]
pages = await self.scrape_pages(batch_urls)
for page in pages:
yield self.create_document(page, job)
async def map_site(self, url: str) -> list[str]:
"""Public wrapper for mapping a site."""
@@ -391,7 +444,7 @@ class FirecrawlIngestor(BaseIngestor):
else:
avg_sentence_length = words / sentences
# Simplified readability score (0-100, higher is more readable)
readability_score = max(0, min(100, 100 - (avg_sentence_length - 15) * 2))
readability_score = max(0.0, min(100.0, 100.0 - (avg_sentence_length - 15.0) * 2.0))
# Completeness score based on structure
completeness_factors = 0
@@ -478,11 +531,15 @@ class FirecrawlIngestor(BaseIngestor):
"domain": domain_info["domain"],
"site_name": domain_info["site_name"],
# Document structure
"heading_hierarchy": structure_info["heading_hierarchy"],
"section_depth": structure_info["section_depth"],
"has_code_blocks": structure_info["has_code_blocks"],
"has_images": structure_info["has_images"],
"has_links": structure_info["has_links"],
"heading_hierarchy": (
list(hierarchy_val) if (hierarchy_val := structure_info.get("heading_hierarchy")) and isinstance(hierarchy_val, (list, tuple)) else []
),
"section_depth": (
int(depth_val) if (depth_val := structure_info.get("section_depth")) and isinstance(depth_val, (int, str)) else 0
),
"has_code_blocks": bool(structure_info.get("has_code_blocks", False)),
"has_images": bool(structure_info.get("has_images", False)),
"has_links": bool(structure_info.get("has_links", False)),
# Processing metadata
"extraction_method": "firecrawl",
"last_modified": datetime.fromisoformat(page.sitemap_last_modified)

View File

@@ -1,9 +1,9 @@
"""Repomix ingestor for Git repositories."""
import asyncio
import re
import subprocess
import tempfile
import re
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
from pathlib import Path
@@ -400,16 +400,13 @@ class RepomixIngestor(BaseIngestor):
# Handle different URL formats
if 'github.com' in repo_url or 'gitlab.com' in repo_url:
# Extract from URLs like https://github.com/org/repo.git
path_match = re.search(r'/([^/]+)/([^/]+?)(?:\.git)?/?$', repo_url)
if path_match:
org_name = path_match.group(1)
repo_name = path_match.group(2)
else:
# Try to extract from generic git URLs
path_match = re.search(r'/([^/]+?)(?:\.git)?/?$', repo_url)
if path_match:
repo_name = path_match.group(1)
if path_match := re.search(
r'/([^/]+)/([^/]+?)(?:\.git)?/?$', repo_url
):
org_name = path_match[1]
repo_name = path_match[2]
elif path_match := re.search(r'/([^/]+?)(?:\.git)?/?$', repo_url):
repo_name = path_match[1]
return {
'repository_name': repo_name or 'unknown',
@@ -485,28 +482,18 @@ class RepomixIngestor(BaseIngestor):
# Build rich metadata
metadata: DocumentMetadata = {
# Core required fields
"source_url": str(job.source_url),
"timestamp": datetime.now(UTC),
"content_type": content_type,
"word_count": len(content.split()),
"char_count": len(content),
# Basic fields
"title": f"{file_path}" + (f" (chunk {chunk_index})" if chunk_index > 0 else ""),
"title": f"{file_path}"
+ (f" (chunk {chunk_index})" if chunk_index > 0 else ""),
"description": f"Repository file: {file_path}",
# Content categorization
"category": "source_code" if programming_language else "documentation",
"language": programming_language or "text",
# Document structure from code analysis
"has_code_blocks": True if programming_language else False,
# Processing metadata
"has_code_blocks": bool(programming_language),
"extraction_method": "repomix",
# Repository-specific fields
"file_path": file_path,
"programming_language": programming_language,
}

0
ingest_pipeline/py.typed Normal file
View File

View File

@@ -75,6 +75,40 @@ class BaseStorage(ABC):
"""
raise NotImplementedError(f"{self.__class__.__name__} doesn't support document retrieval")
async def check_exists(
self, document_id: str, *, collection_name: str | None = None, stale_after_days: int = 30
) -> bool:
"""
Check if a document exists and is not stale.
Args:
document_id: Document ID to check
collection_name: Collection to check in
stale_after_days: Consider document stale after this many days
Returns:
True if document exists and is not stale, False otherwise
"""
try:
document = await self.retrieve(document_id, collection_name=collection_name)
if document is None:
return False
# Check staleness if timestamp is available
if "timestamp" in document.metadata:
from datetime import UTC, datetime, timedelta
timestamp_obj = document.metadata["timestamp"]
if isinstance(timestamp_obj, datetime):
timestamp = timestamp_obj
cutoff = datetime.now(UTC) - timedelta(days=stale_after_days)
return timestamp >= cutoff
# If no timestamp, assume it exists and is valid
return True
except Exception:
# Backend doesn't support retrieval, assume doesn't exist
return False
def search(
self,
query: str,
@@ -130,6 +164,15 @@ class BaseStorage(ABC):
"""
return []
async def describe_collections(self) -> list[dict[str, object]]:
"""
Describe available collections with metadata (if supported by backend).
Returns:
List of collection metadata dictionaries, empty list if not supported
"""
return []
async def list_documents(
self,
limit: int = 100,
@@ -159,4 +202,5 @@ class BaseStorage(ABC):
Default implementation does nothing.
"""
pass
# Default implementation - storage backends can override to cleanup connections
return None

View File

@@ -274,6 +274,25 @@ class OpenWebUIStorage(BaseStorage):
except Exception as e:
raise StorageError(f"Failed to store batch: {e}") from e
@override
async def retrieve(
self, document_id: str, *, collection_name: str | None = None
) -> Document | None:
"""
OpenWebUI doesn't support document retrieval by ID.
Args:
document_id: File ID (not supported)
collection_name: Collection name (not used)
Returns:
Always None - retrieval not supported
"""
# OpenWebUI uses file-based storage without direct document retrieval
# This will cause the base check_exists method to return False,
# which means documents will always be re-scraped for OpenWebUI
raise NotImplementedError("OpenWebUI doesn't support document retrieval by ID")
@override
async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool:
"""
@@ -414,69 +433,57 @@ class OpenWebUIStorage(BaseStorage):
count: int
size_mb: float
async def describe_collections(self) -> list[CollectionSummary]:
async def _get_knowledge_base_count(self, kb: dict[str, object]) -> int:
"""Get the file count for a knowledge base."""
kb_id = kb.get("id")
name = kb.get("name", "Unknown")
if not kb_id:
return self._count_files_from_basic_info(kb)
return await self._count_files_from_detailed_info(str(kb_id), str(name), kb)
def _count_files_from_basic_info(self, kb: dict[str, object]) -> int:
"""Count files from basic knowledge base info."""
files = kb.get("files", [])
return len(files) if isinstance(files, list) and files is not None else 0
async def _count_files_from_detailed_info(self, kb_id: str, name: str, kb: dict[str, object]) -> int:
"""Count files by fetching detailed knowledge base info."""
try:
LOGGER.debug(f"Fetching detailed info for KB '{name}' from /api/v1/knowledge/{kb_id}")
detail_response = await self.client.get(f"/api/v1/knowledge/{kb_id}")
detail_response.raise_for_status()
detailed_kb = detail_response.json()
files = detailed_kb.get("files", [])
count = len(files) if isinstance(files, list) and files is not None else 0
LOGGER.info(f"Knowledge base '{name}' (ID: {kb_id}): found {count} files")
return count
except Exception as e:
LOGGER.warning(f"Failed to get detailed info for KB '{name}' (ID: {kb_id}): {e}")
return self._count_files_from_basic_info(kb)
async def describe_collections(self) -> list[dict[str, object]]:
"""Return metadata about each knowledge base."""
try:
# First get the list of knowledge bases
response = await self.client.get("/api/v1/knowledge/")
response.raise_for_status()
knowledge_bases = response.json()
knowledge_bases = await self._fetch_knowledge_bases()
collections: list[dict[str, object]] = []
LOGGER.info(f"OpenWebUI returned {len(knowledge_bases)} knowledge bases")
LOGGER.debug(f"Knowledge bases structure: {knowledge_bases}")
collections: list[OpenWebUIStorage.CollectionSummary] = []
for kb in knowledge_bases:
if not isinstance(kb, dict):
continue
kb_id = kb.get("id")
count = await self._get_knowledge_base_count(kb)
name = kb.get("name", "Unknown")
LOGGER.info(f"Processing knowledge base: '{name}' (ID: {kb_id})")
LOGGER.debug(f"KB structure: {kb}")
if not kb_id:
# If no ID, fall back to basic count from list response
files = kb.get("files", [])
if files is None:
files = []
count = len(files) if isinstance(files, list) else 0
else:
# Get detailed knowledge base information using the correct endpoint
try:
LOGGER.debug(f"Fetching detailed info for KB '{name}' from /api/v1/knowledge/{kb_id}")
detail_response = await self.client.get(f"/api/v1/knowledge/{kb_id}")
detail_response.raise_for_status()
detailed_kb = detail_response.json()
LOGGER.debug(f"Detailed KB response: {detailed_kb}")
files = detailed_kb.get("files", [])
if files is None:
files = []
count = len(files) if isinstance(files, list) else 0
# Debug logging
LOGGER.info(f"Knowledge base '{name}' (ID: {kb_id}): found {count} files")
if count > 0 and len(files) > 0:
LOGGER.debug(f"First file structure: {files[0] if files else 'No files'}")
elif count == 0:
LOGGER.warning(f"Knowledge base '{name}' has 0 files. Files field type: {type(files)}, value: {files}")
except Exception as e:
LOGGER.warning(f"Failed to get detailed info for KB '{name}' (ID: {kb_id}): {e}")
# Fallback to basic files list if detailed fetch fails
files = kb.get("files", [])
if files is None:
files = []
count = len(files) if isinstance(files, list) else 0
LOGGER.info(f"Fallback count for KB '{name}': {count}")
size_mb = count * 0.5 # rough heuristic
summary: OpenWebUIStorage.CollectionSummary = {
summary: dict[str, object] = {
"name": str(name),
"count": int(count),
"count": count,
"size_mb": float(size_mb),
}
collections.append(summary)
@@ -500,7 +507,10 @@ class OpenWebUIStorage(BaseStorage):
# If no collection name provided, return total across all collections
try:
collections = await self.describe_collections()
return sum(collection["count"] for collection in collections)
return sum(
int(collection["count"]) if isinstance(collection["count"], (int, str)) else 0
for collection in collections
)
except Exception:
return 0

View File

@@ -62,7 +62,8 @@ class R2RCollections:
name=name,
description=description,
)
return cast(JsonData, response.results.model_dump())
# response.results is a list, not a model with model_dump()
return cast(JsonData, response.results if isinstance(response.results, list) else [])
except Exception as e:
raise StorageError(f"Failed to create collection '{name}': {e}") from e
@@ -80,7 +81,8 @@ class R2RCollections:
"""
try:
response = await self.client.collections.retrieve(str(collection_id))
return cast(JsonData, response.results.model_dump())
# response.results is a list, not a model with model_dump()
return cast(JsonData, response.results if isinstance(response.results, list) else [])
except Exception as e:
raise StorageError(f"Failed to retrieve collection {collection_id}: {e}") from e
@@ -109,7 +111,8 @@ class R2RCollections:
name=name,
description=description,
)
return cast(JsonData, response.results.model_dump())
# response.results is a list, not a model with model_dump()
return cast(JsonData, response.results if isinstance(response.results, list) else [])
except Exception as e:
raise StorageError(f"Failed to update collection {collection_id}: {e}") from e
@@ -153,7 +156,8 @@ class R2RCollections:
limit=limit,
owner_only=owner_only,
)
return cast(JsonData, response.results.model_dump())
# response.results is a list, not a model with model_dump()
return cast(JsonData, response.results if isinstance(response.results, list) else [])
except Exception as e:
raise StorageError(f"Failed to list collections: {e}") from e
@@ -202,7 +206,8 @@ class R2RCollections:
id=str(collection_id),
document_id=str(document_id),
)
return cast(JsonData, response.results.model_dump())
# response.results is a list, not a model with model_dump()
return cast(JsonData, response.results if isinstance(response.results, list) else [])
except Exception as e:
raise StorageError(
f"Failed to add document {document_id} to collection {collection_id}: {e}"
@@ -254,7 +259,8 @@ class R2RCollections:
offset=offset,
limit=limit,
)
return cast(JsonData, response.results.model_dump())
# response.results is a list, not a model with model_dump()
return cast(JsonData, response.results if isinstance(response.results, list) else [])
except Exception as e:
raise StorageError(
f"Failed to list documents in collection {collection_id}: {e}"
@@ -278,7 +284,8 @@ class R2RCollections:
id=str(collection_id),
user_id=str(user_id),
)
return cast(JsonData, response.results.model_dump())
# response.results is a list, not a model with model_dump()
return cast(JsonData, response.results if isinstance(response.results, list) else [])
except Exception as e:
raise StorageError(
f"Failed to add user {user_id} to collection {collection_id}: {e}"
@@ -330,7 +337,8 @@ class R2RCollections:
offset=offset,
limit=limit,
)
return cast(JsonData, response.results.model_dump())
# response.results is a list, not a model with model_dump()
return cast(JsonData, response.results if isinstance(response.results, list) else [])
except Exception as e:
raise StorageError(f"Failed to list users for collection {collection_id}: {e}") from e
@@ -359,7 +367,8 @@ class R2RCollections:
run_with_orchestration=run_with_orchestration,
settings=cast(dict[str, object], settings or {}),
)
return cast(JsonData, response.results.model_dump())
# response.results is a list, not a model with model_dump()
return cast(JsonData, response.results if isinstance(response.results, list) else [])
except Exception as e:
raise StorageError(
f"Failed to extract entities from collection {collection_id}: {e}"

View File

@@ -9,10 +9,15 @@ from datetime import UTC, datetime
from typing import Self, TypeVar, cast
from uuid import UUID, uuid4
import httpx
from r2r import R2RAsyncClient, R2RException
from r2r import R2RAsyncClient
from typing_extensions import override
# Direct imports for runtime and type checking
# Note: Some type checkers (basedpyright/Pyrefly) may report import issues
# but these work correctly at runtime and with mypy
from httpx import AsyncClient, HTTPStatusError
from r2r import R2RException
from ...core.exceptions import StorageError
from ...core.models import Document, DocumentMetadata, IngestionSource, StorageConfig
from ..base import BaseStorage
@@ -83,7 +88,7 @@ class R2RStorage(BaseStorage):
try:
# Ensure we have an event loop
try:
asyncio.get_running_loop()
_ = asyncio.get_running_loop()
except RuntimeError:
# No event loop running, this should not happen in async context
# but let's be defensive
@@ -93,7 +98,7 @@ class R2RStorage(BaseStorage):
# Test connection using direct HTTP call to v3 API
endpoint = self.endpoint
client = httpx.AsyncClient()
client = AsyncClient()
try:
response = await client.get(f"{endpoint}/v3/collections")
response.raise_for_status()
@@ -107,7 +112,7 @@ class R2RStorage(BaseStorage):
"""Get or create collection by name."""
try:
endpoint = self.endpoint
client = httpx.AsyncClient()
client = AsyncClient()
try:
# List collections and find by name
response = await client.get(f"{endpoint}/v3/collections")
@@ -203,7 +208,7 @@ class R2RStorage(BaseStorage):
for attempt in range(max_retries):
try:
async with httpx.AsyncClient() as http_client:
async with AsyncClient() as http_client:
# Use files parameter but with string values for multipart/form-data
# This matches the cURL -F behavior more closely
metadata = self._build_metadata(document)
@@ -264,7 +269,7 @@ class R2RStorage(BaseStorage):
doc_response = response.json()
break # Success - exit retry loop
except httpx.TimeoutException:
except (OSError, asyncio.TimeoutError):
if attempt < max_retries - 1:
print(
f"Timeout for document {requested_id}, retrying in {retry_delay}s..."
@@ -274,7 +279,7 @@ class R2RStorage(BaseStorage):
continue
else:
raise
except httpx.HTTPStatusError as e:
except HTTPStatusError as e:
if e.response.status_code >= 500 and attempt < max_retries - 1:
print(
f"Server error {e.response.status_code} for document {requested_id}, retrying in {retry_delay}s..."
@@ -470,7 +475,7 @@ class R2RStorage(BaseStorage):
elif description := metadata_map.get("description"):
metadata["description"] = cast(str | None, description)
if tags := metadata_map.get("tags"):
metadata["tags"] = _as_sequence(tags) if isinstance(tags, list) else []
metadata["tags"] = [str(tag) for tag in tags] if isinstance(tags, list) else []
if category := metadata_map.get("category"):
metadata["category"] = str(category)
if section := metadata_map.get("section"):
@@ -502,13 +507,15 @@ class R2RStorage(BaseStorage):
if last_modified := metadata_map.get("last_modified"):
metadata["last_modified"] = _as_datetime(last_modified)
if readability_score := metadata_map.get("readability_score"):
metadata["readability_score"] = (
float(readability_score) if readability_score is not None else None
)
try:
metadata["readability_score"] = float(str(readability_score))
except (ValueError, TypeError):
metadata["readability_score"] = None
if completeness_score := metadata_map.get("completeness_score"):
metadata["completeness_score"] = (
float(completeness_score) if completeness_score is not None else None
)
try:
metadata["completeness_score"] = float(str(completeness_score))
except (ValueError, TypeError):
metadata["completeness_score"] = None
source_value = str(metadata_map.get("ingestion_source", IngestionSource.WEB.value))
try:
@@ -609,7 +616,7 @@ class R2RStorage(BaseStorage):
"""Get document count in collection."""
try:
endpoint = self.endpoint
client = httpx.AsyncClient()
client = AsyncClient()
try:
# Get collections and find the count for the specific collection
response = await client.get(f"{endpoint}/v3/collections")
@@ -677,7 +684,7 @@ class R2RStorage(BaseStorage):
"""List all available collections."""
try:
endpoint = self.endpoint
client = httpx.AsyncClient()
client = AsyncClient()
try:
response = await client.get(f"{endpoint}/v3/collections")
response.raise_for_status()
@@ -777,7 +784,7 @@ class R2RStorage(BaseStorage):
collection_id = await self._ensure_collection(collection_name)
# Use the collections API to list documents in a specific collection
endpoint = self.endpoint
client = httpx.AsyncClient()
client = AsyncClient()
try:
params = {"offset": offset, "limit": limit}
response = await client.get(

View File

@@ -22,7 +22,6 @@ from ..core.models import Document, DocumentMetadata, IngestionSource, StorageCo
from ..utils.vectorizer import Vectorizer
from .base import BaseStorage
VectorContainer: TypeAlias = Mapping[str, object] | Sequence[object] | None
@@ -591,13 +590,13 @@ class WeaviateStorage(BaseStorage):
except Exception as e:
raise StorageError(f"Failed to list collections: {e}") from e
async def describe_collections(self) -> list[dict[str, str | int | float]]:
async def describe_collections(self) -> list[dict[str, object]]:
"""Return metadata for each Weaviate collection."""
if not self.client:
raise StorageError("Weaviate client not initialized")
try:
collections: list[dict[str, str | int | float]] = []
collections: list[dict[str, object]] = []
for name in self.client.collections.list_all():
collection_obj = self.client.collections.get(name)
if not collection_obj:
@@ -808,7 +807,7 @@ class WeaviateStorage(BaseStorage):
offset: int = 0,
*,
collection_name: str | None = None,
) -> list[dict[str, str | int]]:
) -> list[dict[str, object]]:
"""
List documents in the collection with pagination.
@@ -830,7 +829,7 @@ class WeaviateStorage(BaseStorage):
limit=limit, offset=offset, return_metadata=["creation_time"]
)
documents = []
documents: list[dict[str, object]] = []
for obj in response.objects:
props = self._coerce_properties(
obj.properties,
@@ -849,7 +848,7 @@ class WeaviateStorage(BaseStorage):
else:
word_count = 0
doc_info: dict[str, str | int] = {
doc_info: dict[str, object] = {
"id": str(obj.uuid),
"title": str(props.get("title", "Untitled")),
"source_url": str(props.get("source_url", "")),

View File

@@ -2,7 +2,7 @@
import json
from datetime import UTC, datetime
from typing import TypedDict, cast
from typing import Protocol, TypedDict, cast
import httpx
@@ -10,6 +10,41 @@ from ..core.exceptions import IngestionError
from ..core.models import Document
class HttpResponse(Protocol):
"""Protocol for HTTP response."""
def raise_for_status(self) -> None: ...
def json(self) -> dict[str, object]: ...
class AsyncHttpClient(Protocol):
"""Protocol for async HTTP client."""
async def post(
self,
url: str,
*,
json: dict[str, object] | None = None
) -> HttpResponse: ...
async def aclose(self) -> None: ...
class LlmResponse(TypedDict):
"""Type for LLM API response structure."""
choices: list[dict[str, object]]
class LlmChoice(TypedDict):
"""Type for individual choice in LLM response."""
message: dict[str, object]
class LlmMessage(TypedDict):
"""Type for message in LLM choice."""
content: str
class DocumentMetadata(TypedDict, total=False):
"""Structured metadata for documents."""
@@ -27,7 +62,7 @@ class MetadataTagger:
endpoint: str
model: str
client: httpx.AsyncClient
client: AsyncHttpClient
def __init__(
self,
@@ -60,7 +95,10 @@ class MetadataTagger:
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
self.client = httpx.AsyncClient(timeout=60.0, headers=headers)
# Create client with proper typing - httpx.AsyncClient implements AsyncHttpClient protocol
AsyncClientClass = getattr(httpx, "AsyncClient")
raw_client = AsyncClientClass(timeout=60.0, headers=headers)
self.client = cast(AsyncHttpClient, raw_client)
async def tag_document(
self, document: Document, custom_instructions: str | None = None
@@ -87,29 +125,52 @@ class MetadataTagger:
)
# Merge with existing metadata - preserve ALL existing fields and add LLM-generated ones
from typing import cast
from ..core.models import DocumentMetadata as CoreDocumentMetadata
# Start with a copy of existing metadata to preserve all fields
# Cast to avoid TypedDict key errors during manipulation
updated_metadata = cast(dict[str, object], dict(document.metadata))
updated_metadata = dict(document.metadata)
# Update/enhance with LLM-generated metadata, preserving existing values when new ones are empty
if metadata.get("title") and not updated_metadata.get("title"):
updated_metadata["title"] = str(metadata["title"]) # type: ignore[typeddict-item]
if metadata.get("summary") and not updated_metadata.get("description"):
updated_metadata["description"] = str(metadata["summary"])
if title_val := metadata.get("title"):
if not updated_metadata.get("title") and isinstance(title_val, str):
updated_metadata["title"] = title_val
if summary_val := metadata.get("summary"):
if not updated_metadata.get("description"):
updated_metadata["description"] = str(summary_val)
# Ensure required fields have values
updated_metadata.setdefault("source_url", "")
updated_metadata.setdefault("timestamp", datetime.now(UTC))
updated_metadata.setdefault("content_type", "text/plain")
updated_metadata.setdefault("word_count", len(document.content.split()))
updated_metadata.setdefault("char_count", len(document.content))
_ = updated_metadata.setdefault("source_url", "")
_ = updated_metadata.setdefault("timestamp", datetime.now(UTC))
_ = updated_metadata.setdefault("content_type", "text/plain")
_ = updated_metadata.setdefault("word_count", len(document.content.split()))
_ = updated_metadata.setdefault("char_count", len(document.content))
# Cast to the expected type since we're preserving all fields from the original metadata
document.metadata = cast(CoreDocumentMetadata, updated_metadata)
# Build a proper DocumentMetadata instance with only valid keys
new_metadata: CoreDocumentMetadata = {
"source_url": str(updated_metadata.get("source_url", "")),
"timestamp": (
lambda ts: ts if isinstance(ts, datetime) else datetime.now(UTC)
)(updated_metadata.get("timestamp", datetime.now(UTC))),
"content_type": str(updated_metadata.get("content_type", "text/plain")),
"word_count": (lambda wc: int(wc) if isinstance(wc, (int, str)) else 0)(updated_metadata.get("word_count", 0)),
"char_count": (lambda cc: int(cc) if isinstance(cc, (int, str)) else 0)(updated_metadata.get("char_count", 0)),
}
# Add optional fields if they exist
if "title" in updated_metadata and updated_metadata["title"]:
new_metadata["title"] = str(updated_metadata["title"])
if "description" in updated_metadata and updated_metadata["description"]:
new_metadata["description"] = str(updated_metadata["description"])
if "tags" in updated_metadata and isinstance(updated_metadata["tags"], list):
tags_list = cast(list[object], updated_metadata["tags"])
new_metadata["tags"] = [str(tag) for tag in tags_list if tag is not None]
if "category" in updated_metadata and updated_metadata["category"]:
new_metadata["category"] = str(updated_metadata["category"])
if "language" in updated_metadata and updated_metadata["language"]:
new_metadata["language"] = str(updated_metadata["language"])
document.metadata = new_metadata
return document
@@ -199,27 +260,27 @@ Return a JSON object with the following structure:
if not isinstance(result_raw, dict):
raise IngestionError("Invalid response format from LLM")
result = cast(dict[str, object], result_raw)
result = cast(LlmResponse, result_raw)
# Extract content from response
choices = result.get("choices", [])
if not choices or not isinstance(choices, list):
raise IngestionError("No response from LLM")
first_choice_raw = cast(object, choices[0])
first_choice_raw = choices[0]
if not isinstance(first_choice_raw, dict):
raise IngestionError("Invalid choice format")
first_choice = cast(dict[str, object], first_choice_raw)
first_choice = cast(LlmChoice, first_choice_raw)
message_raw = first_choice.get("message", {})
if not isinstance(message_raw, dict):
raise IngestionError("Invalid message format")
message = cast(dict[str, object], message_raw)
message = cast(LlmMessage, message_raw)
content_str = str(message.get("content", "{}"))
try:
raw_metadata = json.loads(content_str)
raw_metadata = cast(dict[str, object], json.loads(content_str))
except json.JSONDecodeError as e:
raise IngestionError(f"Failed to parse LLM response: {e}") from e

View File

@@ -51,7 +51,7 @@ class Vectorizer:
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
self.client = httpx.AsyncClient(timeout=60.0, headers=headers) # type: ignore[attr-defined]
self.client: httpx.AsyncClient = httpx.AsyncClient(timeout=60.0, headers=headers)
async def vectorize(self, text: str) -> list[float]:
"""

View File

@@ -846,3 +846,29 @@ metadata.author
2025-09-18 08:29:04 | INFO | httpx | HTTP Request: POST http://prefect.lab/api/flow_runs/3d6f8223-0b7e-43fd-a1f4-c102f3fc8919/set_state "HTTP/1.1 201 Created"
2025-09-18 08:29:04 | INFO | prefect.flow_runs | Finished in state Completed()
2025-09-18 08:29:06 | INFO | httpx | HTTP Request: POST http://prefect.lab/api/logs/ "HTTP/1.1 201 Created"
2025-09-18 22:21:09 | INFO | ingest_pipeline.cli.tui.utils.runners | Initializing collection management TUI
2025-09-18 22:21:09 | INFO | ingest_pipeline.cli.tui.utils.runners | Scanning available storage backends
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://weaviate.yo/v1/.well-known/openid-configuration "HTTP/1.1 404 Not Found"
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://weaviate.yo/v1/meta "HTTP/1.1 200 OK"
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET https://pypi.org/pypi/weaviate-client/json "HTTP/1.1 200 OK"
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://weaviate.yo/v1/schema "HTTP/1.1 200 OK"
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://chat.lab/api/v1/knowledge/list "HTTP/1.1 401 Unauthorized"
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://r2r.lab/v3/collections "HTTP/1.1 200 OK"
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://r2r.lab/v3/collections "HTTP/1.1 200 OK"
2025-09-18 22:21:09 | INFO | ingest_pipeline.cli.tui.utils.runners | weaviate connected successfully
2025-09-18 22:21:09 | WARNING | ingest_pipeline.cli.tui.utils.runners | open_webui connection failed
2025-09-18 22:21:09 | INFO | ingest_pipeline.cli.tui.utils.runners | r2r connected successfully
2025-09-18 22:21:09 | INFO | ingest_pipeline.cli.tui.utils.runners | Launching TUI with 2 backend(s): weaviate, r2r
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://weaviate.yo/v1/schema "HTTP/1.1 200 OK"
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: POST http://weaviate.yo/v1/graphql "HTTP/1.1 200 OK"
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: POST http://weaviate.yo/v1/graphql "HTTP/1.1 200 OK"
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: POST http://weaviate.yo/v1/graphql "HTTP/1.1 200 OK"
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://r2r.lab/v3/collections "HTTP/1.1 200 OK"
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://r2r.lab/v3/collections "HTTP/1.1 200 OK"
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://r2r.lab/v3/collections "HTTP/1.1 200 OK"
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://r2r.lab/v3/collections "HTTP/1.1 200 OK"
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://r2r.lab/v3/collections "HTTP/1.1 200 OK"
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://r2r.lab/v3/collections "HTTP/1.1 200 OK"
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://r2r.lab/v3/collections "HTTP/1.1 200 OK"
2025-09-18 23:06:37 | INFO | ingest_pipeline.cli.tui.utils.runners | Shutting down storage connections
2025-09-18 23:06:37 | INFO | ingest_pipeline.cli.tui.utils.runners | All storage connections closed gracefully

154
prefect.yaml Normal file
View File

@@ -0,0 +1,154 @@
# Prefect deployment configuration for RAG Manager
name: rag-manager
prefect-version: 3.0.0
# Build steps - prepare the environment
build:
- prefect.deployments.steps.run_shell_script:
id: prepare-environment
script: |
echo "Preparing RAG Manager environment..."
# Ensure virtual environment is activated
source .venv/bin/activate || echo "Virtual environment not found"
# Install dependencies
uv sync --frozen
# Register custom blocks
prefect block register --module ingest_pipeline.core.models
# Push steps - handle deployment artifacts
push: null
# Work pool configuration
work_pool:
name: "{{ prefect.variables.work_pool_name | default('default') }}"
work_queue_name: "{{ prefect.variables.work_queue_name | default('default') }}"
job_variables:
env:
# Prefect configuration
PREFECT_API_URL: "{{ $PREFECT_API_URL }}"
PREFECT_API_KEY: "{{ $PREFECT_API_KEY }}"
# Application configuration from variables
DEFAULT_BATCH_SIZE: "{{ prefect.variables.default_batch_size | default('50') }}"
MAX_CRAWL_DEPTH: "{{ prefect.variables.max_crawl_depth | default('5') }}"
MAX_CRAWL_PAGES: "{{ prefect.variables.max_crawl_pages | default('100') }}"
MAX_CONCURRENT_TASKS: "{{ prefect.variables.max_concurrent_tasks | default('5') }}"
# Service endpoints from variables
LLM_ENDPOINT: "{{ prefect.variables.llm_endpoint | default('http://llm.lab') }}"
WEAVIATE_ENDPOINT: "{{ prefect.variables.weaviate_endpoint | default('http://weaviate.yo') }}"
OPENWEBUI_ENDPOINT: "{{ prefect.variables.openwebui_endpoint | default('http://chat.lab') }}"
FIRECRAWL_ENDPOINT: "{{ prefect.variables.firecrawl_endpoint | default('http://crawl.lab:30002') }}"
# Deployment definitions
deployments:
# Web ingestion deployment
- name: web-ingestion
version: "{{ prefect.variables.deployment_version | default('1.0.0') }}"
tags:
- "{{ prefect.variables.environment | default('development') }}"
- web
- ingestion
description: "Automated web content ingestion using Firecrawl"
entrypoint: ingest_pipeline/flows/ingestion.py:create_ingestion_flow
parameters:
source_type: web
storage_backend: "{{ prefect.variables.default_storage_backend | default('weaviate') }}"
validate_first: true
storage_block_name: "{{ prefect.variables.default_storage_block }}"
ingestor_config_block_name: "{{ prefect.variables.default_firecrawl_block }}"
schedule:
interval: "{{ prefect.variables.default_schedule_interval | default(3600) }}"
timezone: UTC
work_pool:
name: "{{ prefect.variables.web_work_pool | default('default') }}"
job_variables:
env:
INGESTION_TYPE: "web"
MAX_PAGES: "{{ prefect.variables.max_crawl_pages | default('100') }}"
# Repository ingestion deployment
- name: repository-ingestion
version: "{{ prefect.variables.deployment_version | default('1.0.0') }}"
tags:
- "{{ prefect.variables.environment | default('development') }}"
- repository
- ingestion
description: "Automated repository content ingestion using Repomix"
entrypoint: ingest_pipeline/flows/ingestion.py:create_ingestion_flow
parameters:
source_type: repository
storage_backend: "{{ prefect.variables.default_storage_backend | default('weaviate') }}"
validate_first: true
storage_block_name: "{{ prefect.variables.default_storage_block }}"
ingestor_config_block_name: "{{ prefect.variables.default_repomix_block }}"
schedule: null # Manual trigger only
work_pool:
name: "{{ prefect.variables.repo_work_pool | default('default') }}"
job_variables:
env:
INGESTION_TYPE: "repository"
# R2R specialized deployment
- name: firecrawl-to-r2r
version: "{{ prefect.variables.deployment_version | default('1.0.0') }}"
tags:
- "{{ prefect.variables.environment | default('development') }}"
- firecrawl
- r2r
- specialized
description: "Optimized Firecrawl to R2R ingestion flow"
entrypoint: ingest_pipeline/flows/ingestion.py:firecrawl_to_r2r_flow
parameters:
storage_block_name: "{{ prefect.variables.r2r_storage_block }}"
schedule:
cron: "{{ prefect.variables.r2r_cron_schedule | default('0 2 * * *') }}"
timezone: UTC
work_pool:
name: "{{ prefect.variables.r2r_work_pool | default('default') }}"
job_variables:
env:
INGESTION_TYPE: "r2r"
SPECIALIZED_FLOW: "true"
# Automation definitions (commented out - would be created via API)
# automations:
# - name: Cancel Long Running Flows
# description: Cancels flows running longer than 30 minutes
# trigger:
# type: event
# posture: Proactive
# expect: [prefect.flow-run.Running]
# match_related:
# prefect.resource.role: flow
# prefect.resource.name: ingestion_pipeline
# threshold: 1
# within: 1800
# actions:
# - type: cancel-flow-run
# source: inferred
# enabled: true
# Variables that should be set for optimal operation
# Use: prefect variable set <name> <value>
# Required variables:
# - default_storage_backend: weaviate|open_webui|r2r
# - llm_endpoint: URL for LLM service
# - weaviate_endpoint: URL for Weaviate instance
# - openwebui_endpoint: URL for OpenWebUI instance
# - firecrawl_endpoint: URL for Firecrawl service
#
# Optional variables with defaults:
# - default_batch_size: 50
# - max_crawl_depth: 5
# - max_crawl_pages: 100
# - max_concurrent_tasks: 5
# - default_schedule_interval: 3600 (1 hour)
# - deployment_version: 1.0.0
# - environment: development
# Block types that should be registered:
# - storage-config: Storage backend configurations
# - firecrawl-config: Firecrawl scraping parameters
# - repomix-config: Repository processing settings
# - r2r-config: R2R-specific chunking and graph settings

View File

@@ -65,7 +65,7 @@ ignore = [
"ingest_pipeline/cli/main.py" = ["B008"] # Typer uses function calls in defaults
[tool.mypy]
python_version = "3.11"
python_version = "3.12"
strict = true
warn_return_any = true
warn_unused_configs = true
@@ -73,6 +73,18 @@ ignore_missing_imports = true
# Allow AsyncGenerator types in overrides
disable_error_code = ["override"]
[tool.basedpyright]
include = ["ingest_pipeline"]
exclude = ["**/__pycache__", "**/.pytest_cache", "**/node_modules", ".venv", "build", "dist"]
pythonVersion = "3.12"
venvPath = "."
venv = ".venv"
typeCheckingMode = "standard"
useLibraryCodeForTypes = true
reportMissingTypeStubs = "none"
reportMissingModuleSource = "none"
reportAttributeAccessIssue = "warning"
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]

45
pyrightconfig.json Normal file
View File

@@ -0,0 +1,45 @@
{
"include": [
"ingest_pipeline"
],
"exclude": [
"**/__pycache__",
"**/.pytest_cache",
"**/node_modules",
".venv",
"build",
"dist"
],
"pythonVersion": "3.12",
"venvPath": ".",
"venv": ".venv",
"typeCheckingMode": "standard",
"useLibraryCodeForTypes": true,
"stubPath": "./typings",
"reportCallInDefaultInitializer": "none",
"reportUnknownVariableType": "warning",
"reportUnknownMemberType": "warning",
"reportUnknownArgumentType": "warning",
"reportUnknownLambdaType": "warning",
"reportUnknownParameterType": "warning",
"reportMissingParameterType": "warning",
"reportUnannotatedClassAttribute": "warning",
"reportMissingTypeStubs": "none",
"reportMissingModuleSource": "none",
"reportImportCycles": "none",
"reportUnusedImport": "warning",
"reportUnusedClass": "warning",
"reportUnusedFunction": "warning",
"reportUnusedVariable": "warning",
"reportDuplicateImport": "warning",
"reportWildcardImportFromLibrary": "warning",
"reportAny": "warning",
"reportUnusedCallResult": "none",
"reportUnnecessaryIsInstance": "none",
"reportImplicitOverride": "none",
"reportDeprecated": "warning",
"reportIncompatibleMethodOverride": "error",
"reportIncompatibleVariableOverride": "error",
"reportInconsistentConstructor": "none",
"analyzeUnannotatedFunctions": true
}

File diff suppressed because it is too large Load Diff

0
tests/__init__.py Normal file
View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

545
tests/conftest.py Normal file
View File

@@ -0,0 +1,545 @@
from __future__ import annotations
from collections import deque
from collections.abc import Mapping
from dataclasses import dataclass, field
from datetime import UTC, datetime
from pathlib import Path
from types import SimpleNamespace
from typing import Protocol, TypedDict, cast
from urllib.parse import urlparse
import httpx
import pytest
from pydantic import HttpUrl
from ingest_pipeline.core.models import (
Document,
DocumentMetadata,
IngestionSource,
StorageBackend,
StorageConfig,
VectorConfig,
)
from typings import EmbeddingData, EmbeddingResponse
from .openapi_mocks import (
FirecrawlMockService,
OpenAPISpec,
OpenWebUIMockService,
R2RMockService,
)
# Type aliases for mock responses
MockResponseData = dict[str, object]
PROJECT_ROOT = Path(__file__).resolve().parent.parent
class RequestRecord(TypedDict):
"""Captured HTTP request payload."""
method: str
url: str
json_body: dict[str, object] | None
params: dict[str, object] | None
files: object | None
@dataclass(slots=True)
class StubbedResponse:
"""In-memory HTTP response for httpx client mocking."""
payload: object
status_code: int = 200
def json(self) -> object:
return self.payload
def raise_for_status(self) -> None:
if self.status_code < 400:
return
# Create a minimal exception for testing - we don't need full httpx objects for tests
class TestHTTPError(Exception):
request: object | None
response: object | None
def __init__(self, message: str) -> None:
super().__init__(message)
self.request = None
self.response = None
raise TestHTTPError(f"Stubbed HTTP error with status {self.status_code}")
@dataclass(slots=True)
class AsyncClientStub:
"""Replacement for httpx.AsyncClient used in tests."""
responses: deque[StubbedResponse]
requests: list[RequestRecord]
owner: HttpxStub
timeout: object | None = None
headers: dict[str, str] = field(default_factory=dict)
base_url: str = ""
def __init__(
self,
*,
responses: deque[StubbedResponse],
requests: list[RequestRecord],
timeout: object | None = None,
headers: dict[str, str] | None = None,
base_url: str | None = None,
owner: HttpxStub,
**_: object,
) -> None:
self.responses = responses
self.requests = requests
self.timeout = timeout
self.headers = dict(headers or {})
self.base_url = str(base_url or "")
self.owner = owner
def _normalize_url(self, url: str) -> str:
if url.startswith("http://") or url.startswith("https://"):
return url
if not self.base_url:
return url
prefix = self.base_url.rstrip("/")
suffix = url.lstrip("/")
return f"{prefix}/{suffix}" if suffix else prefix
def _consume(
self,
*,
method: str,
url: str,
json: dict[str, object] | None,
params: dict[str, object] | None,
files: object | None,
) -> StubbedResponse:
if self.responses:
return self.responses.popleft()
dispatched = self.owner.dispatch(
method=method,
url=url,
json=json,
params=params,
files=files,
)
if dispatched is not None:
return dispatched
raise AssertionError(f"No stubbed response for {method} {url}")
def _record(
self,
*,
method: str,
url: str,
json: dict[str, object] | None,
params: dict[str, object] | None,
files: object | None,
) -> str:
normalized = self._normalize_url(url)
record: RequestRecord = {
"method": method,
"url": normalized,
"json_body": json,
"params": params,
"files": files,
}
self.requests.append(record)
return normalized
async def post(
self,
url: str,
*,
json: dict[str, object] | None = None,
files: object | None = None,
params: dict[str, object] | None = None,
) -> StubbedResponse:
normalized = self._record(
method="POST",
url=url,
json=json,
params=params,
files=files,
)
return self._consume(
method="POST",
url=normalized,
json=json,
params=params,
files=files,
)
async def get(
self,
url: str,
*,
params: dict[str, object] | None = None,
) -> StubbedResponse:
normalized = self._record(
method="GET",
url=url,
json=None,
params=params,
files=None,
)
return self._consume(
method="GET",
url=normalized,
json=None,
params=params,
files=None,
)
async def delete(
self,
url: str,
*,
params: dict[str, object] | None = None,
json: dict[str, object] | None = None,
) -> StubbedResponse:
normalized = self._record(
method="DELETE",
url=url,
json=json,
params=params,
files=None,
)
return self._consume(
method="DELETE",
url=normalized,
json=json,
params=params,
files=None,
)
async def aclose(self) -> None:
return None
async def __aenter__(self) -> AsyncClientStub:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: object | None,
) -> None:
return None
@dataclass(slots=True)
class HttpxStub:
"""Helper exposing queued responses and captured requests."""
responses: deque[StubbedResponse] = field(default_factory=deque)
requests: list[RequestRecord] = field(default_factory=list)
clients: list[AsyncClientStub] = field(default_factory=list)
services: dict[str, MockService] = field(default_factory=dict)
def queue_json(self, payload: object, status_code: int = 200) -> None:
self.responses.append(StubbedResponse(payload=payload, status_code=status_code))
def register_service(self, base_url: str, service: MockService) -> None:
normalized = base_url.rstrip("/") or base_url
self.services[normalized] = service
def dispatch(
self,
*,
method: str,
url: str,
json: Mapping[str, object] | None,
params: Mapping[str, object] | None,
files: object | None,
) -> StubbedResponse | None:
parsed = urlparse(url)
base = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme and parsed.netloc else ""
base = base.rstrip("/") if base else base
path = parsed.path or "/"
service = self.services.get(base)
if service is None:
return None
status, payload = service.handle(
method=method,
path=path,
json=json,
params=params,
files=files,
)
return StubbedResponse(payload=payload, status_code=status)
class DocumentFactory(Protocol):
"""Callable protocol for building Document instances."""
def __call__(
self,
*,
content: str,
metadata_updates: dict[str, object] | None = None,
) -> Document:
"""Create Document for testing."""
...
class VectorConfigFactory(Protocol):
"""Callable protocol for building VectorConfig instances."""
def __call__(self, *, model: str, dimension: int, endpoint: str) -> VectorConfig:
"""Create VectorConfig for testing."""
...
class EmbeddingPayloadFactory(Protocol):
"""Callable protocol for synthesizing embedding responses."""
def __call__(self, *, dimension: int) -> EmbeddingResponse:
"""Create embedding payload with the requested dimension."""
...
class MockService(Protocol):
"""Protocol for mock services that can handle HTTP requests."""
def handle(
self,
*,
method: str,
path: str,
json: Mapping[str, object] | None,
params: Mapping[str, object] | None,
files: object | None,
) -> tuple[int, object]:
"""Handle HTTP request and return status code and response payload."""
...
@pytest.fixture(scope="session")
def base_metadata() -> DocumentMetadata:
"""Provide reusable base metadata for document fixtures."""
required = {
"source_url": "https://example.com/article",
"timestamp": datetime.now(UTC),
"content_type": "text/plain",
"word_count": 100,
"char_count": 500,
}
return cast(DocumentMetadata, cast(object, required))
@pytest.fixture(scope="module")
def document_factory(base_metadata: DocumentMetadata) -> DocumentFactory:
"""Build Document models with deterministic defaults."""
def _factory(
*,
content: str,
metadata_updates: dict[str, object] | None = None,
) -> Document:
metadata_dict = dict(base_metadata)
if metadata_updates:
metadata_dict.update(metadata_updates)
return Document(
content=content,
metadata=cast(DocumentMetadata, cast(object, metadata_dict)),
source=IngestionSource.WEB,
)
return _factory
@pytest.fixture(scope="session")
def vector_config_factory() -> VectorConfigFactory:
"""Construct VectorConfig instances for tests."""
def _factory(*, model: str, dimension: int, endpoint: str) -> VectorConfig:
return VectorConfig(model=model, dimension=dimension, embedding_endpoint=HttpUrl(endpoint))
return _factory
@pytest.fixture(scope="session")
def storage_config() -> StorageConfig:
"""Provide canonical storage configuration for adapters."""
return StorageConfig(
backend=StorageBackend.WEAVIATE,
endpoint=HttpUrl("http://storage.local"),
collection_name="documents",
)
@pytest.fixture(scope="session")
def r2r_storage_config() -> StorageConfig:
"""Provide storage configuration for R2R adapter tests."""
return StorageConfig(
backend=StorageBackend.R2R,
endpoint=HttpUrl("http://r2r.local"),
collection_name="documents",
)
@pytest.fixture(scope="session")
def openwebui_spec() -> OpenAPISpec:
"""Load OpenWebUI OpenAPI specification."""
return OpenAPISpec.from_file(PROJECT_ROOT / "chat.json")
@pytest.fixture(scope="session")
def r2r_spec() -> OpenAPISpec:
"""Load R2R OpenAPI specification."""
return OpenAPISpec.from_file(PROJECT_ROOT / "r2r.json")
@pytest.fixture(scope="session")
def firecrawl_spec() -> OpenAPISpec:
"""Load Firecrawl OpenAPI specification."""
return OpenAPISpec.from_file(PROJECT_ROOT / "firecrawl.json")
@pytest.fixture(scope="function")
def httpx_stub(monkeypatch: pytest.MonkeyPatch) -> HttpxStub:
"""Replace httpx.AsyncClient with an in-memory stub."""
stub = HttpxStub()
def _client_factory(**kwargs: object) -> AsyncClientStub:
client = AsyncClientStub(
responses=stub.responses,
requests=stub.requests,
timeout=kwargs.get("timeout"),
headers=cast(dict[str, str] | None, kwargs.get("headers")),
base_url=cast(str | None, kwargs.get("base_url")),
owner=stub,
)
stub.clients.append(client)
return client
monkeypatch.setattr(httpx, "AsyncClient", _client_factory)
monkeypatch.delenv("LLM_API_KEY", raising=False)
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
return stub
@pytest.fixture(scope="function")
def openwebui_service(
httpx_stub: HttpxStub,
openwebui_spec: OpenAPISpec,
storage_config: StorageConfig,
) -> OpenWebUIMockService:
"""Stateful mock for OpenWebUI APIs."""
service = OpenWebUIMockService(
base_url=str(storage_config.endpoint),
spec=openwebui_spec,
)
httpx_stub.register_service(service.base_url, service)
return service
@pytest.fixture(scope="function")
def r2r_service(
httpx_stub: HttpxStub,
r2r_spec: OpenAPISpec,
r2r_storage_config: StorageConfig,
) -> R2RMockService:
"""Stateful mock for R2R APIs."""
service = R2RMockService(
base_url=str(r2r_storage_config.endpoint),
spec=r2r_spec,
)
httpx_stub.register_service(service.base_url, service)
return service
@pytest.fixture(scope="function")
def firecrawl_service(
httpx_stub: HttpxStub,
firecrawl_spec: OpenAPISpec,
) -> FirecrawlMockService:
"""Stateful mock for Firecrawl APIs."""
service = FirecrawlMockService(
base_url="http://crawl.lab:30002",
spec=firecrawl_spec,
)
httpx_stub.register_service(service.base_url, service)
return service
@pytest.fixture(scope="function")
def firecrawl_client_stub(
monkeypatch: pytest.MonkeyPatch,
firecrawl_service: FirecrawlMockService,
) -> object:
"""Patch AsyncFirecrawl to use the mock service."""
class AsyncFirecrawlStub:
_service: FirecrawlMockService
def __init__(self, *args: object, **kwargs: object) -> None:
self._service = firecrawl_service
async def map(self, url: str, limit: int | None = None, **_: object) -> SimpleNamespace:
payload = cast(MockResponseData, self._service.map_response(url, limit))
links_data = cast(list[MockResponseData], payload.get("links", []))
links = [SimpleNamespace(url=cast(str, item.get("url", ""))) for item in links_data]
return SimpleNamespace(success=payload.get("success", True), links=links)
async def scrape(self, url: str, formats: list[str] | None = None, **_: object) -> SimpleNamespace:
payload = cast(MockResponseData, self._service.scrape_response(url, formats))
data = cast(MockResponseData, payload.get("data", {}))
metadata_payload = cast(MockResponseData, data.get("metadata", {}))
metadata_obj = SimpleNamespace(**metadata_payload)
return SimpleNamespace(
markdown=data.get("markdown"),
html=data.get("html"),
rawHtml=data.get("rawHtml"),
links=data.get("links", []),
metadata=metadata_obj,
)
async def close(self) -> None:
return None
async def __aenter__(self) -> AsyncFirecrawlStub:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: object | None,
) -> None:
return None
monkeypatch.setattr(
"ingest_pipeline.ingestors.firecrawl.AsyncFirecrawl",
AsyncFirecrawlStub,
)
return AsyncFirecrawlStub
@pytest.fixture(scope="function")
def embedding_payload_factory() -> EmbeddingPayloadFactory:
"""Provide embedding payloads tailored to the requested dimension."""
def _factory(*, dimension: int) -> EmbeddingResponse:
vector = [float(index) for index in range(dimension)]
data_entry: EmbeddingData = {"embedding": vector}
return {"data": [data_entry]}
return _factory

807
tests/openapi_mocks.py Normal file
View File

@@ -0,0 +1,807 @@
from __future__ import annotations
import copy
import json as json_module
import time
from collections.abc import Mapping
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from uuid import uuid4
class OpenAPISpec:
"""Utility for generating example payloads from an OpenAPI document."""
def __init__(self, raw: Mapping[str, Any]):
self._raw = raw
self._paths = raw.get("paths", {})
self._components = raw.get("components", {})
@classmethod
def from_file(cls, path: Path) -> OpenAPISpec:
with path.open("r", encoding="utf-8") as handle:
raw = json_module.load(handle)
return cls(raw)
def resolve_ref(self, ref: str) -> Mapping[str, Any]:
if not ref.startswith("#/"):
raise ValueError(f"Unsupported $ref format: {ref}")
parts = ref.lstrip("#/").split("/")
node: Any = self._raw
for part in parts:
if not isinstance(node, Mapping) or part not in node:
raise KeyError(f"Unable to resolve reference: {ref}")
node = node[part]
if not isinstance(node, Mapping):
raise TypeError(f"Reference {ref} did not resolve to an object")
return node
def generate(self, schema: Mapping[str, Any] | None, name: str = "value") -> Any:
if schema is None:
return None
if "$ref" in schema:
resolved = self.resolve_ref(schema["$ref"])
return self.generate(resolved, name=name)
if "example" in schema:
return copy.deepcopy(schema["example"])
if "default" in schema and schema["default"] is not None:
return copy.deepcopy(schema["default"])
if "enum" in schema:
enum_values = schema["enum"]
if enum_values:
return copy.deepcopy(enum_values[0])
if "anyOf" in schema:
for option in schema["anyOf"]:
candidate = self.generate(option, name=name)
if candidate is not None:
return candidate
if "oneOf" in schema:
return self.generate(schema["oneOf"][0], name=name)
if "allOf" in schema:
result: dict[str, Any] = {}
for option in schema["allOf"]:
fragment = self.generate(option, name=name)
if isinstance(fragment, Mapping):
result.update(fragment)
return result
type_name = schema.get("type")
if type_name == "object":
properties = schema.get("properties", {})
required = schema.get("required", [])
result: dict[str, Any] = {}
keys = list(properties.keys()) or list(required)
for key in keys:
prop_schema = properties.get(key, {})
result[key] = self.generate(prop_schema, name=key)
additional = schema.get("additionalProperties")
if additional and isinstance(additional, Mapping) and not properties:
result["key"] = self.generate(additional, name="key")
return result
if type_name == "array":
item_schema = schema.get("items", {})
item = self.generate(item_schema, name=name)
return [] if item is None else [item]
if type_name == "string":
format_hint = schema.get("format")
if format_hint == "date-time":
return "2024-01-01T00:00:00+00:00"
if format_hint == "date":
return "2024-01-01"
if format_hint == "uuid":
return "00000000-0000-4000-8000-000000000000"
if format_hint == "uri":
return "https://example.com"
return f"{name}-value"
if type_name == "integer":
minimum = schema.get("minimum")
if minimum is not None:
return int(minimum)
return 1
if type_name == "number":
return 1.0
if type_name == "boolean":
return True
if type_name == "null":
return None
return {}
def _split_path(self, path: str) -> list[str]:
if path in {"", "/"}:
return []
return [segment for segment in path.strip("/").split("/") if segment]
def find_operation(self, method: str, path: str) -> tuple[Mapping[str, Any] | None, dict[str, str]]:
method = method.lower()
normalized = "/" if path in {"", "/"} else "/" + path.strip("/")
actual_segments = self._split_path(normalized)
for template, operations in self._paths.items():
operation = operations.get(method)
if operation is None:
continue
template_path = template if template else "/"
template_segments = self._split_path(template_path)
if len(template_segments) != len(actual_segments):
continue
params: dict[str, str] = {}
matched = True
for template_part, actual_part in zip(template_segments, actual_segments, strict=False):
if template_part.startswith("{") and template_part.endswith("}"):
params[template_part[1:-1]] = actual_part
elif template_part != actual_part:
matched = False
break
if matched:
return operation, params
return None, {}
def build_response(
self,
operation: Mapping[str, Any],
status: str | None = None,
) -> tuple[int, Any]:
responses = operation.get("responses", {})
target_status = status or (
"200" if "200" in responses else next(iter(responses), "200")
)
status_code = 200
try:
status_code = int(target_status)
except ValueError:
pass
response_entry = responses.get(target_status, {})
content = response_entry.get("content", {})
media = None
for candidate in ("application/json", "application/problem+json", "text/plain"):
if candidate in content:
media = content[candidate]
break
schema = media.get("schema") if media else None
payload = self.generate(schema, name="response")
return status_code, payload
def generate_from_ref(
self,
ref: str,
overrides: Mapping[str, Any] | None = None,
) -> Any:
base = self.generate({"$ref": ref})
if overrides is None or not isinstance(base, Mapping):
return copy.deepcopy(base)
merged = copy.deepcopy(base)
for key, value in overrides.items():
if isinstance(value, Mapping) and isinstance(merged.get(key), Mapping):
merged[key] = self.generate_from_mapping(
merged[key],
value,
)
else:
merged[key] = copy.deepcopy(value)
return merged
def generate_from_mapping(
self,
base: Mapping[str, Any],
overrides: Mapping[str, Any],
) -> Mapping[str, Any]:
result = copy.deepcopy(base)
for key, value in overrides.items():
if isinstance(value, Mapping) and isinstance(result.get(key), Mapping):
result[key] = self.generate_from_mapping(result[key], value)
else:
result[key] = copy.deepcopy(value)
return result
class OpenAPIMockService:
"""Base class for stateful mock services backed by an OpenAPI spec."""
def __init__(self, base_url: str, spec: OpenAPISpec):
self.base_url = base_url.rstrip("/")
self.spec = spec
@staticmethod
def _normalize_path(path: str) -> str:
if not path or path == "/":
return "/"
return "/" + path.strip("/")
def handle(
self,
*,
method: str,
path: str,
json: Mapping[str, Any] | None,
params: Mapping[str, Any] | None,
files: Any,
) -> tuple[int, Any]:
operation, _ = self.spec.find_operation(method, path)
if operation is None:
return 404, {"detail": f"Unhandled {method.upper()} {path}"}
return self.spec.build_response(operation)
class OpenWebUIMockService(OpenAPIMockService):
"""Stateful mock capturing OpenWebUI knowledge and file operations."""
def __init__(self, base_url: str, spec: OpenAPISpec):
super().__init__(base_url, spec)
self._knowledge: dict[str, dict[str, Any]] = {}
self._files: dict[str, dict[str, Any]] = {}
self._knowledge_counter = 1
self._file_counter = 1
self._default_user = "user-1"
@staticmethod
def _timestamp() -> int:
return int(time.time())
def ensure_knowledge(
self,
*,
name: str,
description: str = "",
knowledge_id: str | None = None,
) -> dict[str, Any]:
identifier = knowledge_id or f"kb-{self._knowledge_counter}"
self._knowledge_counter += 1
entry = self.spec.generate_from_ref("#/components/schemas/KnowledgeUserResponse")
ts = self._timestamp()
entry.update(
{
"id": identifier,
"user_id": self._default_user,
"name": name,
"description": description or entry.get("description") or "",
"created_at": ts,
"updated_at": ts,
"data": entry.get("data") or {},
"meta": entry.get("meta") or {},
"access_control": entry.get("access_control") or {},
"files": [],
}
)
self._knowledge[identifier] = entry
return copy.deepcopy(entry)
def create_file(
self,
*,
filename: str,
user_id: str | None = None,
file_id: str | None = None,
) -> dict[str, Any]:
identifier = file_id or f"file-{self._file_counter}"
self._file_counter += 1
entry = self.spec.generate_from_ref("#/components/schemas/FileModelResponse")
ts = self._timestamp()
entry.update(
{
"id": identifier,
"user_id": user_id or self._default_user,
"filename": filename,
"meta": entry.get("meta") or {},
"created_at": ts,
"updated_at": ts,
}
)
self._files[identifier] = entry
return copy.deepcopy(entry)
def _build_file_metadata(self, file_id: str) -> dict[str, Any]:
metadata = self.spec.generate_from_ref("#/components/schemas/FileMetadataResponse")
source = self._files.get(file_id)
ts = self._timestamp()
metadata.update(
{
"id": file_id,
"meta": (source or {}).get("meta", {}),
"created_at": ts,
"updated_at": ts,
}
)
return metadata
def get_knowledge(self, knowledge_id: str) -> dict[str, Any] | None:
entry = self._knowledge.get(knowledge_id)
return copy.deepcopy(entry) if entry is not None else None
def find_knowledge_by_name(self, name: str) -> tuple[str, dict[str, Any]] | None:
for identifier, entry in self._knowledge.items():
if entry.get("name") == name:
return identifier, copy.deepcopy(entry)
return None
def attach_existing_file(self, knowledge_id: str, file_id: str) -> None:
if knowledge_id not in self._knowledge:
raise KeyError(f"Knowledge {knowledge_id} not found")
knowledge = self._knowledge[knowledge_id]
metadata = self._build_file_metadata(file_id)
knowledge.setdefault("files", [])
knowledge["files"] = [item for item in knowledge["files"] if item.get("id") != file_id]
knowledge["files"].append(metadata)
knowledge["updated_at"] = self._timestamp()
def _knowledge_user_response(self, knowledge: Mapping[str, Any]) -> dict[str, Any]:
payload = self.spec.generate_from_ref("#/components/schemas/KnowledgeUserResponse")
payload.update({
"id": knowledge["id"],
"user_id": knowledge["user_id"],
"name": knowledge["name"],
"description": knowledge.get("description", ""),
"data": copy.deepcopy(knowledge.get("data", {})),
"meta": copy.deepcopy(knowledge.get("meta", {})),
"access_control": copy.deepcopy(knowledge.get("access_control", {})),
"created_at": knowledge.get("created_at", self._timestamp()),
"updated_at": knowledge.get("updated_at", self._timestamp()),
"files": copy.deepcopy(knowledge.get("files", [])),
})
return payload
def _knowledge_files_response(self, knowledge: Mapping[str, Any]) -> dict[str, Any]:
payload = self.spec.generate_from_ref("#/components/schemas/KnowledgeFilesResponse")
payload.update(
{
"id": knowledge["id"],
"user_id": knowledge["user_id"],
"name": knowledge["name"],
"description": knowledge.get("description", ""),
"data": copy.deepcopy(knowledge.get("data", {})),
"meta": copy.deepcopy(knowledge.get("meta", {})),
"access_control": copy.deepcopy(knowledge.get("access_control", {})),
"created_at": knowledge.get("created_at", self._timestamp()),
"updated_at": knowledge.get("updated_at", self._timestamp()),
"files": copy.deepcopy(knowledge.get("files", [])),
}
)
return payload
def handle(
self,
*,
method: str,
path: str,
json: Mapping[str, Any] | None,
params: Mapping[str, Any] | None,
files: Any,
) -> tuple[int, Any]:
normalized = self._normalize_path(path)
segments = [segment for segment in normalized.strip("/").split("/") if segment]
if segments[:3] != ["api", "v1", "knowledge"]:
return super().handle(method=method, path=path, json=json, params=params, files=files)
method_upper = method.upper()
segment_count = len(segments)
if method_upper == "GET" and segment_count == 4 and segments[3] == "list":
body = [self._knowledge_user_response(entry) for entry in self._knowledge.values()]
return 200, body
if method_upper == "GET" and segment_count == 3:
body = [self._knowledge_user_response(entry) for entry in self._knowledge.values()]
return 200, body
if method_upper == "POST" and segment_count == 4 and segments[3] == "create":
payload = json or {}
entry = self.ensure_knowledge(
name=str(payload.get("name", "knowledge")),
description=str(payload.get("description", "")),
)
return 200, entry
if segment_count >= 4:
knowledge_id = segments[3]
knowledge = self._knowledge.get(knowledge_id)
if knowledge is None:
return 404, {"detail": f"Knowledge {knowledge_id} not found"}
if method_upper == "GET" and segment_count == 4:
return 200, self._knowledge_files_response(knowledge)
if method_upper == "POST" and segment_count == 6 and segments[4:] == ["file", "add"]:
payload = json or {}
file_id = str(payload.get("file_id", ""))
if not file_id:
return 422, {"detail": "file_id is required"}
if file_id not in self._files:
self.create_file(filename=f"{file_id}.txt")
metadata = self._build_file_metadata(file_id)
knowledge.setdefault("files", [])
knowledge["files"] = [item for item in knowledge["files"] if item.get("id") != file_id]
knowledge["files"].append(metadata)
knowledge["updated_at"] = self._timestamp()
return 200, self._knowledge_files_response(knowledge)
if method_upper == "POST" and segment_count == 6 and segments[4:] == ["file", "remove"]:
payload = json or {}
file_id = str(payload.get("file_id", ""))
knowledge["files"] = [item for item in knowledge.get("files", []) if item.get("id") != file_id]
knowledge["updated_at"] = self._timestamp()
return 200, self._knowledge_files_response(knowledge)
if method_upper == "DELETE" and segment_count == 5 and segments[4] == "delete":
self._knowledge.pop(knowledge_id, None)
return 200, True
if method_upper == "POST" and segments == ["api", "v1", "files"]:
filename = "uploaded.txt"
if files and isinstance(files, Mapping):
file_entry = files.get("file")
if isinstance(file_entry, tuple) and len(file_entry) >= 1:
filename = str(file_entry[0]) or filename
entry = self.create_file(filename=filename)
return 200, entry
if method_upper == "DELETE" and segments[:3] == ["api", "v1", "files"] and len(segments) == 4:
file_id = segments[3]
self._files.pop(file_id, None)
for knowledge in self._knowledge.values():
knowledge["files"] = [item for item in knowledge.get("files", []) if item.get("id") != file_id]
return 200, {"deleted": True}
return super().handle(method=method, path=path, json=json, params=params, files=files)
class R2RMockService(OpenAPIMockService):
"""Stateful mock of core R2R collection endpoints."""
def __init__(self, base_url: str, spec: OpenAPISpec):
super().__init__(base_url, spec)
self._collections: dict[str, dict[str, Any]] = {}
self._documents: dict[str, dict[str, Any]] = {}
@staticmethod
def _iso_now() -> str:
return datetime.now(tz=UTC).isoformat()
def create_collection(
self,
*,
name: str,
description: str = "",
collection_id: str | None = None,
) -> dict[str, Any]:
identifier = collection_id or str(uuid4())
entry = self.spec.generate_from_ref("#/components/schemas/CollectionResponse")
timestamp = self._iso_now()
entry.update(
{
"id": identifier,
"owner_id": entry.get("owner_id") or str(uuid4()),
"name": name,
"description": description or entry.get("description") or "",
"graph_cluster_status": entry.get("graph_cluster_status") or "idle",
"graph_sync_status": entry.get("graph_sync_status") or "synced",
"created_at": timestamp,
"updated_at": timestamp,
"user_count": entry.get("user_count", 1) or 1,
"document_count": entry.get("document_count", 0) or 0,
"documents": entry.get("documents", []),
}
)
self._collections[identifier] = entry
return copy.deepcopy(entry)
def get_collection(self, collection_id: str) -> dict[str, Any] | None:
entry = self._collections.get(collection_id)
return copy.deepcopy(entry) if entry is not None else None
def find_collection_by_name(self, name: str) -> tuple[str, dict[str, Any]] | None:
for identifier, entry in self._collections.items():
if entry.get("name") == name:
return identifier, copy.deepcopy(entry)
return None
def _set_collection_document_ids(self, collection_id: str, document_id: str, *, add: bool) -> None:
collection = self._collections.get(collection_id)
if collection is None:
collection = self.create_collection(name=f"Collection {collection_id}", collection_id=collection_id)
documents = collection.setdefault("documents", [])
if add:
if document_id not in documents:
documents.append(document_id)
else:
documents[:] = [doc for doc in documents if doc != document_id]
collection["document_count"] = len(documents)
def create_document(
self,
*,
document_id: str,
content: str,
metadata: dict[str, Any],
collection_ids: list[str],
) -> dict[str, Any]:
entry = self.spec.generate_from_ref("#/components/schemas/DocumentResponse")
timestamp = self._iso_now()
entry.update(
{
"id": document_id,
"owner_id": entry.get("owner_id") or str(uuid4()),
"collection_ids": collection_ids,
"metadata": copy.deepcopy(metadata),
"document_type": entry.get("document_type") or "text",
"version": entry.get("version") or "1.0",
"size_in_bytes": metadata.get("char_count") or len(content.encode("utf-8")),
"created_at": entry.get("created_at") or timestamp,
"updated_at": timestamp,
"summary": entry.get("summary"),
"summary_embedding": entry.get("summary_embedding") or None,
"chunks": entry.get("chunks") or [],
"content": content,
}
)
entry["__content"] = content
self._documents[document_id] = entry
for collection_id in collection_ids:
self._set_collection_document_ids(collection_id, document_id, add=True)
return copy.deepcopy(entry)
def get_document(self, document_id: str) -> dict[str, Any] | None:
entry = self._documents.get(document_id)
if entry is None:
return None
return copy.deepcopy(entry)
def delete_document(self, document_id: str) -> bool:
entry = self._documents.pop(document_id, None)
if entry is None:
return False
for collection_id in entry.get("collection_ids", []):
self._set_collection_document_ids(collection_id, document_id, add=False)
return True
def append_document_metadata(self, document_id: str, metadata_list: list[dict[str, Any]]) -> dict[str, Any] | None:
entry = self._documents.get(document_id)
if entry is None:
return None
metadata = entry.setdefault("metadata", {})
for item in metadata_list:
for key, value in item.items():
metadata[key] = value
entry["updated_at"] = self._iso_now()
return copy.deepcopy(entry)
def handle(
self,
*,
method: str,
path: str,
json: Mapping[str, Any] | None,
params: Mapping[str, Any] | None,
files: Any,
) -> tuple[int, Any]:
normalized = self._normalize_path(path)
method_upper = method.upper()
if normalized == "/v3/collections" and method_upper == "GET":
payload = self.spec.generate_from_ref(
"#/components/schemas/PaginatedR2RResult_list_CollectionResponse__"
)
results = []
for _identifier, entry in self._collections.items():
clone = copy.deepcopy(entry)
clone["document_count"] = len(clone.get("documents", []))
results.append(clone)
payload.update(
{
"results": results,
"total_entries": len(results),
}
)
return 200, payload
if normalized == "/v3/collections" and method_upper == "POST":
body = json or {}
entry = self.create_collection(
name=str(body.get("name", "Collection")),
description=str(body.get("description", "")),
)
payload = self.spec.generate_from_ref("#/components/schemas/R2RResults_CollectionResponse_")
payload.update({"results": entry})
return 200, payload
segments = [segment for segment in normalized.strip("/").split("/") if segment]
if segments[:2] == ["v3", "documents"]:
if method_upper == "POST" and len(segments) == 2:
metadata_raw = {}
content = ""
doc_id = str(uuid4())
collection_ids: list[str] = []
if isinstance(files, Mapping):
if "metadata" in files:
metadata_entry = files["metadata"]
if isinstance(metadata_entry, tuple) and len(metadata_entry) >= 2:
try:
metadata_raw = json_module.loads(metadata_entry[1] or "{}")
except json_module.JSONDecodeError:
metadata_raw = {}
if "raw_text" in files:
raw_text_entry = files["raw_text"]
if isinstance(raw_text_entry, tuple) and len(raw_text_entry) >= 2 and raw_text_entry[1] is not None:
content = str(raw_text_entry[1])
if "id" in files:
id_entry = files["id"]
if isinstance(id_entry, tuple) and len(id_entry) >= 2 and id_entry[1]:
doc_id = str(id_entry[1])
if "collection_ids" in files:
coll_entry = files["collection_ids"]
if isinstance(coll_entry, tuple) and len(coll_entry) >= 2 and coll_entry[1]:
try:
parsed = json_module.loads(coll_entry[1])
if isinstance(parsed, list):
collection_ids = [str(item) for item in parsed]
except json_module.JSONDecodeError:
collection_ids = []
if not collection_ids:
name = metadata_raw.get("collection_name") or metadata_raw.get("collection")
if isinstance(name, str):
located = self.find_collection_by_name(name)
if located:
collection_ids = [located[0]]
if not collection_ids and self._collections:
collection_ids = [next(iter(self._collections))]
document = self.create_document(
document_id=doc_id,
content=content,
metadata=metadata_raw,
collection_ids=collection_ids,
)
response_payload = self.spec.generate_from_ref(
"#/components/schemas/R2RResults_IngestionResponse_"
)
ingestion = self.spec.generate_from_ref("#/components/schemas/IngestionResponse")
ingestion.update(
{
"message": ingestion.get("message") or "Ingestion task queued successfully.",
"document_id": document["id"],
"task_id": ingestion.get("task_id") or str(uuid4()),
}
)
response_payload["results"] = ingestion
return 202, response_payload
if method_upper == "GET" and len(segments) == 3:
doc_id = segments[2]
document = self.get_document(doc_id)
if document is None:
return 404, {"detail": f"Document {doc_id} not found"}
response_payload = self.spec.generate_from_ref(
"#/components/schemas/R2RResults_DocumentResponse_"
)
response_payload["results"] = document
return 200, response_payload
if method_upper == "DELETE" and len(segments) == 3:
doc_id = segments[2]
success = self.delete_document(doc_id)
response_payload = self.spec.generate_from_ref(
"#/components/schemas/R2RResults_GenericBooleanResponse_"
)
results = response_payload.get("results")
if isinstance(results, Mapping):
results = copy.deepcopy(results)
results.update({"success": success})
else:
results = {"success": success}
response_payload["results"] = results
status = 200 if success else 404
return status, response_payload
if method_upper == "PATCH" and len(segments) == 4 and segments[3] == "metadata":
doc_id = segments[2]
metadata_list = []
if isinstance(json, list):
metadata_list = [dict(item) for item in json]
document = self.append_document_metadata(doc_id, metadata_list)
if document is None:
return 404, {"detail": f"Document {doc_id} not found"}
response_payload = self.spec.generate_from_ref(
"#/components/schemas/R2RResults_DocumentResponse_"
)
response_payload["results"] = document
return 200, response_payload
return super().handle(method=method, path=path, json=json, params=params, files=files)
class FirecrawlMockService(OpenAPIMockService):
"""Stateful mock for Firecrawl map and scrape endpoints."""
def __init__(self, base_url: str, spec: OpenAPISpec) -> None:
super().__init__(base_url, spec)
self._maps: dict[str, list[str]] = {}
self._pages: dict[str, dict[str, Any]] = {}
def register_map_result(self, origin: str, links: list[str]) -> None:
self._maps[origin] = list(links)
def register_page(self, url: str, *, markdown: str | None = None, html: str | None = None, metadata: Mapping[str, Any] | None = None, links: list[str] | None = None) -> None:
self._pages[url] = {
"markdown": markdown,
"html": html,
"metadata": dict(metadata or {}),
"links": list(links or []),
}
def _build_map_payload(self, target_url: str, limit: int | None) -> dict[str, Any]:
payload = self.spec.generate_from_ref("#/components/schemas/MapResponse")
links = self._maps.get(target_url, [target_url])
if limit is not None:
try:
limit_value = int(limit)
if limit_value >= 0:
links = links[:limit_value]
except (TypeError, ValueError):
pass
payload["success"] = True
payload["links"] = [{"url": link} for link in links]
return payload
def _default_metadata(self, url: str) -> dict[str, Any]:
metadata = self.spec.generate_from_ref("#/components/schemas/ScrapeMetadata")
metadata.update(
{
"url": url,
"sourceURL": url,
"scrapeId": metadata.get("scrapeId") or str(uuid4()),
"statusCode": metadata.get("statusCode", 200) or 200,
"contentType": metadata.get("contentType", "text/html") or "text/html",
"creditsUsed": metadata.get("creditsUsed", 1) or 1,
}
)
return metadata
def _build_scrape_payload(self, target_url: str, formats: list[str] | None) -> dict[str, Any]:
payload = self.spec.generate_from_ref("#/components/schemas/ScrapeResponse")
payload["success"] = True
page_info = self._pages.get(target_url, {})
data = payload.get("data", {})
markdown = page_info.get("markdown") or f"# Content for {target_url}\n"
html = page_info.get("html") or f"<h1>Content for {target_url}</h1>"
data.update(
{
"markdown": markdown,
"html": html,
"rawHtml": page_info.get("rawHtml", html),
"links": page_info.get("links", []),
}
)
metadata_payload = self._default_metadata(target_url)
metadata_payload.update(page_info.get("metadata", {}))
data["metadata"] = metadata_payload
payload["data"] = data
return payload
def map_response(self, url: str, limit: int | None = None) -> dict[str, Any]:
return self._build_map_payload(url, limit)
def scrape_response(self, url: str, formats: list[str] | None = None) -> dict[str, Any]:
return self._build_scrape_payload(url, formats)
def handle(
self,
*,
method: str,
path: str,
json: Mapping[str, Any] | None,
params: Mapping[str, Any] | None,
files: Any,
) -> tuple[int, Any]:
normalized = self._normalize_path(path)
method_upper = method.upper()
if normalized == "/v2/map" and method_upper == "POST":
body = json or {}
target_url = str(body.get("url", ""))
limit = body.get("limit")
return 200, self._build_map_payload(target_url, limit)
if normalized == "/v2/scrape" and method_upper == "POST":
body = json or {}
target_url = str(body.get("url", ""))
formats = body.get("formats")
return 200, self._build_scrape_payload(target_url, formats)
return super().handle(method=method, path=path, json=json, params=params, files=files)

0
tests/unit/__init__.py Normal file
View File

Binary file not shown.

View File

@@ -0,0 +1 @@
"""CLI tests package."""

Binary file not shown.

View File

@@ -0,0 +1,147 @@
from __future__ import annotations
from types import SimpleNamespace
from uuid import uuid4
import pytest
from ingest_pipeline.cli import main
from ingest_pipeline.core.models import (
IngestionResult,
IngestionSource,
IngestionStatus,
StorageBackend,
)
@pytest.mark.parametrize(
("collection_arg", "expected"),
(
(None, "docs_example_com_web"),
("explicit_collection", "explicit_collection"),
),
)
@pytest.mark.asyncio
async def test_run_ingestion_collection_resolution(monkeypatch: pytest.MonkeyPatch, collection_arg: str | None, expected: str) -> None:
recorded: dict[str, object] = {}
async def fake_flow(**kwargs: object) -> IngestionResult:
recorded.update(kwargs)
return IngestionResult(
job_id=uuid4(),
status=IngestionStatus.COMPLETED,
documents_processed=7,
documents_failed=0,
duration_seconds=1.5,
error_messages=[],
)
monkeypatch.setattr(main, "create_ingestion_flow", fake_flow)
result = await main.run_ingestion(
url="https://docs.example.com/guide",
source_type=IngestionSource.WEB,
storage_backend=StorageBackend.WEAVIATE,
collection_name=collection_arg,
validate_first=False,
)
assert recorded["collection_name"] == expected
assert result.documents_processed == 7
@pytest.mark.parametrize(
("cron_value", "serve_now", "expected_type", "serve_expected"),
(
("0 * * * *", False, "cron", False),
(None, True, "interval", True),
),
)
def test_schedule_creates_deployment(
monkeypatch: pytest.MonkeyPatch,
cron_value: str | None,
serve_now: bool,
expected_type: str,
serve_expected: bool,
) -> None:
recorded: dict[str, object] = {}
served = {"called": False}
def fake_create_scheduled_deployment(**kwargs: object) -> str:
recorded.update(kwargs)
return "deployment"
def fake_serve_deployments(deployments: list[str]) -> None:
served["called"] = True
assert deployments == ["deployment"]
monkeypatch.setattr(main, "create_scheduled_deployment", fake_create_scheduled_deployment)
monkeypatch.setattr(main, "serve_deployments", fake_serve_deployments)
main.schedule(
name="nightly",
source_url="https://example.com",
source_type=IngestionSource.WEB,
storage=StorageBackend.WEAVIATE,
cron=cron_value,
interval=30,
serve_now=serve_now,
)
assert recorded["schedule_type"] == expected_type
assert served["called"] is serve_expected
def test_serve_tui_launch(monkeypatch: pytest.MonkeyPatch) -> None:
invoked = {"count": 0}
def fake_dashboard() -> None:
invoked["count"] = invoked["count"] + 1
monkeypatch.setattr("ingest_pipeline.cli.tui.dashboard", fake_dashboard)
main.serve(ui="tui")
assert invoked["count"] == 1
@pytest.mark.asyncio
async def test_run_search_collects_results(monkeypatch: pytest.MonkeyPatch) -> None:
messages: list[object] = []
class DummyConsole:
def print(self, message: object) -> None:
messages.append(message)
class WeaviateStub:
def __init__(self, config: object) -> None:
self.config = config
self.initialized = False
async def initialize(self) -> None:
self.initialized = True
async def search(
self,
query: str,
limit: int = 10,
threshold: float = 0.7,
*,
collection_name: str | None = None,
) -> object:
yield SimpleNamespace(title="Title", content="Body text", score=0.91)
dummy_settings = SimpleNamespace(
weaviate_endpoint="http://weaviate.local",
weaviate_api_key="token",
openwebui_endpoint="http://chat.local",
openwebui_api_key=None,
)
monkeypatch.setattr(main, "console", DummyConsole())
monkeypatch.setattr(main, "get_settings", lambda: dummy_settings)
monkeypatch.setattr("ingest_pipeline.storage.weaviate.WeaviateStorage", WeaviateStub)
await main.run_search("query", collection=None, backend="weaviate", limit=1)
assert messages[-1] == "\n✅ [green]Found 1 results[/green]"

View File

@@ -0,0 +1 @@
"""Flow tests package."""

Binary file not shown.

View File

@@ -0,0 +1,131 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from ingest_pipeline.core.models import (
IngestionJob,
IngestionSource,
StorageBackend,
)
from ingest_pipeline.flows import ingestion
from ingest_pipeline.flows.ingestion import (
FirecrawlIngestor,
_chunk_urls,
_deduplicate_urls,
filter_existing_documents_task,
ingest_documents_task,
)
@pytest.mark.parametrize(
("urls", "size", "expected"),
(
(["a", "b", "c", "d"], 2, [["a", "b"], ["c", "d"]]),
(["solo"], 3, [["solo"]]),
),
)
def test_chunk_urls_batches_urls(urls: list[str], size: int, expected: list[list[str]]) -> None:
assert _chunk_urls(urls, size) == expected
def test_chunk_urls_rejects_zero() -> None:
with pytest.raises(ValueError):
_chunk_urls(["item"], 0)
def test_deduplicate_urls_preserves_order() -> None:
urls = ["https://example.com", "https://example.com", "https://other.com"]
unique = _deduplicate_urls(urls)
assert unique == ["https://example.com", "https://other.com"]
@pytest.mark.asyncio
async def test_filter_existing_documents_task_filters_known_urls(
monkeypatch: pytest.MonkeyPatch,
) -> None:
existing_url = "https://keep.example.com"
new_url = "https://new.example.com"
existing_id = str(FirecrawlIngestor.compute_document_id(existing_url))
class StubStorage:
display_name = "stub-storage"
async def check_exists(
self,
document_id: str,
*,
collection_name: str | None = None,
stale_after_days: int,
) -> bool:
return document_id == existing_id
monkeypatch.setattr(
ingestion,
"get_run_logger",
lambda: SimpleNamespace(info=lambda *args, **kwargs: None),
)
eligible = await filter_existing_documents_task.fn(
[existing_url, new_url],
StubStorage(),
stale_after_days=30,
)
assert eligible == [new_url]
@pytest.mark.asyncio
async def test_ingest_documents_task_invokes_helpers(
monkeypatch: pytest.MonkeyPatch,
) -> None:
job = IngestionJob(
source_url="https://docs.example.com",
source_type=IngestionSource.WEB,
storage_backend=StorageBackend.WEAVIATE,
)
ingestor_sentinel = object()
storage_sentinel = object()
progress_events: list[tuple[int, str]] = []
def record_progress(percent: int, message: str) -> None:
progress_events.append((percent, message))
async def fake_create_storage(
job_arg: IngestionJob,
collection_name: str | None,
storage_block_name: str | None = None,
) -> object:
return storage_sentinel
async def fake_create_ingestor(job_arg: IngestionJob, config_block_name: str | None = None) -> object:
return ingestor_sentinel
monkeypatch.setattr(ingestion, "_create_ingestor", fake_create_ingestor)
monkeypatch.setattr(ingestion, "_create_storage", fake_create_storage)
fake_process = AsyncMock(return_value=(5, 1))
monkeypatch.setattr(ingestion, "_process_documents", fake_process)
result = await ingest_documents_task.fn(
job,
collection_name="unit",
progress_callback=record_progress,
)
assert result == (5, 1)
assert progress_events == [
(35, "Creating ingestor and storage clients..."),
(40, "Starting document processing..."),
]
fake_process.assert_awaited_once_with(
ingestor_sentinel,
storage_sentinel,
job,
50,
"unit",
record_progress,
)

View File

@@ -0,0 +1,72 @@
from __future__ import annotations
from datetime import timedelta
from types import SimpleNamespace
import pytest
from ingest_pipeline.flows import scheduler
def test_create_scheduled_deployment_cron(monkeypatch: pytest.MonkeyPatch) -> None:
captured: dict[str, object] = {}
class DummyFlow:
def to_deployment(self, **kwargs: object) -> SimpleNamespace:
captured.update(kwargs)
return SimpleNamespace(**kwargs)
monkeypatch.setattr(scheduler, "create_ingestion_flow", DummyFlow())
deployment = scheduler.create_scheduled_deployment(
name="cron-ingestion",
source_url="https://example.com",
source_type="web",
schedule_type="cron",
cron_expression="0 * * * *",
)
assert captured["schedule"].cron == "0 * * * *"
assert captured["parameters"]["source_type"] == "web"
assert deployment.tags == ["web", "weaviate"]
def test_create_scheduled_deployment_interval(monkeypatch: pytest.MonkeyPatch) -> None:
captured: dict[str, object] = {}
class DummyFlow:
def to_deployment(self, **kwargs: object) -> SimpleNamespace:
captured.update(kwargs)
return SimpleNamespace(**kwargs)
monkeypatch.setattr(scheduler, "create_ingestion_flow", DummyFlow())
deployment = scheduler.create_scheduled_deployment(
name="interval-ingestion",
source_url="https://repo.example.com",
source_type="repository",
storage_backend="open_webui",
schedule_type="interval",
interval_minutes=15,
tags=["custom"],
)
assert captured["schedule"].interval == timedelta(minutes=15)
assert captured["tags"] == ["custom"]
assert deployment.parameters["storage_backend"] == "open_webui"
def test_serve_deployments_invokes_prefect(monkeypatch: pytest.MonkeyPatch) -> None:
called: dict[str, object] = {}
def fake_serve(*deployments: object, limit: int) -> None:
called["deployments"] = deployments
called["limit"] = limit
monkeypatch.setattr(scheduler, "serve", fake_serve)
deployment = SimpleNamespace(name="only")
scheduler.serve_deployments([deployment])
assert called["deployments"] == (deployment,)
assert called["limit"] == 10

View File

View File

@@ -0,0 +1,64 @@
from __future__ import annotations
import pytest
from ingest_pipeline.core.models import IngestionJob, IngestionSource, StorageBackend
from ingest_pipeline.ingestors.firecrawl import FirecrawlIngestor
@pytest.mark.asyncio
async def test_firecrawl_ingest_flows(
firecrawl_service,
firecrawl_client_stub,
) -> None:
base_url = "https://example.com"
page_urls = [f"{base_url}/docs", f"{base_url}/about"]
firecrawl_service.register_map_result(base_url, page_urls)
for index, url in enumerate(page_urls, start=1):
firecrawl_service.register_page(
url,
markdown=f"# Page {index}\nContent body",
metadata={
"title": f"Page {index}",
"description": f"Description {index}",
"statusCode": 200,
"language": "en",
"sourceURL": url,
},
)
ingestor = FirecrawlIngestor()
job = IngestionJob(
source_type=IngestionSource.WEB,
source_url=base_url,
storage_backend=StorageBackend.WEAVIATE,
)
documents = [document async for document in ingestor.ingest(job)]
assert {doc.metadata["title"] for doc in documents} == {"Page 1", "Page 2"}
assert all(doc.content.startswith("# Page") for doc in documents)
@pytest.mark.asyncio
async def test_firecrawl_validate_source(
firecrawl_service,
firecrawl_client_stub,
) -> None:
target_url = "https://validate.example.com"
firecrawl_service.register_page(
target_url,
markdown="# Validation Page\n",
metadata={
"title": "Validation Page",
"statusCode": 200,
"language": "en",
"sourceURL": target_url,
},
)
ingestor = FirecrawlIngestor()
assert await ingestor.validate_source(target_url) is True
assert await ingestor.estimate_size(target_url) == 1

View File

@@ -0,0 +1,85 @@
from __future__ import annotations
import pytest
from ingest_pipeline.core.models import IngestionJob, IngestionSource, StorageBackend
from ingest_pipeline.ingestors.repomix import RepomixIngestor
@pytest.mark.parametrize(
("content", "expected_keys"),
(
("## File: src/app.py\nprint(\"hi\")", ["src/app.py"]),
("plain content without markers", ["repository"]),
),
)
def test_split_by_files_detects_file_markers(content: str, expected_keys: list[str]) -> None:
ingestor = RepomixIngestor()
results = ingestor._split_by_files(content)
assert list(results) == expected_keys
@pytest.mark.parametrize(
("content", "chunk_size", "expected"),
(
("line-one\nline-two\nline-three", 9, ["line-one", "line-two", "line-three"]),
("single-line", 50, ["single-line"]),
),
)
def test_chunk_content_respects_max_size(
content: str,
chunk_size: int,
expected: list[str],
) -> None:
ingestor = RepomixIngestor()
chunks = ingestor._chunk_content(content, chunk_size=chunk_size)
assert chunks == expected
@pytest.mark.parametrize(
("file_path", "content", "expected"),
(
("src/app.py", "def feature():\n return True", "python"),
("scripts/run", "#!/usr/bin/env python\nprint(\"ok\")", "python"),
("documentation.md", "# Title", "markdown"),
("unknown.ext", "text", None),
),
)
def test_detect_programming_language_infers_extension(
file_path: str,
content: str,
expected: str | None,
) -> None:
ingestor = RepomixIngestor()
detected = ingestor._detect_programming_language(file_path, content)
assert detected == expected
def test_create_document_enriches_metadata() -> None:
ingestor = RepomixIngestor()
job = IngestionJob(
source_url="https://example.com/repo.git",
source_type=IngestionSource.REPOSITORY,
storage_backend=StorageBackend.WEAVIATE,
)
document = ingestor._create_document(
"src/module.py",
"def alpha():\n return 42\n",
job,
chunk_index=1,
git_metadata={"branch_name": "main", "commit_hash": "deadbeef"},
repo_info={"repository_name": "demo"},
)
assert document.metadata["repository_name"] == "demo"
assert document.metadata["branch_name"] == "main"
assert document.metadata["commit_hash"] == "deadbeef"
assert document.metadata["title"].endswith("(chunk 1)")
assert document.collection == job.storage_backend.value

View File

View File

@@ -0,0 +1,102 @@
from __future__ import annotations
from datetime import UTC, datetime, timedelta
import pytest
from ingest_pipeline.core.models import Document, StorageConfig
from ingest_pipeline.storage.base import BaseStorage
class StubStorage(BaseStorage):
"""Concrete BaseStorage for exercising shared logic."""
def __init__(
self,
config: StorageConfig,
*,
result: Document | None = None,
error: BaseException | None = None,
) -> None:
super().__init__(config)
self._result = result
self._error = error
async def initialize(self) -> None: # pragma: no cover - not used in tests
return None
async def store(
self,
document: Document,
*,
collection_name: str | None = None,
) -> str: # pragma: no cover
return ""
async def store_batch(
self,
documents: list[Document],
*,
collection_name: str | None = None,
) -> list[str]: # pragma: no cover
return []
async def delete(
self,
document_id: str,
*,
collection_name: str | None = None,
) -> bool: # pragma: no cover
return False
async def retrieve(
self,
document_id: str,
*,
collection_name: str | None = None,
) -> Document | None:
if self._error is not None:
raise self._error
return self._result
@pytest.mark.asyncio
@pytest.mark.parametrize(
"delta_days",
[pytest.param(0, id="fresh"), pytest.param(40, id="stale")],
)
async def test_check_exists_uses_staleness(document_factory, storage_config, delta_days) -> None:
"""Return True only when document timestamp is within freshness window."""
timestamp = datetime.now(UTC) - timedelta(days=delta_days)
document = document_factory(content=f"doc-{delta_days}", metadata_updates={"timestamp": timestamp})
storage = StubStorage(storage_config, result=document)
outcome = await storage.check_exists("identifier", stale_after_days=30)
assert outcome is (delta_days <= 30)
@pytest.mark.asyncio
async def test_check_exists_returns_false_for_missing(storage_config) -> None:
"""Return False when retrieve yields no document."""
storage = StubStorage(storage_config, result=None)
assert await storage.check_exists("missing") is False
@pytest.mark.asyncio
@pytest.mark.parametrize(
"error_factory",
[
pytest.param(lambda: NotImplementedError("unsupported"), id="not-implemented"),
pytest.param(lambda: RuntimeError("unexpected"), id="generic-error"),
],
)
async def test_check_exists_swallows_errors(storage_config, error_factory) -> None:
"""Return False when retrieval raises errors."""
storage = StubStorage(storage_config, error=error_factory())
assert await storage.check_exists("identifier") is False

View File

@@ -0,0 +1,136 @@
from __future__ import annotations
from typing import Any
import pytest
from ingest_pipeline.core.models import StorageConfig
from ingest_pipeline.storage.openwebui import OpenWebUIStorage
def _make_storage(config: StorageConfig) -> OpenWebUIStorage:
return OpenWebUIStorage(config)
@pytest.mark.asyncio
async def test_get_knowledge_id_returns_existing(
storage_config: StorageConfig,
openwebui_service,
httpx_stub,
) -> None:
"""Return cached identifier when knowledge base already exists."""
openwebui_service.ensure_knowledge(
name=storage_config.collection_name,
knowledge_id="kb-123",
)
storage = _make_storage(storage_config)
knowledge_id = await storage._get_knowledge_id(None, create=False)
assert knowledge_id == "kb-123"
urls = [request["url"] for request in httpx_stub.requests]
assert "http://storage.local/api/v1/knowledge/list" in urls
await storage.client.aclose()
@pytest.mark.asyncio
async def test_get_knowledge_id_creates_when_missing(
storage_config: StorageConfig,
openwebui_service,
httpx_stub,
) -> None:
"""Create knowledge base when missing and caching the result."""
derived_config = storage_config.model_copy(update={"collection_name": "custom"})
storage = _make_storage(derived_config)
knowledge_id = await storage._get_knowledge_id("custom", create=True)
assert knowledge_id is not None
urls = [request["url"] for request in httpx_stub.requests]
assert "http://storage.local/api/v1/knowledge/list" in urls
assert any(
url.startswith("http://storage.local/api/v1/knowledge/") and url.endswith("/create")
for url in urls
)
await storage.client.aclose()
@pytest.mark.asyncio
async def test_store_uploads_and_attaches_document(
document_factory,
storage_config: StorageConfig,
openwebui_service,
httpx_stub,
) -> None:
"""Upload file then attach it to the knowledge base."""
storage = _make_storage(storage_config)
document = document_factory(content="hello world", metadata_updates={"title": "example"})
file_id = await storage.store(document)
assert file_id is not None
urls = [request["url"] for request in httpx_stub.requests]
assert any(url.endswith("/api/v1/files/") for url in urls)
assert any("/file/add" in url for url in urls)
knowledge_entry = openwebui_service.find_knowledge_by_name(storage_config.collection_name)
assert knowledge_entry is not None
_, knowledge = knowledge_entry
assert len(knowledge.get("files", [])) == 1
assert knowledge["files"][0]["id"] == file_id
await storage.client.aclose()
@pytest.mark.asyncio
async def test_store_batch_handles_multiple_documents(
document_factory,
storage_config: StorageConfig,
openwebui_service,
httpx_stub,
) -> None:
"""Store documents in batch and return collected file identifiers."""
storage = _make_storage(storage_config)
first = document_factory(content="alpha", metadata_updates={"title": "alpha"})
second = document_factory(content="beta", metadata_updates={"title": "beta"})
file_ids = await storage.store_batch([first, second])
assert len(file_ids) == 2
files_payloads: list[Any] = [request["files"] for request in httpx_stub.requests if request["method"] == "POST"]
assert any(payload is not None for payload in files_payloads)
knowledge_entry = openwebui_service.find_knowledge_by_name(storage_config.collection_name)
assert knowledge_entry is not None
_, knowledge = knowledge_entry
assert {meta["id"] for meta in knowledge.get("files", [])} == set(file_ids)
await storage.client.aclose()
@pytest.mark.asyncio
async def test_delete_removes_file(
storage_config: StorageConfig,
openwebui_service,
httpx_stub,
) -> None:
"""Remove file from knowledge base and delete the uploaded resource."""
openwebui_service.ensure_knowledge(
name=storage_config.collection_name,
knowledge_id="kb-55",
)
openwebui_service.create_file(filename="to-delete.txt", file_id="file-xyz")
openwebui_service.attach_existing_file("kb-55", "file-xyz")
storage = _make_storage(storage_config)
result = await storage.delete("file-xyz")
assert result is True
urls = [request["url"] for request in httpx_stub.requests]
assert "http://storage.local/api/v1/knowledge/kb-55/file/remove" in urls
assert "http://storage.local/api/v1/files/file-xyz" in urls
knowledge = openwebui_service.get_knowledge("kb-55")
assert knowledge is not None
assert knowledge.get("files", []) == []
await storage.client.aclose()

View File

@@ -0,0 +1,343 @@
from __future__ import annotations
from collections.abc import Mapping
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import Any, Self
import pytest
from ingest_pipeline.core.models import StorageConfig
from ingest_pipeline.storage.r2r.storage import (
R2RStorage,
_as_datetime,
_as_int,
_as_mapping,
_as_sequence,
_extract_id,
)
@pytest.fixture
def r2r_client_stub(
monkeypatch: pytest.MonkeyPatch,
r2r_service,
) -> object:
class DummyR2RException(Exception):
def __init__(self, message: str, status_code: int | None = None) -> None:
super().__init__(message)
self.status_code = status_code
class MockResponse:
def __init__(self, json_data: dict[str, Any], status_code: int = 200) -> None:
self._json_data = json_data
self.status_code = status_code
def json(self) -> dict[str, Any]:
return self._json_data
def raise_for_status(self) -> None:
if self.status_code >= 400:
from httpx import HTTPStatusError
raise HTTPStatusError("HTTP error", request=None, response=self)
class MockAsyncClient:
def __init__(self, service: Any) -> None:
self._service = service
async def get(self, url: str) -> MockResponse:
if "/v3/collections" in url:
# Return existing collections
collections = []
for collection_id, collection_data in self._service._collections.items():
collections.append({
"id": collection_id,
"name": collection_data["name"],
"description": collection_data.get("description", ""),
})
return MockResponse({"results": collections})
return MockResponse({})
async def post(self, url: str, *, json: dict[str, Any] | None = None, files: dict[str, Any] | None = None) -> MockResponse:
if "/v3/collections" in url and json:
# Create new collection
new_collection_id = f"col-{len(self._service._collections) + 1}"
self._service.create_collection(
name=json["name"],
collection_id=new_collection_id,
description=json.get("description", ""),
)
return MockResponse({
"results": {
"id": new_collection_id,
"name": json["name"],
"description": json.get("description", ""),
}
})
elif "/v3/documents" in url and files:
# Create document
import json as json_lib
document_id = files.get("id", (None, f"doc-{len(self._service._documents) + 1}"))[1]
content = files.get("raw_text", (None, ""))[1]
metadata_str = files.get("metadata", (None, "{}"))[1]
metadata = json_lib.loads(metadata_str) if metadata_str else {}
# Store document in mock service
document_data = {
"id": document_id,
"content": content,
"metadata": metadata,
}
self._service._documents[document_id] = document_data
# Update collection document count if specified
collection_ids = files.get("collection_ids")
if collection_ids:
# collection_ids is passed as a tuple (None, "[collection_id]") in the files format
collection_ids_str = collection_ids[1] if isinstance(collection_ids, tuple) else collection_ids
try:
import json as json_lib
collection_list = json_lib.loads(collection_ids_str) if isinstance(collection_ids_str, str) else collection_ids_str
if isinstance(collection_list, list) and len(collection_list) > 0:
# Extract the collection ID - it could be a string or dict
first_collection = collection_list[0]
if isinstance(first_collection, dict) and "id" in first_collection:
collection_id = first_collection["id"]
elif isinstance(first_collection, str):
collection_id = first_collection
else:
collection_id = None
if collection_id and collection_id in self._service._collections:
# Update the collection's document count to match the actual number of documents
total_docs = len(self._service._documents)
self._service._collections[collection_id]["document_count"] = total_docs
except (json_lib.JSONDecodeError, TypeError, KeyError):
pass # Ignore parsing errors
return MockResponse({
"results": {
"document_id": document_id,
"message": "Document created successfully",
}
})
return MockResponse({})
async def aclose(self) -> None:
return None
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
return None
class DocumentsAPI:
def __init__(self, service: Any) -> None:
self._service = service
async def retrieve(self, document_id: str) -> dict[str, Any]:
document = self._service.get_document(document_id)
if document is None:
raise DummyR2RException("Not found", status_code=404)
return {"results": document}
async def delete(self, document_id: str) -> dict[str, Any]:
if not self._service.delete_document(document_id):
raise DummyR2RException("Not found", status_code=404)
return {"results": {"success": True}}
async def append_metadata(self, id: str, metadata: list[dict[str, Any]]) -> dict[str, Any]:
document = self._service.append_document_metadata(id, metadata)
if document is None:
raise DummyR2RException("Not found", status_code=404)
return {"results": document}
class RetrievalAPI:
def __init__(self, service: Any) -> None:
self._service = service
async def search(self, query: str, search_settings: Mapping[str, Any]) -> dict[str, Any]:
results = [
{"document_id": doc_id, "score": 1.0}
for doc_id in self._service._documents
]
return {"results": results}
class DummyClient:
def __init__(self, service: Any) -> None:
self.documents = DocumentsAPI(service)
self.retrieval = RetrievalAPI(service)
async def aclose(self) -> None:
return None
async def close(self) -> None:
return None
# Mock the AsyncClient that R2RStorage uses internally
mock_async_client = MockAsyncClient(r2r_service)
monkeypatch.setattr(
"ingest_pipeline.storage.r2r.storage.AsyncClient",
lambda: mock_async_client,
)
client = DummyClient(r2r_service)
monkeypatch.setattr(
"ingest_pipeline.storage.r2r.storage.R2RAsyncClient",
lambda endpoint: client,
)
monkeypatch.setattr(
"ingest_pipeline.storage.r2r.storage.R2RException",
DummyR2RException,
)
return client
@pytest.mark.parametrize(
("value", "expected"),
[
pytest.param({"a": 1}, {"a": 1}, id="mapping"),
pytest.param(SimpleNamespace(a=2), {"a": 2}, id="namespace"),
pytest.param(5, {}, id="other"),
],
)
def test_as_mapping_normalizes(value, expected) -> None:
"""Convert inputs into dictionaries where possible."""
assert _as_mapping(value) == expected
@pytest.mark.parametrize(
("value", "expected"),
[
pytest.param([1, 2], (1, 2), id="list"),
pytest.param((3, 4), (3, 4), id="tuple"),
pytest.param("ab", ("a", "b"), id="string"),
pytest.param(7, (), id="non-iterable"),
],
)
def test_as_sequence_coerces_iterables(value, expected) -> None:
"""Represent values as tuples for downstream iteration."""
assert _as_sequence(value) == expected
@pytest.mark.parametrize(
("source", "fallback", "expected"),
[
pytest.param({"id": "abc"}, "x", "abc", id="mapping"),
pytest.param(SimpleNamespace(id=123), "x", "123", id="attribute"),
pytest.param({}, "fallback", "fallback", id="fallback"),
],
)
def test_extract_id_falls_back(source, fallback, expected) -> None:
"""Prefer embedded identifier values and fall back otherwise."""
assert _extract_id(source, fallback) == expected
@pytest.mark.parametrize(
("value", "expected_year"),
[
pytest.param(datetime(2024, 1, 1, tzinfo=UTC), 2024, id="datetime"),
pytest.param("2024-02-01T00:00:00+00:00", 2024, id="iso"),
pytest.param("invalid", datetime.now(UTC).year, id="fallback"),
],
)
def test_as_datetime_recognizes_formats(value, expected_year) -> None:
"""Produce timezone-aware datetime objects."""
assert _as_datetime(value).year == expected_year
@pytest.mark.parametrize(
("value", "expected"),
[
pytest.param(True, 1, id="bool"),
pytest.param(5, 5, id="int"),
pytest.param(3.9, 3, id="float"),
pytest.param("7", 7, id="string"),
pytest.param("8.2", 8, id="float-string"),
pytest.param("bad", 2, id="default"),
],
)
def test_as_int_handles_numeric_coercions(value, expected) -> None:
"""Convert assorted numeric representations."""
assert _as_int(value, default=2) == expected
@pytest.mark.asyncio
async def test_ensure_collection_finds_existing(
r2r_storage_config: StorageConfig,
r2r_service,
httpx_stub,
r2r_client_stub,
) -> None:
"""Return collection identifier when already present."""
r2r_service.create_collection(
name=r2r_storage_config.collection_name,
collection_id="col-1",
)
storage = R2RStorage(r2r_storage_config)
collection_id = await storage._ensure_collection(r2r_storage_config.collection_name)
assert collection_id == "col-1"
assert storage.default_collection_id == "col-1"
await storage.client.aclose()
@pytest.mark.asyncio
async def test_ensure_collection_creates_when_missing(
r2r_storage_config: StorageConfig,
r2r_service,
httpx_stub,
r2r_client_stub,
) -> None:
"""Create collection via POST when absent."""
storage = R2RStorage(r2r_storage_config)
collection_id = await storage._ensure_collection("alternate")
assert collection_id is not None
located = r2r_service.find_collection_by_name("alternate")
assert located is not None
identifier, _ = located
assert identifier == collection_id
await storage.client.aclose()
@pytest.mark.asyncio
async def test_store_batch_creates_documents(
document_factory,
r2r_storage_config: StorageConfig,
r2r_service,
httpx_stub,
r2r_client_stub,
) -> None:
"""Store documents and persist them via the R2R mock service."""
storage = R2RStorage(r2r_storage_config)
documents = [
document_factory(content="first document", metadata_updates={"title": "First"}),
document_factory(content="second document", metadata_updates={"title": "Second"}),
]
stored_ids = await storage.store_batch(documents)
assert len(stored_ids) == 2
for doc_id, original in zip(stored_ids, documents, strict=False):
stored = r2r_service.get_document(doc_id)
assert stored is not None
assert stored["metadata"]["source_url"] == original.metadata["source_url"]
collection = r2r_service.find_collection_by_name(r2r_storage_config.collection_name)
assert collection is not None
_, collection_payload = collection
assert collection_payload["document_count"] == 2
await storage.client.aclose()

View File

@@ -0,0 +1,140 @@
from __future__ import annotations
from datetime import UTC, datetime
from typing import cast
import pytest
from ingest_pipeline.core.exceptions import StorageError
from ingest_pipeline.core.models import IngestionSource, StorageConfig
from ingest_pipeline.storage.weaviate import WeaviateStorage
from ingest_pipeline.utils.vectorizer import Vectorizer
def _build_storage(config: StorageConfig) -> WeaviateStorage:
storage = WeaviateStorage.__new__(WeaviateStorage)
storage.config = config
storage.client = None
storage.vectorizer = cast(Vectorizer, None)
storage._default_collection = config.collection_name
return storage
@pytest.mark.parametrize(
"raw, expected",
[
pytest.param([1, 2, 3], [1.0, 2.0, 3.0], id="list"),
pytest.param({"default": [4, 5]}, [4.0, 5.0], id="mapping"),
pytest.param([[6, 7]], [6.0, 7.0], id="nested"),
pytest.param("invalid", None, id="invalid"),
],
)
def test_extract_vector_normalizes(raw, expected) -> None:
"""Normalize various vector payload structures."""
assert WeaviateStorage._extract_vector(raw) == expected
@pytest.mark.parametrize(
"value, expected",
[
pytest.param(IngestionSource.WEB, IngestionSource.WEB, id="enum"),
pytest.param("documentation", IngestionSource.DOCUMENTATION, id="string"),
pytest.param("unknown", IngestionSource.WEB, id="fallback"),
pytest.param(42, IngestionSource.WEB, id="non-string"),
],
)
def test_parse_source_normalizes(value, expected) -> None:
"""Ensure ingestion source strings coerce into enums."""
assert WeaviateStorage._parse_source(value) is expected
def test_coerce_properties_returns_mapping() -> None:
"""Pass through mapping when provided."""
payload = {"key": "value"}
assert WeaviateStorage._coerce_properties(payload, context="test") is payload
def test_coerce_properties_allows_missing() -> None:
"""Return None when allow_missing is True and payload absent."""
assert WeaviateStorage._coerce_properties(None, context="test", allow_missing=True) is None
@pytest.mark.parametrize(
"payload",
[pytest.param(None, id="none"), pytest.param([1, 2, 3], id="sequence")],
)
def test_coerce_properties_raises_on_invalid(payload) -> None:
"""Raise storage error when payload is not a mapping."""
with pytest.raises(StorageError):
WeaviateStorage._coerce_properties(payload, context="invalid")
@pytest.mark.parametrize(
"input_name, expected",
[
pytest.param(None, "Documents", id="default"),
pytest.param(" custom ", "Custom", id="trimmed"),
],
)
def test_normalize_collection_name(storage_config, input_name, expected) -> None:
"""Normalize with capitalization and fallback to config value."""
storage = _build_storage(storage_config)
assert storage._normalize_collection_name(input_name) == expected
def test_normalize_collection_name_rejects_empty(storage_config) -> None:
"""Raise when provided name is blank."""
storage = _build_storage(storage_config)
with pytest.raises(StorageError):
storage._normalize_collection_name(" ")
@pytest.mark.parametrize(
"value, expected",
[
pytest.param(5, 5, id="int"),
pytest.param(7.9, 7, id="float"),
pytest.param("12", 12, id="string"),
pytest.param(None, 0, id="none"),
],
)
def test_safe_convert_count(storage_config, value, expected) -> None:
"""Convert assorted count representations into integers."""
storage = _build_storage(storage_config)
assert storage._safe_convert_count(value) == expected
def test_build_document_metadata(storage_config) -> None:
"""Build metadata mapping with type coercions."""
storage = _build_storage(storage_config)
props = {
"source_url": "https://example.com",
"title": "Doc",
"description": "Details",
"timestamp": "2024-01-01T00:00:00+00:00",
"content_type": "text/plain",
"word_count": "5",
"char_count": 128.0,
}
metadata = storage._build_document_metadata(props)
assert metadata["source_url"] == "https://example.com"
assert metadata["title"] == "Doc"
assert metadata["description"] == "Details"
assert metadata["timestamp"] == datetime(2024, 1, 1, tzinfo=UTC)
assert metadata["word_count"] == 5
assert metadata["char_count"] == 128

Some files were not shown because too many files have changed in this diff Show More