updates
This commit is contained in:
@@ -7,7 +7,8 @@
|
||||
"WebSearch",
|
||||
"Bash(cat:*)",
|
||||
"mcp__firecrawl__firecrawl_search",
|
||||
"mcp__firecrawl__firecrawl_scrape"
|
||||
"mcp__firecrawl__firecrawl_scrape",
|
||||
"Bash(python:*)"
|
||||
],
|
||||
"deny": [],
|
||||
"ask": []
|
||||
|
||||
100
.gitignore
vendored
Normal file
100
.gitignore
vendored
Normal file
@@ -0,0 +1,100 @@
|
||||
# Environment files
|
||||
.env
|
||||
.env.local
|
||||
.env.*.local
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# Virtual environments
|
||||
.venv/
|
||||
.env/
|
||||
venv/
|
||||
ENV/
|
||||
env/
|
||||
|
||||
# Package managers
|
||||
.uv_cache/
|
||||
.uv-cache/
|
||||
.pip-cache/
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Testing
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
.pytest_cache/
|
||||
htmlcov/
|
||||
.nox/
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
|
||||
# Type checking
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
.pyre/
|
||||
.pytype/
|
||||
|
||||
# Linting and formatting
|
||||
.ruff_cache/
|
||||
.flake8_cache/
|
||||
.black_cache/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IDE and editors
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
.DS_Store?
|
||||
._*
|
||||
.Spotlight-V100
|
||||
.Trashes
|
||||
ehthumbs.db
|
||||
Thumbs.db
|
||||
|
||||
# Logs
|
||||
logs/
|
||||
*.log
|
||||
log/
|
||||
|
||||
# Temporary files
|
||||
*.tmp
|
||||
*.temp
|
||||
.tmp/
|
||||
.temp/
|
||||
|
||||
# Project specific
|
||||
repomix-output.xml
|
||||
*.json.bak
|
||||
chat.json
|
||||
36
.vscode/settings.json
vendored
36
.vscode/settings.json
vendored
@@ -1,7 +1,37 @@
|
||||
{
|
||||
"chatgpt.openOnStartup": true,
|
||||
"python.languageServer": "None",
|
||||
"python.analysis.typeCheckingMode": "off",
|
||||
"python.defaultInterpreterPath": "./.venv/bin/python",
|
||||
"python.terminal.activateEnvironment": true
|
||||
"python.terminal.activateEnvironment": true,
|
||||
"python.linting.enabled": true,
|
||||
"python.linting.mypyEnabled": true,
|
||||
"python.linting.mypyPath": "./.venv/bin/mypy",
|
||||
"python.linting.pylintEnabled": false,
|
||||
"python.linting.flake8Enabled": false,
|
||||
"python.analysis.typeCheckingMode": "basic",
|
||||
"python.analysis.autoImportCompletions": true,
|
||||
"python.analysis.stubPath": "./.venv/lib/python3.12/site-packages",
|
||||
"basedpyright.analysis.typeCheckingMode": "standard",
|
||||
"basedpyright.analysis.autoSearchPaths": true,
|
||||
"basedpyright.analysis.autoImportCompletions": true,
|
||||
"basedpyright.analysis.diagnosticMode": "workspace",
|
||||
"basedpyright.analysis.stubPath": "./.venv/lib/python3.12/site-packages",
|
||||
"basedpyright.analysis.extraPaths": [
|
||||
"./ingest_pipeline",
|
||||
"./.venv/lib/python3.12/site-packages"
|
||||
],
|
||||
"pyright.analysis.typeCheckingMode": "standard",
|
||||
"pyright.analysis.autoSearchPaths": true,
|
||||
"pyright.analysis.autoImportCompletions": true,
|
||||
"pyright.analysis.diagnosticMode": "workspace",
|
||||
"pyright.analysis.stubPath": "./.venv/lib/python3.12/site-packages",
|
||||
"pyright.analysis.extraPaths": [
|
||||
"./ingest_pipeline",
|
||||
"./.venv/lib/python3.12/site-packages"
|
||||
],
|
||||
"files.exclude": {
|
||||
"**/__pycache__": true,
|
||||
"**/.pytest_cache": true,
|
||||
"**/node_modules": true,
|
||||
".mypy_cache": true
|
||||
}
|
||||
}
|
||||
@@ -6,11 +6,26 @@
|
||||
"**/__pycache__",
|
||||
"**/.pytest_cache",
|
||||
"**/node_modules",
|
||||
".venv"
|
||||
".venv",
|
||||
"build",
|
||||
"dist"
|
||||
],
|
||||
"pythonPath": "./.venv/bin/python",
|
||||
"pythonVersion": "3.12",
|
||||
"venvPath": ".",
|
||||
"venv": ".venv",
|
||||
"typeCheckingMode": "standard",
|
||||
"useLibraryCodeForTypes": true,
|
||||
"stubPath": "./.venv/lib/python3.12/site-packages",
|
||||
"executionEnvironments": [
|
||||
{
|
||||
"root": ".",
|
||||
"pythonVersion": "3.12",
|
||||
"extraPaths": [
|
||||
"./ingest_pipeline",
|
||||
"./.venv/lib/python3.12/site-packages"
|
||||
]
|
||||
}
|
||||
],
|
||||
"reportCallInDefaultInitializer": "none",
|
||||
"reportUnknownVariableType": "warning",
|
||||
"reportUnknownMemberType": "warning",
|
||||
@@ -21,9 +36,12 @@
|
||||
"reportUnannotatedClassAttribute": "warning",
|
||||
"reportMissingTypeStubs": "none",
|
||||
"reportMissingModuleSource": "none",
|
||||
"reportImportCycles": "none",
|
||||
"reportAttributeAccessIssue": "warning",
|
||||
"reportAny": "warning",
|
||||
"reportUnusedCallResult": "none",
|
||||
"reportUnnecessaryIsInstance": "none",
|
||||
"reportImplicitOverride": "none",
|
||||
"reportDeprecated": "warning"
|
||||
"reportDeprecated": "warning",
|
||||
"analyzeUnannotatedFunctions": true
|
||||
}
|
||||
|
||||
62
ingest_pipeline/automations/__init__.py
Normal file
62
ingest_pipeline/automations/__init__.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Prefect Automations for ingestion pipeline monitoring and management."""
|
||||
|
||||
# Automation configurations as YAML-ready dictionaries
|
||||
AUTOMATION_TEMPLATES = {
|
||||
"cancel_long_running": """
|
||||
name: Cancel Long Running Ingestion Flows
|
||||
description: Cancels ingestion flows running longer than 30 minutes
|
||||
trigger:
|
||||
type: event
|
||||
posture: Proactive
|
||||
expect: [prefect.flow-run.Running]
|
||||
match_related:
|
||||
prefect.resource.role: flow
|
||||
prefect.resource.name: ingestion_pipeline
|
||||
threshold: 1
|
||||
within: 1800
|
||||
actions:
|
||||
- type: cancel-flow-run
|
||||
source: inferred
|
||||
enabled: true
|
||||
""",
|
||||
|
||||
"retry_failed": """
|
||||
name: Retry Failed Ingestion Flows
|
||||
description: Retries failed ingestion flows with original parameters
|
||||
trigger:
|
||||
type: event
|
||||
posture: Reactive
|
||||
expect: [prefect.flow-run.Failed]
|
||||
match_related:
|
||||
prefect.resource.role: flow
|
||||
prefect.resource.name: ingestion_pipeline
|
||||
threshold: 1
|
||||
within: 0
|
||||
actions:
|
||||
- type: run-deployment
|
||||
source: inferred
|
||||
parameters:
|
||||
validate_first: false
|
||||
enabled: true
|
||||
""",
|
||||
|
||||
"resource_monitoring": """
|
||||
name: Manage Work Pool Based on Resources
|
||||
description: Pauses work pool when system resources are constrained
|
||||
trigger:
|
||||
type: event
|
||||
posture: Reactive
|
||||
expect: [system.resource.high_usage]
|
||||
threshold: 1
|
||||
within: 120
|
||||
actions:
|
||||
- type: pause-work-pool
|
||||
work_pool_name: default
|
||||
enabled: true
|
||||
""",
|
||||
}
|
||||
|
||||
|
||||
def get_automation_yaml_templates() -> dict[str, str]:
|
||||
"""Get automation templates as YAML strings."""
|
||||
return AUTOMATION_TEMPLATES.copy()
|
||||
Binary file not shown.
@@ -4,13 +4,19 @@ import asyncio
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
from pydantic import SecretStr
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
|
||||
from rich.table import Table
|
||||
|
||||
from ..config import configure_prefect, get_settings
|
||||
from ..core.models import IngestionResult, IngestionSource, StorageBackend
|
||||
from ..core.models import (
|
||||
IngestionResult,
|
||||
IngestionSource,
|
||||
StorageBackend,
|
||||
StorageConfig,
|
||||
)
|
||||
from ..flows.ingestion import create_ingestion_flow
|
||||
from ..flows.scheduler import create_scheduled_deployment, serve_deployments
|
||||
|
||||
@@ -439,6 +445,22 @@ def search(
|
||||
asyncio.run(run_search(query, collection, backend.value, limit))
|
||||
|
||||
|
||||
@app.command(name="blocks")
|
||||
def blocks_command() -> None:
|
||||
"""🧩 List and manage Prefect Blocks."""
|
||||
console.print("[bold cyan]📦 Prefect Blocks Management[/bold cyan]")
|
||||
console.print("Use 'prefect block register --module ingest_pipeline.core.models' to register custom blocks")
|
||||
console.print("Use 'prefect block ls' to list available blocks")
|
||||
|
||||
|
||||
@app.command(name="variables")
|
||||
def variables_command() -> None:
|
||||
"""📊 Manage Prefect Variables."""
|
||||
console.print("[bold cyan]📊 Prefect Variables Management[/bold cyan]")
|
||||
console.print("Use 'prefect variable set VARIABLE_NAME value' to set variables")
|
||||
console.print("Use 'prefect variable ls' to list variables")
|
||||
|
||||
|
||||
async def run_ingestion(
|
||||
url: str,
|
||||
source_type: IngestionSource,
|
||||
@@ -470,8 +492,8 @@ async def run_list_collections() -> None:
|
||||
"""
|
||||
List collections across storage backends.
|
||||
"""
|
||||
from ..config import configure_prefect, get_settings
|
||||
from ..core.models import StorageBackend, StorageConfig
|
||||
from ..config import get_settings
|
||||
from ..core.models import StorageBackend
|
||||
from ..storage.openwebui import OpenWebUIStorage
|
||||
from ..storage.weaviate import WeaviateStorage
|
||||
|
||||
@@ -485,7 +507,7 @@ async def run_list_collections() -> None:
|
||||
weaviate_config = StorageConfig(
|
||||
backend=StorageBackend.WEAVIATE,
|
||||
endpoint=settings.weaviate_endpoint,
|
||||
api_key=settings.weaviate_api_key,
|
||||
api_key=SecretStr(settings.weaviate_api_key) if settings.weaviate_api_key is not None else None,
|
||||
collection_name="default",
|
||||
)
|
||||
weaviate = WeaviateStorage(weaviate_config)
|
||||
@@ -494,7 +516,8 @@ async def run_list_collections() -> None:
|
||||
overview = await weaviate.describe_collections()
|
||||
for item in overview:
|
||||
name = str(item.get("name", "Unknown"))
|
||||
count = int(item.get("count", 0))
|
||||
count_val = item.get("count", 0)
|
||||
count = int(count_val) if isinstance(count_val, (int, str)) else 0
|
||||
weaviate_collections.append((name, count))
|
||||
except Exception as e:
|
||||
console.print(f"❌ [red]Weaviate connection failed: {e}[/red]")
|
||||
@@ -505,7 +528,7 @@ async def run_list_collections() -> None:
|
||||
openwebui_config = StorageConfig(
|
||||
backend=StorageBackend.OPEN_WEBUI,
|
||||
endpoint=settings.openwebui_endpoint,
|
||||
api_key=settings.openwebui_api_key,
|
||||
api_key=SecretStr(settings.openwebui_api_key) if settings.openwebui_api_key is not None else None,
|
||||
collection_name="default",
|
||||
)
|
||||
openwebui = OpenWebUIStorage(openwebui_config)
|
||||
@@ -514,7 +537,8 @@ async def run_list_collections() -> None:
|
||||
overview = await openwebui.describe_collections()
|
||||
for item in overview:
|
||||
name = str(item.get("name", "Unknown"))
|
||||
count = int(item.get("count", 0))
|
||||
count_val = item.get("count", 0)
|
||||
count = int(count_val) if isinstance(count_val, (int, str)) else 0
|
||||
openwebui_collections.append((name, count))
|
||||
except Exception as e:
|
||||
console.print(f"❌ [red]OpenWebUI connection failed: {e}[/red]")
|
||||
@@ -551,8 +575,8 @@ async def run_search(query: str, collection: str | None, backend: str, limit: in
|
||||
"""
|
||||
Search across collections.
|
||||
"""
|
||||
from ..config import configure_prefect, get_settings
|
||||
from ..core.models import StorageBackend, StorageConfig
|
||||
from ..config import get_settings
|
||||
from ..core.models import StorageBackend
|
||||
from ..storage.weaviate import WeaviateStorage
|
||||
|
||||
settings = get_settings()
|
||||
@@ -569,7 +593,7 @@ async def run_search(query: str, collection: str | None, backend: str, limit: in
|
||||
weaviate_config = StorageConfig(
|
||||
backend=StorageBackend.WEAVIATE,
|
||||
endpoint=settings.weaviate_endpoint,
|
||||
api_key=settings.weaviate_api_key,
|
||||
api_key=SecretStr(settings.weaviate_api_key) if settings.weaviate_api_key is not None else None,
|
||||
collection_name=collection or "default",
|
||||
)
|
||||
weaviate = WeaviateStorage(weaviate_config)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -17,7 +17,8 @@ from textual.timer import Timer
|
||||
from ...storage.base import BaseStorage
|
||||
from ...storage.openwebui import OpenWebUIStorage
|
||||
from ...storage.weaviate import WeaviateStorage
|
||||
from .screens import CollectionOverviewScreen, HelpScreen
|
||||
from .screens.dashboard import CollectionOverviewScreen
|
||||
from .screens.help import HelpScreen
|
||||
from .styles import TUI_CSS
|
||||
from .utils.storage_manager import StorageManager
|
||||
|
||||
@@ -35,6 +36,8 @@ class CollectionManagementApp(App[None]):
|
||||
"""Enhanced modern Textual application with comprehensive keyboard navigation."""
|
||||
|
||||
CSS: ClassVar[str] = TUI_CSS
|
||||
TITLE = "Collection Management"
|
||||
SUB_TITLE = "Document Ingestion Pipeline"
|
||||
|
||||
def safe_notify(
|
||||
self,
|
||||
@@ -89,8 +92,8 @@ class CollectionManagementApp(App[None]):
|
||||
self.weaviate = weaviate
|
||||
self.openwebui = openwebui
|
||||
self.r2r = r2r
|
||||
self.title: str = ""
|
||||
self.sub_title: str = ""
|
||||
# Remove direct assignment to read-only title properties
|
||||
# These should be set through class attributes or overridden methods
|
||||
self.log_queue = log_queue
|
||||
self._log_formatter = log_formatter or logging.Formatter(
|
||||
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
@@ -146,7 +149,7 @@ class CollectionManagementApp(App[None]):
|
||||
if drained and self._log_viewer is not None:
|
||||
self._log_viewer.append_logs(drained)
|
||||
|
||||
def attach_log_viewer(self, viewer: "LogViewerScreen") -> None:
|
||||
def attach_log_viewer(self, viewer: LogViewerScreen) -> None:
|
||||
"""Register an active log viewer and hydrate it with existing entries."""
|
||||
self._log_viewer = viewer
|
||||
viewer.replace_logs(list(self._log_buffer))
|
||||
@@ -154,7 +157,7 @@ class CollectionManagementApp(App[None]):
|
||||
# Drain once more to deliver any entries gathered between instantiation and mount
|
||||
self._drain_log_queue()
|
||||
|
||||
def detach_log_viewer(self, viewer: "LogViewerScreen") -> None:
|
||||
def detach_log_viewer(self, viewer: LogViewerScreen) -> None:
|
||||
"""Remove the current log viewer when it is dismissed."""
|
||||
if self._log_viewer is viewer:
|
||||
self._log_viewer = None
|
||||
@@ -273,7 +276,7 @@ class CollectionManagementApp(App[None]):
|
||||
if len(self.screen_stack) > 1: # Don't close the main screen
|
||||
_ = self.pop_screen()
|
||||
else:
|
||||
_ = self.notify("Cannot close main screen. Use Q to quit.", severity="warning")
|
||||
self.notify("Cannot close main screen. Use Q to quit.", severity="warning")
|
||||
|
||||
def action_dashboard_tab(self) -> None:
|
||||
"""Switch to dashboard tab in current screen."""
|
||||
@@ -305,7 +308,7 @@ class CollectionManagementApp(App[None]):
|
||||
_ = event.prevent_default()
|
||||
elif event.key == "ctrl+alt+r":
|
||||
# Force refresh all connections
|
||||
_ = self.notify("🔄 Refreshing all connections...", severity="information")
|
||||
self.notify("🔄 Refreshing all connections...", severity="information")
|
||||
# This could trigger a full reinit if needed
|
||||
_ = event.prevent_default()
|
||||
# No else clause needed - just handle our events
|
||||
|
||||
@@ -163,12 +163,12 @@ class CollapsibleSidebar(Container):
|
||||
else:
|
||||
_ = self.remove_class("collapsed")
|
||||
|
||||
def expand(self) -> None:
|
||||
def expand_sidebar(self) -> None:
|
||||
"""Expand sidebar."""
|
||||
if self.collapsed:
|
||||
self.toggle()
|
||||
|
||||
def collapse(self) -> None:
|
||||
def collapse_sidebar(self) -> None:
|
||||
"""Collapse sidebar."""
|
||||
if not self.collapsed:
|
||||
self.toggle()
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -2,14 +2,14 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||
from typing import TYPE_CHECKING, Generic, TypeVar
|
||||
|
||||
from textual import work
|
||||
from textual.app import ComposeResult
|
||||
from textual.binding import Binding
|
||||
from textual.containers import Container
|
||||
from textual.screen import ModalScreen, Screen
|
||||
from textual.widget import Widget
|
||||
from textual.widgets import Button, DataTable, LoadingIndicator, Static
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -19,16 +19,25 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseScreen(Screen, ABC):
|
||||
class BaseScreen(Screen[object]):
|
||||
"""Base screen with common functionality."""
|
||||
|
||||
def __init__(self, storage_manager: StorageManager, **kwargs: Any) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
storage_manager: StorageManager,
|
||||
*,
|
||||
name: str | None = None,
|
||||
id: str | None = None,
|
||||
classes: str | None = None,
|
||||
**kwargs: object
|
||||
) -> None:
|
||||
"""Initialize base screen."""
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(name=name, id=id, classes=classes)
|
||||
# Ignore any additional kwargs to avoid type issues
|
||||
self.storage_manager = storage_manager
|
||||
|
||||
|
||||
class CRUDScreen(BaseScreen, Generic[T], ABC):
|
||||
class CRUDScreen(BaseScreen, Generic[T]):
|
||||
"""Base class for Create/Read/Update/Delete operations."""
|
||||
|
||||
BINDINGS = [
|
||||
@@ -39,9 +48,17 @@ class CRUDScreen(BaseScreen, Generic[T], ABC):
|
||||
Binding("escape", "app.pop_screen", "Back"),
|
||||
]
|
||||
|
||||
def __init__(self, storage_manager: StorageManager, **kwargs: Any) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
storage_manager: StorageManager,
|
||||
*,
|
||||
name: str | None = None,
|
||||
id: str | None = None,
|
||||
classes: str | None = None,
|
||||
**kwargs: object
|
||||
) -> None:
|
||||
"""Initialize CRUD screen."""
|
||||
super().__init__(storage_manager, **kwargs)
|
||||
super().__init__(storage_manager, name=name, id=id, classes=classes)
|
||||
self.items: list[T] = []
|
||||
self.selected_item: T | None = None
|
||||
self.loading = False
|
||||
@@ -77,35 +94,29 @@ class CRUDScreen(BaseScreen, Generic[T], ABC):
|
||||
table.add_columns(*self.get_table_columns())
|
||||
return table
|
||||
|
||||
@abstractmethod
|
||||
def get_table_columns(self) -> list[str]:
|
||||
"""Get table column headers."""
|
||||
pass
|
||||
raise NotImplementedError("Subclasses must implement get_table_columns")
|
||||
|
||||
@abstractmethod
|
||||
async def load_items(self) -> list[T]:
|
||||
"""Load items from storage."""
|
||||
pass
|
||||
raise NotImplementedError("Subclasses must implement load_items")
|
||||
|
||||
@abstractmethod
|
||||
def item_to_row(self, item: T) -> list[str]:
|
||||
"""Convert item to table row."""
|
||||
pass
|
||||
raise NotImplementedError("Subclasses must implement item_to_row")
|
||||
|
||||
@abstractmethod
|
||||
async def create_item_dialog(self) -> T | None:
|
||||
"""Show create item dialog."""
|
||||
pass
|
||||
raise NotImplementedError("Subclasses must implement create_item_dialog")
|
||||
|
||||
@abstractmethod
|
||||
async def edit_item_dialog(self, item: T) -> T | None:
|
||||
"""Show edit item dialog."""
|
||||
pass
|
||||
raise NotImplementedError("Subclasses must implement edit_item_dialog")
|
||||
|
||||
@abstractmethod
|
||||
async def delete_item(self, item: T) -> bool:
|
||||
"""Delete item."""
|
||||
pass
|
||||
raise NotImplementedError("Subclasses must implement delete_item")
|
||||
|
||||
def on_mount(self) -> None:
|
||||
"""Initialize screen."""
|
||||
@@ -176,17 +187,21 @@ class CRUDScreen(BaseScreen, Generic[T], ABC):
|
||||
self.refresh_items()
|
||||
|
||||
|
||||
class ListScreen(BaseScreen, Generic[T], ABC):
|
||||
class ListScreen(BaseScreen, Generic[T]):
|
||||
"""Base for paginated list views."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_manager: StorageManager,
|
||||
page_size: int = 20,
|
||||
**kwargs: Any,
|
||||
*,
|
||||
name: str | None = None,
|
||||
id: str | None = None,
|
||||
classes: str | None = None,
|
||||
**kwargs: object,
|
||||
) -> None:
|
||||
"""Initialize list screen."""
|
||||
super().__init__(storage_manager, **kwargs)
|
||||
super().__init__(storage_manager, name=name, id=id, classes=classes)
|
||||
self.page_size = page_size
|
||||
self.current_page = 0
|
||||
self.total_items = 0
|
||||
@@ -204,25 +219,21 @@ class ListScreen(BaseScreen, Generic[T], ABC):
|
||||
classes="list-container",
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_title(self) -> str:
|
||||
"""Get screen title."""
|
||||
pass
|
||||
raise NotImplementedError("Subclasses must implement get_title")
|
||||
|
||||
@abstractmethod
|
||||
def create_filters(self) -> Container:
|
||||
"""Create filter widgets."""
|
||||
pass
|
||||
raise NotImplementedError("Subclasses must implement create_filters")
|
||||
|
||||
@abstractmethod
|
||||
def create_list_view(self) -> Any:
|
||||
def create_list_view(self) -> Widget:
|
||||
"""Create list view widget."""
|
||||
pass
|
||||
raise NotImplementedError("Subclasses must implement create_list_view")
|
||||
|
||||
@abstractmethod
|
||||
async def load_page(self, page: int, page_size: int) -> tuple[list[T], int]:
|
||||
"""Load page of items."""
|
||||
pass
|
||||
raise NotImplementedError("Subclasses must implement load_page")
|
||||
|
||||
def create_pagination(self) -> Container:
|
||||
"""Create pagination controls."""
|
||||
@@ -249,10 +260,9 @@ class ListScreen(BaseScreen, Generic[T], ABC):
|
||||
finally:
|
||||
self.set_loading(False)
|
||||
|
||||
@abstractmethod
|
||||
async def update_list_view(self) -> None:
|
||||
"""Update list view with current items."""
|
||||
pass
|
||||
raise NotImplementedError("Subclasses must implement update_list_view")
|
||||
|
||||
def update_pagination_info(self) -> None:
|
||||
"""Update pagination information."""
|
||||
@@ -285,7 +295,7 @@ class ListScreen(BaseScreen, Generic[T], ABC):
|
||||
self.load_current_page()
|
||||
|
||||
|
||||
class FormScreen(ModalScreen[T], Generic[T], ABC):
|
||||
class FormScreen(ModalScreen[T], Generic[T]):
|
||||
"""Base for input forms with validation."""
|
||||
|
||||
BINDINGS = [
|
||||
@@ -294,9 +304,18 @@ class FormScreen(ModalScreen[T], Generic[T], ABC):
|
||||
Binding("enter", "save", "Save"),
|
||||
]
|
||||
|
||||
def __init__(self, item: T | None = None, **kwargs: Any) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
item: T | None = None,
|
||||
*,
|
||||
name: str | None = None,
|
||||
id: str | None = None,
|
||||
classes: str | None = None,
|
||||
**kwargs: object
|
||||
) -> None:
|
||||
"""Initialize form screen."""
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(name=name, id=id, classes=classes)
|
||||
# Ignore any additional kwargs to avoid type issues
|
||||
self.item = item
|
||||
self.is_edit_mode = item is not None
|
||||
|
||||
@@ -315,35 +334,30 @@ class FormScreen(ModalScreen[T], Generic[T], ABC):
|
||||
classes="form-container",
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_item_type(self) -> str:
|
||||
"""Get item type name for title."""
|
||||
pass
|
||||
raise NotImplementedError("Subclasses must implement get_item_type")
|
||||
|
||||
@abstractmethod
|
||||
def create_form_fields(self) -> Container:
|
||||
"""Create form input fields."""
|
||||
pass
|
||||
raise NotImplementedError("Subclasses must implement create_form_fields")
|
||||
|
||||
@abstractmethod
|
||||
def validate_form(self) -> tuple[bool, list[str]]:
|
||||
"""Validate form data."""
|
||||
pass
|
||||
raise NotImplementedError("Subclasses must implement validate_form")
|
||||
|
||||
@abstractmethod
|
||||
def get_form_data(self) -> T:
|
||||
"""Get item from form data."""
|
||||
pass
|
||||
raise NotImplementedError("Subclasses must implement get_form_data")
|
||||
|
||||
def on_mount(self) -> None:
|
||||
"""Initialize form."""
|
||||
if self.is_edit_mode and self.item:
|
||||
self.populate_form(self.item)
|
||||
|
||||
@abstractmethod
|
||||
def populate_form(self, item: T) -> None:
|
||||
"""Populate form with item data."""
|
||||
pass
|
||||
raise NotImplementedError("Subclasses must implement populate_form")
|
||||
|
||||
def action_save(self) -> None:
|
||||
"""Save form data."""
|
||||
|
||||
@@ -212,8 +212,9 @@ class CollectionOverviewScreen(Screen[None]):
|
||||
"""Update the metrics cards display."""
|
||||
try:
|
||||
dashboard_tab = self.query_one("#dashboard")
|
||||
metrics_cards = dashboard_tab.query(MetricsCard)
|
||||
if len(metrics_cards) >= 4:
|
||||
metrics_cards_query = dashboard_tab.query(MetricsCard)
|
||||
if len(metrics_cards_query) >= 4:
|
||||
metrics_cards = list(metrics_cards_query)
|
||||
self._update_card_values(metrics_cards)
|
||||
self._update_status_card(metrics_cards[3])
|
||||
except NoMatches:
|
||||
@@ -221,13 +222,13 @@ class CollectionOverviewScreen(Screen[None]):
|
||||
except Exception as exc:
|
||||
LOGGER.exception("Failed to update dashboard metrics", exc_info=exc)
|
||||
|
||||
def _update_card_values(self, metrics_cards: list) -> None:
|
||||
def _update_card_values(self, metrics_cards: list[MetricsCard]) -> None:
|
||||
"""Update individual metric card values."""
|
||||
metrics_cards[0].query_one(".metrics-value", Static).update(f"{self.total_collections:,}")
|
||||
metrics_cards[1].query_one(".metrics-value", Static).update(f"{self.total_documents:,}")
|
||||
metrics_cards[2].query_one(".metrics-value", Static).update(str(self.active_backends))
|
||||
|
||||
def _update_status_card(self, status_card: object) -> None:
|
||||
def _update_status_card(self, status_card: MetricsCard) -> None:
|
||||
"""Update the system status card."""
|
||||
if self.active_backends > 0 and self.total_collections > 0:
|
||||
status_text, status_class = "🟢 Healthy", "status-active"
|
||||
@@ -362,8 +363,10 @@ class CollectionOverviewScreen(Screen[None]):
|
||||
collections: list[CollectionInfo] = []
|
||||
|
||||
for item in overview:
|
||||
count_val = int(item.get("count", 0))
|
||||
size_mb_val = float(item.get("size_mb", 0.0))
|
||||
count_raw = item.get("count", 0)
|
||||
count_val = int(count_raw) if isinstance(count_raw, (int, str)) else 0
|
||||
size_mb_raw = item.get("size_mb", 0.0)
|
||||
size_mb_val = float(size_mb_raw) if isinstance(size_mb_raw, (int, float, str)) else 0.0
|
||||
collections.append(
|
||||
CollectionInfo(
|
||||
name=str(item.get("name", "Unknown")),
|
||||
@@ -393,8 +396,10 @@ class CollectionOverviewScreen(Screen[None]):
|
||||
collections: list[CollectionInfo] = []
|
||||
|
||||
for item in overview:
|
||||
count_val = int(item.get("count", 0))
|
||||
size_mb_val = float(item.get("size_mb", 0.0))
|
||||
count_raw = item.get("count", 0)
|
||||
count_val = int(count_raw) if isinstance(count_raw, (int, str)) else 0
|
||||
size_mb_raw = item.get("size_mb", 0.0)
|
||||
size_mb_val = float(size_mb_raw) if isinstance(size_mb_raw, (int, float, str)) else 0.0
|
||||
collection_name = str(item.get("name", "Unknown"))
|
||||
collections.append(
|
||||
CollectionInfo(
|
||||
@@ -504,9 +509,7 @@ class CollectionOverviewScreen(Screen[None]):
|
||||
def action_manage(self) -> None:
|
||||
"""Manage documents in selected collection."""
|
||||
if selected := self.get_selected_collection():
|
||||
# Get the appropriate storage backend for the collection
|
||||
storage_backend = self._get_storage_for_collection(selected)
|
||||
if storage_backend:
|
||||
if storage_backend := self._get_storage_for_collection(selected):
|
||||
from .documents import DocumentManagementScreen
|
||||
|
||||
self.app.push_screen(DocumentManagementScreen(selected, storage_backend))
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""Dialog screens for confirmations and user interactions."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.binding import Binding, BindingType
|
||||
from textual.binding import Binding
|
||||
from textual.containers import Container, Horizontal
|
||||
from textual.screen import ModalScreen, Screen
|
||||
from textual.widgets import Button, Footer, Header, LoadingIndicator, RichLog, Static
|
||||
@@ -23,7 +23,7 @@ class ConfirmDeleteScreen(Screen[None]):
|
||||
collection: CollectionInfo
|
||||
parent_screen: "CollectionOverviewScreen"
|
||||
|
||||
BINDINGS: ClassVar[list[BindingType]] = [
|
||||
BINDINGS = [
|
||||
Binding("escape", "app.pop_screen", "Cancel"),
|
||||
Binding("y", "confirm_delete", "Yes"),
|
||||
Binding("n", "app.pop_screen", "No"),
|
||||
@@ -145,7 +145,7 @@ class ConfirmDocumentDeleteScreen(Screen[None]):
|
||||
collection: CollectionInfo
|
||||
parent_screen: "DocumentManagementScreen"
|
||||
|
||||
BINDINGS: ClassVar[list[BindingType]] = [
|
||||
BINDINGS = [
|
||||
Binding("escape", "app.pop_screen", "Cancel"),
|
||||
Binding("y", "confirm_delete", "Yes"),
|
||||
Binding("n", "app.pop_screen", "No"),
|
||||
@@ -205,9 +205,12 @@ class ConfirmDocumentDeleteScreen(Screen[None]):
|
||||
loading.display = True
|
||||
|
||||
try:
|
||||
if self.parent_screen.weaviate:
|
||||
# Delete documents
|
||||
results = await self.parent_screen.weaviate.delete_documents(
|
||||
if hasattr(self.parent_screen, 'storage') and self.parent_screen.storage:
|
||||
# Delete documents via storage
|
||||
# The storage should have delete_documents method for weaviate
|
||||
storage = self.parent_screen.storage
|
||||
if hasattr(storage, 'delete_documents'):
|
||||
results = await storage.delete_documents(
|
||||
self.doc_ids,
|
||||
collection_name=self.collection["name"],
|
||||
)
|
||||
@@ -238,7 +241,7 @@ class LogViewerScreen(ModalScreen[None]):
|
||||
_log_widget: RichLog | None
|
||||
_log_file: Path | None
|
||||
|
||||
BINDINGS: ClassVar[list[BindingType]] = [
|
||||
BINDINGS = [
|
||||
Binding("escape", "close", "Close"),
|
||||
Binding("ctrl+l", "close", "Close"),
|
||||
Binding("s", "show_path", "Log File"),
|
||||
@@ -264,21 +267,21 @@ class LogViewerScreen(ModalScreen[None]):
|
||||
def on_mount(self) -> None:
|
||||
"""Attach this viewer to the parent application once mounted."""
|
||||
self._log_widget = self.query_one(RichLog)
|
||||
from ..app import CollectionManagementApp
|
||||
|
||||
if isinstance(self.app, CollectionManagementApp):
|
||||
if hasattr(self.app, 'attach_log_viewer'):
|
||||
self.app.attach_log_viewer(self)
|
||||
|
||||
def on_unmount(self) -> None:
|
||||
"""Detach from the parent application when closed."""
|
||||
from ..app import CollectionManagementApp
|
||||
|
||||
if isinstance(self.app, CollectionManagementApp):
|
||||
if hasattr(self.app, 'detach_log_viewer'):
|
||||
self.app.detach_log_viewer(self)
|
||||
|
||||
def _get_log_widget(self) -> RichLog:
|
||||
if self._log_widget is None:
|
||||
self._log_widget = self.query_one(RichLog)
|
||||
if self._log_widget is None:
|
||||
raise RuntimeError("RichLog widget not found")
|
||||
return self._log_widget
|
||||
|
||||
def replace_logs(self, lines: list[str]) -> None:
|
||||
|
||||
@@ -10,7 +10,6 @@ from textual.widgets import Button, Footer, Header, Label, LoadingIndicator, Sta
|
||||
from typing_extensions import override
|
||||
|
||||
from ....storage.base import BaseStorage
|
||||
from ....storage.weaviate import WeaviateStorage
|
||||
from ..models import CollectionInfo, DocumentInfo
|
||||
from ..widgets import EnhancedDataTable
|
||||
|
||||
@@ -115,9 +114,9 @@ class DocumentManagementScreen(Screen[None]):
|
||||
description=str(doc.get("description", "")),
|
||||
content_type=str(doc.get("content_type", "text/plain")),
|
||||
content_preview=str(doc.get("content_preview", "")),
|
||||
word_count=int(doc.get("word_count", 0))
|
||||
if str(doc.get("word_count", 0)).isdigit()
|
||||
else 0,
|
||||
word_count=(
|
||||
lambda wc_val: int(wc_val) if isinstance(wc_val, (int, str)) and str(wc_val).isdigit() else 0
|
||||
)(doc.get("word_count", 0)),
|
||||
timestamp=str(doc.get("timestamp", "")),
|
||||
)
|
||||
for i, doc in enumerate(raw_docs)
|
||||
|
||||
@@ -383,12 +383,12 @@ class IngestionScreen(ModalScreen[None]):
|
||||
# Run the Prefect flow for this backend using asyncio.run with timeout
|
||||
import asyncio
|
||||
|
||||
async def run_flow_with_timeout() -> IngestionResult:
|
||||
async def run_flow_with_timeout(current_backend: StorageBackend = backend) -> IngestionResult:
|
||||
return await asyncio.wait_for(
|
||||
create_ingestion_flow(
|
||||
source_url=source_url,
|
||||
source_type=self.selected_type,
|
||||
storage_backend=backend,
|
||||
storage_backend=current_backend,
|
||||
collection_name=final_collection_name,
|
||||
progress_callback=progress_reporter,
|
||||
),
|
||||
@@ -403,7 +403,7 @@ class IngestionScreen(ModalScreen[None]):
|
||||
if result.error_messages:
|
||||
flow_errors.extend([f"{backend.value}: {err}" for err in result.error_messages])
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
except TimeoutError:
|
||||
error_msg = f"{backend.value}: Timeout after 10 minutes"
|
||||
flow_errors.append(error_msg)
|
||||
progress_reporter(0, f"❌ {backend.value} timed out")
|
||||
|
||||
@@ -112,7 +112,7 @@ class SearchScreen(Screen[None]):
|
||||
|
||||
async def search_collection(self, query: str) -> None:
|
||||
"""Search the collection."""
|
||||
loading = self.query_one("#loading")
|
||||
loading = self.query_one("#loading", LoadingIndicator)
|
||||
table = self.query_one("#results_table", EnhancedDataTable)
|
||||
status = self.query_one("#search_status", Static)
|
||||
|
||||
@@ -127,7 +127,7 @@ class SearchScreen(Screen[None]):
|
||||
finally:
|
||||
loading.display = False
|
||||
|
||||
def _setup_search_ui(self, loading, table, status, query: str) -> None:
|
||||
def _setup_search_ui(self, loading: LoadingIndicator, table: EnhancedDataTable, status: Static, query: str) -> None:
|
||||
"""Setup the search UI elements."""
|
||||
loading.display = True
|
||||
status.update(f"🔍 Searching for '{query}'...")
|
||||
@@ -142,7 +142,7 @@ class SearchScreen(Screen[None]):
|
||||
return await self.search_openwebui(query)
|
||||
return []
|
||||
|
||||
def _populate_results_table(self, table, results: list[dict[str, str | float]]) -> None:
|
||||
def _populate_results_table(self, table: EnhancedDataTable, results: list[dict[str, str | float]]) -> None:
|
||||
"""Populate the results table with search results."""
|
||||
for result in results:
|
||||
row_data = self._format_result_row(result)
|
||||
@@ -183,7 +183,7 @@ class SearchScreen(Screen[None]):
|
||||
content = str(content) if content is not None else ""
|
||||
return f"{content[:60]}..." if len(content) > 60 else content
|
||||
|
||||
def _format_score(self, score) -> str:
|
||||
def _format_score(self, score: object) -> str:
|
||||
"""Format search score for display."""
|
||||
if isinstance(score, (int, float)):
|
||||
return f"{score:.3f}"
|
||||
@@ -193,7 +193,7 @@ class SearchScreen(Screen[None]):
|
||||
return str(score)
|
||||
|
||||
def _update_search_status(
|
||||
self, status, query: str, results: list[dict[str, str | float]], table
|
||||
self, status: Static, query: str, results: list[dict[str, str | float]], table: EnhancedDataTable
|
||||
) -> None:
|
||||
"""Update search status and notifications based on results."""
|
||||
if not results:
|
||||
|
||||
@@ -1106,7 +1106,7 @@ def get_css_for_theme(theme_type: ThemeType) -> str:
|
||||
return css
|
||||
|
||||
|
||||
def apply_theme_to_app(app, theme_type: ThemeType) -> None:
|
||||
def apply_theme_to_app(app: object, theme_type: ThemeType) -> None:
|
||||
"""Apply a theme to a Textual app instance."""
|
||||
try:
|
||||
css = set_theme(theme_type)
|
||||
@@ -1114,8 +1114,8 @@ def apply_theme_to_app(app, theme_type: ThemeType) -> None:
|
||||
app.stylesheet.clear()
|
||||
app.stylesheet.parse(css)
|
||||
elif hasattr(app, "CSS"):
|
||||
app.CSS = css
|
||||
else:
|
||||
setattr(app, "CSS", css)
|
||||
elif hasattr(app, "refresh"):
|
||||
# Fallback: try to refresh the app with new CSS
|
||||
app.refresh()
|
||||
except Exception as e:
|
||||
@@ -1127,7 +1127,7 @@ def apply_theme_to_app(app, theme_type: ThemeType) -> None:
|
||||
class ThemeSwitcher:
|
||||
"""Helper class for managing theme switching in TUI applications."""
|
||||
|
||||
def __init__(self, app=None):
|
||||
def __init__(self, app: object | None = None) -> None:
|
||||
self.app = app
|
||||
self.theme_history = [ThemeType.DARK]
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -8,12 +8,10 @@ from logging import Logger
|
||||
from logging.handlers import QueueHandler, RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import NamedTuple, cast
|
||||
from typing import NamedTuple
|
||||
|
||||
from ....config import configure_prefect, get_settings
|
||||
from ....core.models import StorageBackend
|
||||
from ....storage.openwebui import OpenWebUIStorage
|
||||
from ....storage.weaviate import WeaviateStorage
|
||||
from .storage_manager import StorageManager
|
||||
|
||||
|
||||
@@ -113,11 +111,16 @@ async def run_textual_tui() -> None:
|
||||
)
|
||||
|
||||
# Get individual storage instances for backward compatibility
|
||||
weaviate = cast(WeaviateStorage | None, storage_manager.get_backend(StorageBackend.WEAVIATE))
|
||||
openwebui = cast(
|
||||
OpenWebUIStorage | None, storage_manager.get_backend(StorageBackend.OPEN_WEBUI)
|
||||
)
|
||||
r2r = storage_manager.get_backend(StorageBackend.R2R)
|
||||
from ....storage.openwebui import OpenWebUIStorage
|
||||
from ....storage.weaviate import WeaviateStorage
|
||||
|
||||
weaviate_backend = storage_manager.get_backend(StorageBackend.WEAVIATE)
|
||||
openwebui_backend = storage_manager.get_backend(StorageBackend.OPEN_WEBUI)
|
||||
r2r_backend = storage_manager.get_backend(StorageBackend.R2R)
|
||||
|
||||
# Type-safe casting to specific storage types
|
||||
weaviate = weaviate_backend if isinstance(weaviate_backend, WeaviateStorage) else None
|
||||
openwebui = openwebui_backend if isinstance(openwebui_backend, OpenWebUIStorage) else None
|
||||
|
||||
# Import here to avoid circular import
|
||||
from ..app import CollectionManagementApp
|
||||
@@ -125,7 +128,7 @@ async def run_textual_tui() -> None:
|
||||
storage_manager,
|
||||
weaviate,
|
||||
openwebui,
|
||||
r2r,
|
||||
r2r_backend,
|
||||
log_queue=logging_context.queue,
|
||||
log_formatter=logging_context.formatter,
|
||||
log_file=logging_context.log_file,
|
||||
|
||||
@@ -11,12 +11,13 @@ from ....core.exceptions import StorageError
|
||||
from ....core.models import Document, StorageBackend, StorageConfig
|
||||
from ..models import CollectionInfo, StorageCapabilities
|
||||
|
||||
from ....storage.base import BaseStorage
|
||||
from ....storage.openwebui import OpenWebUIStorage
|
||||
from ....storage.r2r.storage import R2RStorage
|
||||
from ....storage.weaviate import WeaviateStorage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ....config.settings import Settings
|
||||
from ....storage.weaviate import WeaviateStorage
|
||||
from ....storage.r2r.storage import R2RStorage
|
||||
from ....storage.openwebui import OpenWebUIStorage
|
||||
from ....storage.base import BaseStorage
|
||||
|
||||
|
||||
class StorageBackendProtocol(Protocol):
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
import json
|
||||
from typing import cast
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Container, Horizontal
|
||||
@@ -61,9 +62,17 @@ class ScrapeOptionsForm(Container):
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str | None = None,
|
||||
id: str | None = None,
|
||||
classes: str | None = None,
|
||||
disabled: bool = False,
|
||||
markup: bool = True,
|
||||
) -> None:
|
||||
"""Initialize scrape options form."""
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(name=name, id=id, classes=classes, disabled=disabled, markup=markup)
|
||||
|
||||
@override
|
||||
def compose(self) -> ComposeResult:
|
||||
@@ -136,10 +145,8 @@ class ScrapeOptionsForm(Container):
|
||||
classes="form-section",
|
||||
)
|
||||
|
||||
def get_scrape_options(self) -> dict[str, Any]:
|
||||
def get_scrape_options(self) -> dict[str, object]:
|
||||
"""Get scraping options from form."""
|
||||
options: dict[str, Any] = {}
|
||||
|
||||
# Collect formats
|
||||
formats = []
|
||||
if self.query_one("#format_markdown", Checkbox).value:
|
||||
@@ -148,11 +155,12 @@ class ScrapeOptionsForm(Container):
|
||||
formats.append("html")
|
||||
if self.query_one("#format_screenshot", Checkbox).value:
|
||||
formats.append("screenshot")
|
||||
options["formats"] = formats
|
||||
|
||||
# Content filtering
|
||||
options["only_main_content"] = self.query_one("#only_main_content", Switch).value
|
||||
|
||||
options: dict[str, object] = {
|
||||
"formats": formats,
|
||||
"only_main_content": self.query_one(
|
||||
"#only_main_content", Switch
|
||||
).value,
|
||||
}
|
||||
include_tags_input = self.query_one("#include_tags", Input).value
|
||||
if include_tags_input.strip():
|
||||
options["include_tags"] = [tag.strip() for tag in include_tags_input.split(",")]
|
||||
@@ -171,22 +179,26 @@ class ScrapeOptionsForm(Container):
|
||||
|
||||
return options
|
||||
|
||||
def set_scrape_options(self, options: dict[str, Any]) -> None:
|
||||
def set_scrape_options(self, options: dict[str, object]) -> None:
|
||||
"""Set form values from options."""
|
||||
# Set formats
|
||||
formats = options.get("formats", ["markdown"])
|
||||
self.query_one("#format_markdown", Checkbox).value = "markdown" in formats
|
||||
self.query_one("#format_html", Checkbox).value = "html" in formats
|
||||
self.query_one("#format_screenshot", Checkbox).value = "screenshot" in formats
|
||||
formats_list = formats if isinstance(formats, list) else []
|
||||
self.query_one("#format_markdown", Checkbox).value = "markdown" in formats_list
|
||||
self.query_one("#format_html", Checkbox).value = "html" in formats_list
|
||||
self.query_one("#format_screenshot", Checkbox).value = "screenshot" in formats_list
|
||||
|
||||
# Set content filtering
|
||||
self.query_one("#only_main_content", Switch).value = options.get("only_main_content", True)
|
||||
main_content_val = options.get("only_main_content", True)
|
||||
self.query_one("#only_main_content", Switch).value = bool(main_content_val)
|
||||
|
||||
if include_tags := options.get("include_tags", []):
|
||||
self.query_one("#include_tags", Input).value = ", ".join(include_tags)
|
||||
include_list = include_tags if isinstance(include_tags, list) else []
|
||||
self.query_one("#include_tags", Input).value = ", ".join(str(tag) for tag in include_list)
|
||||
|
||||
if exclude_tags := options.get("exclude_tags", []):
|
||||
self.query_one("#exclude_tags", Input).value = ", ".join(exclude_tags)
|
||||
exclude_list = exclude_tags if isinstance(exclude_tags, list) else []
|
||||
self.query_one("#exclude_tags", Input).value = ", ".join(str(tag) for tag in exclude_list)
|
||||
|
||||
# Set performance
|
||||
wait_for = options.get("wait_for")
|
||||
@@ -231,9 +243,17 @@ class MapOptionsForm(Container):
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str | None = None,
|
||||
id: str | None = None,
|
||||
classes: str | None = None,
|
||||
disabled: bool = False,
|
||||
markup: bool = True,
|
||||
) -> None:
|
||||
"""Initialize map options form."""
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(name=name, id=id, classes=classes, disabled=disabled, markup=markup)
|
||||
|
||||
@override
|
||||
def compose(self) -> ComposeResult:
|
||||
@@ -286,9 +306,9 @@ class MapOptionsForm(Container):
|
||||
classes="form-section",
|
||||
)
|
||||
|
||||
def get_map_options(self) -> dict[str, Any]:
|
||||
def get_map_options(self) -> dict[str, object]:
|
||||
"""Get mapping options from form."""
|
||||
options: dict[str, Any] = {}
|
||||
options: dict[str, object] = {}
|
||||
|
||||
# Discovery settings
|
||||
search_pattern = self.query_one("#search_pattern", Input).value
|
||||
@@ -314,12 +334,13 @@ class MapOptionsForm(Container):
|
||||
|
||||
return options
|
||||
|
||||
def set_map_options(self, options: dict[str, Any]) -> None:
|
||||
def set_map_options(self, options: dict[str, object]) -> None:
|
||||
"""Set form values from options."""
|
||||
if search := options.get("search"):
|
||||
self.query_one("#search_pattern", Input).value = str(search)
|
||||
|
||||
self.query_one("#include_subdomains", Switch).value = options.get("include_subdomains", False)
|
||||
subdomains_val = options.get("include_subdomains", False)
|
||||
self.query_one("#include_subdomains", Switch).value = bool(subdomains_val)
|
||||
|
||||
# Set limits
|
||||
limit = options.get("limit")
|
||||
@@ -373,9 +394,17 @@ class ExtractOptionsForm(Container):
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str | None = None,
|
||||
id: str | None = None,
|
||||
classes: str | None = None,
|
||||
disabled: bool = False,
|
||||
markup: bool = True,
|
||||
) -> None:
|
||||
"""Initialize extract options form."""
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(name=name, id=id, classes=classes, disabled=disabled, markup=markup)
|
||||
|
||||
@override
|
||||
def compose(self) -> ComposeResult:
|
||||
@@ -429,9 +458,9 @@ class ExtractOptionsForm(Container):
|
||||
classes="form-section",
|
||||
)
|
||||
|
||||
def get_extract_options(self) -> dict[str, Any]:
|
||||
def get_extract_options(self) -> dict[str, object]:
|
||||
"""Get extraction options from form."""
|
||||
options: dict[str, Any] = {}
|
||||
options: dict[str, object] = {}
|
||||
|
||||
# Extract prompt
|
||||
prompt = self.query_one("#extract_prompt", TextArea).text
|
||||
@@ -442,8 +471,6 @@ class ExtractOptionsForm(Container):
|
||||
schema_text = self.query_one("#extract_schema", TextArea).text
|
||||
if schema_text.strip():
|
||||
try:
|
||||
import json
|
||||
|
||||
schema = json.loads(schema_text)
|
||||
options["extract_schema"] = schema
|
||||
except json.JSONDecodeError:
|
||||
@@ -452,7 +479,7 @@ class ExtractOptionsForm(Container):
|
||||
|
||||
return options
|
||||
|
||||
def set_extract_options(self, options: dict[str, Any]) -> None:
|
||||
def set_extract_options(self, options: dict[str, object]) -> None:
|
||||
"""Set form values from options."""
|
||||
if prompt := options.get("extract_prompt"):
|
||||
self.query_one("#extract_prompt", TextArea).text = str(prompt)
|
||||
@@ -551,9 +578,17 @@ class FirecrawlConfigWidget(Container):
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str | None = None,
|
||||
id: str | None = None,
|
||||
classes: str | None = None,
|
||||
disabled: bool = False,
|
||||
markup: bool = True,
|
||||
) -> None:
|
||||
"""Initialize Firecrawl config widget."""
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(name=name, id=id, classes=classes, disabled=disabled, markup=markup)
|
||||
self.current_tab = "scrape"
|
||||
|
||||
@override
|
||||
@@ -604,7 +639,7 @@ class FirecrawlConfigWidget(Container):
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
"""Handle button presses."""
|
||||
if event.button.id.startswith("tab_"):
|
||||
if event.button.id and event.button.id.startswith("tab_"):
|
||||
tab_name = event.button.id[4:] # Remove "tab_" prefix
|
||||
self.show_tab(tab_name)
|
||||
|
||||
@@ -622,15 +657,15 @@ class FirecrawlConfigWidget(Container):
|
||||
pass
|
||||
elif self.current_tab == "map":
|
||||
try:
|
||||
form = self.query_one("#map_form", MapOptionsForm)
|
||||
map_opts = form.get_map_options()
|
||||
map_form = self.query_one("#map_form", MapOptionsForm)
|
||||
map_opts = map_form.get_map_options()
|
||||
options.update(cast(FirecrawlOptions, map_opts))
|
||||
except Exception:
|
||||
pass
|
||||
elif self.current_tab == "extract":
|
||||
try:
|
||||
form = self.query_one("#extract_form", ExtractOptionsForm)
|
||||
extract_opts = form.get_extract_options()
|
||||
extract_form = self.query_one("#extract_form", ExtractOptionsForm)
|
||||
extract_opts = extract_form.get_extract_options()
|
||||
options.update(cast(FirecrawlOptions, extract_opts))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -98,9 +98,11 @@ class ChunkViewer(Container):
|
||||
"id": str(chunk_data.get("id", "")),
|
||||
"document_id": self.document_id,
|
||||
"content": str(chunk_data.get("text", "")),
|
||||
"start_index": int(chunk_data.get("start_index", 0)),
|
||||
"end_index": int(chunk_data.get("end_index", 0)),
|
||||
"metadata": dict(chunk_data.get("metadata", {})),
|
||||
"start_index": (lambda si: int(si) if isinstance(si, (int, str)) else 0)(chunk_data.get("start_index", 0)),
|
||||
"end_index": (lambda ei: int(ei) if isinstance(ei, (int, str)) else 0)(chunk_data.get("end_index", 0)),
|
||||
"metadata": (
|
||||
dict(metadata_val) if (metadata_val := chunk_data.get("metadata")) and isinstance(metadata_val, dict) else {}
|
||||
),
|
||||
}
|
||||
self.chunks.append(chunk_info)
|
||||
|
||||
@@ -217,6 +219,8 @@ class EntityGraph(Container):
|
||||
|
||||
# Parse entities from R2R response
|
||||
entities_list = entities_data.get("entities", [])
|
||||
if not isinstance(entities_list, list):
|
||||
entities_list = []
|
||||
for entity_data in entities_list:
|
||||
entity_info: EntityInfo = {
|
||||
"id": str(entity_data.get("id", "")),
|
||||
@@ -260,7 +264,7 @@ class EntityGraph(Container):
|
||||
|
||||
tree.root.expand()
|
||||
|
||||
def on_tree_node_selected(self, event: Tree.NodeSelected) -> None:
|
||||
def on_tree_node_selected(self, event: Tree.NodeSelected[EntityInfo]) -> None:
|
||||
"""Handle entity selection."""
|
||||
if hasattr(event.node, "data") and event.node.data:
|
||||
entity = event.node.data
|
||||
@@ -488,7 +492,8 @@ class DocumentOverview(Container):
|
||||
doc_table = self.query_one("#doc_info_table", DataTable)
|
||||
doc_table.add_columns("Property", "Value")
|
||||
|
||||
document_info = overview_data.get("document", {})
|
||||
document_info_raw = overview_data.get("document", {})
|
||||
document_info = document_info_raw if isinstance(document_info_raw, dict) else {}
|
||||
doc_table.add_row("ID", str(document_info.get("id", "N/A")))
|
||||
doc_table.add_row("Title", str(document_info.get("title", "N/A")))
|
||||
doc_table.add_row("Created", str(document_info.get("created_at", "N/A")))
|
||||
|
||||
@@ -4,15 +4,43 @@ from __future__ import annotations
|
||||
|
||||
from contextlib import ExitStack
|
||||
|
||||
from prefect.settings import (
|
||||
PREFECT_API_KEY,
|
||||
PREFECT_API_URL,
|
||||
PREFECT_DEFAULT_WORK_POOL_NAME,
|
||||
Setting,
|
||||
temporary_settings,
|
||||
)
|
||||
from prefect.settings import Setting, temporary_settings
|
||||
|
||||
from .settings import Settings, get_settings
|
||||
|
||||
# Import Prefect settings with version compatibility - avoid static analysis issues
|
||||
def _setup_prefect_settings() -> tuple[object, object, object]:
|
||||
"""Setup Prefect settings with proper fallbacks."""
|
||||
try:
|
||||
import prefect.settings as ps
|
||||
|
||||
# Try to get the settings directly
|
||||
api_key = getattr(ps, "PREFECT_API_KEY", None)
|
||||
api_url = getattr(ps, "PREFECT_API_URL", None)
|
||||
work_pool = getattr(ps, "PREFECT_DEFAULT_WORK_POOL_NAME", None)
|
||||
|
||||
if api_key is not None:
|
||||
return api_key, api_url, work_pool
|
||||
|
||||
# Fallback to registry-based approach
|
||||
registry = getattr(ps, "PREFECT_SETTING_REGISTRY", None)
|
||||
if registry is not None:
|
||||
Setting = getattr(ps, "Setting", None)
|
||||
if Setting is not None:
|
||||
api_key = registry.get("PREFECT_API_KEY") or Setting("PREFECT_API_KEY", type_=str, default=None)
|
||||
api_url = registry.get("PREFECT_API_URL") or Setting("PREFECT_API_URL", type_=str, default=None)
|
||||
work_pool = registry.get("PREFECT_DEFAULT_WORK_POOL_NAME") or Setting("PREFECT_DEFAULT_WORK_POOL_NAME", type_=str, default=None)
|
||||
return api_key, api_url, work_pool
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Ultimate fallback
|
||||
return None, None, None
|
||||
|
||||
PREFECT_API_KEY, PREFECT_API_URL, PREFECT_DEFAULT_WORK_POOL_NAME = _setup_prefect_settings()
|
||||
|
||||
# Import after Prefect settings setup to avoid circular dependencies
|
||||
from .settings import Settings, get_settings # noqa: E402
|
||||
|
||||
__all__ = ["Settings", "get_settings", "configure_prefect"]
|
||||
|
||||
@@ -25,11 +53,11 @@ def configure_prefect(settings: Settings) -> None:
|
||||
|
||||
overrides: dict[Setting, str] = {}
|
||||
|
||||
if settings.prefect_api_url is not None:
|
||||
if settings.prefect_api_url is not None and PREFECT_API_URL is not None and isinstance(PREFECT_API_URL, Setting):
|
||||
overrides[PREFECT_API_URL] = str(settings.prefect_api_url)
|
||||
if settings.prefect_api_key:
|
||||
if settings.prefect_api_key and PREFECT_API_KEY is not None and isinstance(PREFECT_API_KEY, Setting):
|
||||
overrides[PREFECT_API_KEY] = settings.prefect_api_key
|
||||
if settings.prefect_work_pool:
|
||||
if settings.prefect_work_pool and PREFECT_DEFAULT_WORK_POOL_NAME is not None and isinstance(PREFECT_DEFAULT_WORK_POOL_NAME, Setting):
|
||||
overrides[PREFECT_DEFAULT_WORK_POOL_NAME] = settings.prefect_work_pool
|
||||
|
||||
if not overrides:
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -3,6 +3,7 @@
|
||||
from functools import lru_cache
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from prefect.variables import Variable
|
||||
from pydantic import Field, HttpUrl, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
@@ -147,3 +148,77 @@ def get_settings() -> Settings:
|
||||
Settings instance
|
||||
"""
|
||||
return Settings()
|
||||
|
||||
|
||||
class PrefectVariableConfig:
|
||||
"""Helper class for managing Prefect variables with fallbacks to settings."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._settings = get_settings()
|
||||
self._variable_names = [
|
||||
"default_batch_size", "max_file_size", "max_crawl_depth", "max_crawl_pages",
|
||||
"default_storage_backend", "default_collection_prefix", "max_concurrent_tasks",
|
||||
"request_timeout", "default_schedule_interval"
|
||||
]
|
||||
|
||||
def _get_fallback_value(self, name: str, default_value: object = None) -> object:
|
||||
"""Get fallback value from settings or default."""
|
||||
return default_value or getattr(self._settings, name, default_value)
|
||||
|
||||
def get_with_fallback(self, name: str, default_value: str | int | float | None = None) -> str | int | float | None:
|
||||
"""Get variable value with fallback synchronously."""
|
||||
fallback = self._get_fallback_value(name, default_value)
|
||||
# Ensure fallback is a type that Variable expects
|
||||
variable_fallback = str(fallback) if fallback is not None else None
|
||||
try:
|
||||
result = Variable.get(name, default=variable_fallback)
|
||||
# Variable can return various types, convert to our expected types
|
||||
if isinstance(result, (str, int, float)):
|
||||
return result
|
||||
elif result is None:
|
||||
return None
|
||||
else:
|
||||
# Convert other types to string
|
||||
return str(result)
|
||||
except Exception:
|
||||
# Return fallback with proper type
|
||||
if isinstance(fallback, (str, int, float)) or fallback is None:
|
||||
return fallback
|
||||
return str(fallback) if fallback is not None else None
|
||||
|
||||
async def get_with_fallback_async(self, name: str, default_value: str | int | float | None = None) -> str | int | float | None:
|
||||
"""Get variable value with fallback asynchronously."""
|
||||
fallback = self._get_fallback_value(name, default_value)
|
||||
variable_fallback = str(fallback) if fallback is not None else None
|
||||
try:
|
||||
result = await Variable.aget(name, default=variable_fallback)
|
||||
# Variable can return various types, convert to our expected types
|
||||
if isinstance(result, (str, int, float)):
|
||||
return result
|
||||
elif result is None:
|
||||
return None
|
||||
else:
|
||||
# Convert other types to string
|
||||
return str(result)
|
||||
except Exception:
|
||||
# Return fallback with proper type
|
||||
if isinstance(fallback, (str, int, float)) or fallback is None:
|
||||
return fallback
|
||||
return str(fallback) if fallback is not None else None
|
||||
|
||||
def get_ingestion_config(self) -> dict[str, str | int | float | None]:
|
||||
"""Get all ingestion-related configuration variables synchronously."""
|
||||
return {name: self.get_with_fallback(name) for name in self._variable_names}
|
||||
|
||||
async def get_ingestion_config_async(self) -> dict[str, str | int | float | None]:
|
||||
"""Get all ingestion-related configuration variables asynchronously."""
|
||||
result = {}
|
||||
for name in self._variable_names:
|
||||
result[name] = await self.get_with_fallback_async(name)
|
||||
return result
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_prefect_config() -> PrefectVariableConfig:
|
||||
"""Get cached Prefect variable configuration helper."""
|
||||
return PrefectVariableConfig()
|
||||
|
||||
Binary file not shown.
@@ -5,7 +5,8 @@ from enum import Enum
|
||||
from typing import Annotated, TypedDict
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
from prefect.blocks.core import Block
|
||||
from pydantic import BaseModel, Field, HttpUrl, SecretStr
|
||||
|
||||
|
||||
class IngestionStatus(str, Enum):
|
||||
@@ -44,19 +45,27 @@ class VectorConfig(BaseModel):
|
||||
batch_size: Annotated[int, Field(gt=0, le=1000)] = 100
|
||||
|
||||
|
||||
class StorageConfig(BaseModel):
|
||||
class StorageConfig(Block):
|
||||
"""Configuration for storage backend."""
|
||||
|
||||
_block_type_name = "Storage Configuration"
|
||||
_block_type_slug = "storage-config"
|
||||
_description = "Configures storage backend connections and settings for document ingestion"
|
||||
|
||||
backend: StorageBackend
|
||||
endpoint: HttpUrl
|
||||
api_key: str | None = Field(default=None)
|
||||
api_key: SecretStr | None = Field(default=None)
|
||||
collection_name: str = Field(default="documents")
|
||||
batch_size: Annotated[int, Field(gt=0, le=1000)] = 100
|
||||
|
||||
|
||||
class FirecrawlConfig(BaseModel):
|
||||
class FirecrawlConfig(Block):
|
||||
"""Configuration for Firecrawl ingestion (operational parameters only)."""
|
||||
|
||||
_block_type_name = "Firecrawl Configuration"
|
||||
_block_type_slug = "firecrawl-config"
|
||||
_description = "Configures Firecrawl web scraping and crawling parameters"
|
||||
|
||||
formats: list[str] = Field(default_factory=lambda: ["markdown", "html"])
|
||||
max_depth: Annotated[int, Field(ge=1, le=20)] = 5
|
||||
limit: Annotated[int, Field(ge=1, le=1000)] = 100
|
||||
@@ -64,9 +73,13 @@ class FirecrawlConfig(BaseModel):
|
||||
include_subdomains: bool = Field(default=False)
|
||||
|
||||
|
||||
class RepomixConfig(BaseModel):
|
||||
class RepomixConfig(Block):
|
||||
"""Configuration for Repomix ingestion."""
|
||||
|
||||
_block_type_name = "Repomix Configuration"
|
||||
_block_type_slug = "repomix-config"
|
||||
_description = "Configures repository ingestion patterns and file processing settings"
|
||||
|
||||
include_patterns: list[str] = Field(
|
||||
default_factory=lambda: ["*.py", "*.js", "*.ts", "*.md", "*.yaml", "*.json"]
|
||||
)
|
||||
@@ -77,9 +90,13 @@ class RepomixConfig(BaseModel):
|
||||
respect_gitignore: bool = Field(default=True)
|
||||
|
||||
|
||||
class R2RConfig(BaseModel):
|
||||
class R2RConfig(Block):
|
||||
"""Configuration for R2R ingestion."""
|
||||
|
||||
_block_type_name = "R2R Configuration"
|
||||
_block_type_slug = "r2r-config"
|
||||
_description = "Configures R2R-specific ingestion settings including chunking and graph enrichment"
|
||||
|
||||
chunk_size: Annotated[int, Field(ge=100, le=8192)] = 1000
|
||||
chunk_overlap: Annotated[int, Field(ge=0, le=1000)] = 200
|
||||
enable_graph_enrichment: bool = Field(default=False)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -3,12 +3,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Literal, assert_never, cast
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Literal, TypeAlias, assert_never, cast
|
||||
|
||||
from prefect import flow, get_run_logger, task
|
||||
from prefect.cache_policies import NO_CACHE
|
||||
from prefect.futures import wait
|
||||
from prefect.blocks.core import Block
|
||||
from prefect.variables import Variable
|
||||
from pydantic.types import SecretStr
|
||||
|
||||
from ..config.settings import Settings
|
||||
from ..core.exceptions import IngestionError
|
||||
@@ -31,8 +32,14 @@ from ..utils.metadata_tagger import MetadataTagger
|
||||
|
||||
SourceTypeLiteral = Literal["web", "repository", "documentation"]
|
||||
StorageBackendLiteral = Literal["weaviate", "open_webui", "r2r"]
|
||||
SourceTypeLike = IngestionSource | SourceTypeLiteral
|
||||
StorageBackendLike = StorageBackend | StorageBackendLiteral
|
||||
SourceTypeLike: TypeAlias = IngestionSource | SourceTypeLiteral
|
||||
StorageBackendLike: TypeAlias = StorageBackend | StorageBackendLiteral
|
||||
|
||||
|
||||
def _safe_cache_key(prefix: str, params: dict[str, object], key: str) -> str:
|
||||
"""Create a type-safe cache key from task parameters."""
|
||||
value = params.get(key, "")
|
||||
return f"{prefix}_{hash(str(value))}"
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -65,16 +72,22 @@ async def validate_source_task(source_url: str, source_type: IngestionSource) ->
|
||||
|
||||
|
||||
@task(name="initialize_storage", retries=3, retry_delay_seconds=5, tags=["storage"])
|
||||
async def initialize_storage_task(config: StorageConfig) -> BaseStorage:
|
||||
async def initialize_storage_task(config: StorageConfig | str) -> BaseStorage:
|
||||
"""
|
||||
Initialize storage backend.
|
||||
|
||||
Args:
|
||||
config: Storage configuration
|
||||
config: Storage configuration block or block name
|
||||
|
||||
Returns:
|
||||
Initialized storage adapter
|
||||
"""
|
||||
# Load block if string provided
|
||||
if isinstance(config, str):
|
||||
# Use Block.aload with type slug for better type inference
|
||||
loaded_block = await Block.aload(f"storage-config/{config}")
|
||||
config = cast(StorageConfig, loaded_block)
|
||||
|
||||
if config.backend == StorageBackend.WEAVIATE:
|
||||
storage = WeaviateStorage(config)
|
||||
elif config.backend == StorageBackend.OPEN_WEBUI:
|
||||
@@ -90,38 +103,48 @@ async def initialize_storage_task(config: StorageConfig) -> BaseStorage:
|
||||
return storage
|
||||
|
||||
|
||||
@task(name="map_firecrawl_site", retries=2, retry_delay_seconds=15, tags=["firecrawl", "map"])
|
||||
async def map_firecrawl_site_task(source_url: str, config: FirecrawlConfig) -> list[str]:
|
||||
@task(name="map_firecrawl_site", retries=2, retry_delay_seconds=15, tags=["firecrawl", "map"],
|
||||
cache_key_fn=lambda ctx, p: _safe_cache_key("firecrawl_map", p, "source_url"))
|
||||
async def map_firecrawl_site_task(source_url: str, config: FirecrawlConfig | str) -> list[str]:
|
||||
"""Map a site using Firecrawl and return discovered URLs."""
|
||||
# Load block if string provided
|
||||
if isinstance(config, str):
|
||||
# Use Block.aload with type slug for better type inference
|
||||
loaded_block = await Block.aload(f"firecrawl-config/{config}")
|
||||
config = cast(FirecrawlConfig, loaded_block)
|
||||
|
||||
ingestor = FirecrawlIngestor(config)
|
||||
mapped = await ingestor.map_site(source_url)
|
||||
return mapped or [source_url]
|
||||
|
||||
|
||||
@task(name="filter_existing_documents", retries=1, retry_delay_seconds=5, tags=["r2r", "dedup"], cache_policy=NO_CACHE)
|
||||
@task(name="filter_existing_documents", retries=1, retry_delay_seconds=5, tags=["dedup"],
|
||||
cache_key_fn=lambda ctx, p: _safe_cache_key("filter_docs", p, "urls")) # Cache based on URL list
|
||||
async def filter_existing_documents_task(
|
||||
urls: list[str],
|
||||
storage_client: R2RStorageType,
|
||||
storage_client: BaseStorage,
|
||||
stale_after_days: int = 30,
|
||||
*,
|
||||
collection_name: str | None = None,
|
||||
) -> list[str]:
|
||||
"""Filter URLs whose documents are missing or stale in R2R."""
|
||||
"""Filter URLs to only those that need scraping (missing or stale in storage)."""
|
||||
logger = get_run_logger()
|
||||
cutoff = datetime.now(UTC) - timedelta(days=stale_after_days)
|
||||
eligible: list[str] = []
|
||||
|
||||
for url in urls:
|
||||
document_id = str(FirecrawlIngestor.compute_document_id(url))
|
||||
existing: Document | None = await storage_client.retrieve(document_id)
|
||||
if existing is None:
|
||||
eligible.append(url)
|
||||
continue
|
||||
exists = await storage_client.check_exists(
|
||||
document_id,
|
||||
collection_name=collection_name,
|
||||
stale_after_days=stale_after_days
|
||||
)
|
||||
|
||||
timestamp = existing.metadata["timestamp"]
|
||||
if timestamp < cutoff:
|
||||
if not exists:
|
||||
eligible.append(url)
|
||||
|
||||
if skipped := len(urls) - len(eligible):
|
||||
logger.info("Skipping %s up-to-date pages", skipped)
|
||||
skipped = len(urls) - len(eligible)
|
||||
if skipped > 0:
|
||||
logger.info("Skipping %s up-to-date documents in %s", skipped, storage_client.display_name)
|
||||
|
||||
return eligible
|
||||
|
||||
@@ -134,7 +157,8 @@ async def scrape_firecrawl_batch_task(
|
||||
) -> list[FirecrawlPage]:
|
||||
"""Scrape a batch of URLs via Firecrawl."""
|
||||
ingestor = FirecrawlIngestor(config)
|
||||
return await ingestor.scrape_pages(batch_urls)
|
||||
result: list[FirecrawlPage] = await ingestor.scrape_pages(batch_urls)
|
||||
return result
|
||||
|
||||
|
||||
@task(name="annotate_firecrawl_metadata", retries=1, retry_delay_seconds=10, tags=["metadata"])
|
||||
@@ -153,7 +177,8 @@ async def annotate_firecrawl_metadata_task(
|
||||
|
||||
settings = get_settings()
|
||||
async with MetadataTagger(llm_endpoint=str(settings.llm_endpoint)) as tagger:
|
||||
return await tagger.tag_batch(documents)
|
||||
tagged_documents: list[Document] = await tagger.tag_batch(documents)
|
||||
return tagged_documents
|
||||
except IngestionError as exc: # pragma: no cover - logging side effect
|
||||
logger = get_run_logger()
|
||||
logger.warning("Metadata tagging failed: %s", exc)
|
||||
@@ -164,7 +189,7 @@ async def annotate_firecrawl_metadata_task(
|
||||
return documents
|
||||
|
||||
|
||||
@task(name="upsert_r2r_documents", retries=2, retry_delay_seconds=20, tags=["storage", "r2r"], cache_policy=NO_CACHE)
|
||||
@task(name="upsert_r2r_documents", retries=2, retry_delay_seconds=20, tags=["storage", "r2r"])
|
||||
async def upsert_r2r_documents_task(
|
||||
storage_client: R2RStorageType,
|
||||
documents: list[Document],
|
||||
@@ -187,12 +212,14 @@ async def upsert_r2r_documents_task(
|
||||
return processed, failed
|
||||
|
||||
|
||||
@task(name="ingest_documents", retries=2, retry_delay_seconds=30, tags=["ingestion"], cache_policy=NO_CACHE)
|
||||
@task(name="ingest_documents", retries=2, retry_delay_seconds=30, tags=["ingestion"])
|
||||
async def ingest_documents_task(
|
||||
job: IngestionJob,
|
||||
collection_name: str | None = None,
|
||||
batch_size: int = 50,
|
||||
batch_size: int | None = None,
|
||||
storage_client: BaseStorage | None = None,
|
||||
storage_block_name: str | None = None,
|
||||
ingestor_config_block_name: str | None = None,
|
||||
progress_callback: Callable[[int, str], None] | None = None,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
@@ -201,8 +228,10 @@ async def ingest_documents_task(
|
||||
Args:
|
||||
job: Ingestion job configuration
|
||||
collection_name: Target collection name
|
||||
batch_size: Number of documents per batch
|
||||
batch_size: Number of documents per batch (uses Variable if None)
|
||||
storage_client: Optional pre-initialized storage client
|
||||
storage_block_name: Optional storage block name to load
|
||||
ingestor_config_block_name: Optional ingestor config block name to load
|
||||
progress_callback: Optional callback for progress updates
|
||||
|
||||
Returns:
|
||||
@@ -211,8 +240,22 @@ async def ingest_documents_task(
|
||||
if progress_callback:
|
||||
progress_callback(35, "Creating ingestor and storage clients...")
|
||||
|
||||
ingestor = _create_ingestor(job)
|
||||
storage = storage_client or await _create_storage(job, collection_name)
|
||||
# Use Variable for batch size if not provided
|
||||
if batch_size is None:
|
||||
try:
|
||||
batch_size_var = await Variable.aget("default_batch_size", default="50")
|
||||
# Convert Variable result to int, handling various types
|
||||
if isinstance(batch_size_var, int):
|
||||
batch_size = batch_size_var
|
||||
elif isinstance(batch_size_var, (str, float)):
|
||||
batch_size = int(float(str(batch_size_var)))
|
||||
else:
|
||||
batch_size = 50
|
||||
except Exception:
|
||||
batch_size = 50
|
||||
|
||||
ingestor = await _create_ingestor(job, ingestor_config_block_name)
|
||||
storage = storage_client or await _create_storage(job, collection_name, storage_block_name)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(40, "Starting document processing...")
|
||||
@@ -220,30 +263,50 @@ async def ingest_documents_task(
|
||||
return await _process_documents(ingestor, storage, job, batch_size, collection_name, progress_callback)
|
||||
|
||||
|
||||
def _create_ingestor(job: IngestionJob) -> BaseIngestor:
|
||||
async def _create_ingestor(job: IngestionJob, config_block_name: str | None = None) -> BaseIngestor:
|
||||
"""Create appropriate ingestor based on job source type."""
|
||||
if job.source_type == IngestionSource.WEB:
|
||||
config = FirecrawlConfig()
|
||||
if config_block_name:
|
||||
# Use Block.aload with type slug for better type inference
|
||||
loaded_block = await Block.aload(f"firecrawl-config/{config_block_name}")
|
||||
config = cast(FirecrawlConfig, loaded_block)
|
||||
else:
|
||||
# Fallback to default configuration
|
||||
config = FirecrawlConfig()
|
||||
return FirecrawlIngestor(config)
|
||||
elif job.source_type == IngestionSource.REPOSITORY:
|
||||
config = RepomixConfig()
|
||||
if config_block_name:
|
||||
# Use Block.aload with type slug for better type inference
|
||||
loaded_block = await Block.aload(f"repomix-config/{config_block_name}")
|
||||
config = cast(RepomixConfig, loaded_block)
|
||||
else:
|
||||
# Fallback to default configuration
|
||||
config = RepomixConfig()
|
||||
return RepomixIngestor(config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {job.source_type}")
|
||||
|
||||
|
||||
async def _create_storage(job: IngestionJob, collection_name: str | None) -> BaseStorage:
|
||||
async def _create_storage(job: IngestionJob, collection_name: str | None, storage_block_name: str | None = None) -> BaseStorage:
|
||||
"""Create and initialize storage client."""
|
||||
if collection_name is None:
|
||||
collection_name = f"docs_{job.source_type.value}"
|
||||
# Use variable for default collection prefix
|
||||
prefix = await Variable.aget("default_collection_prefix", default="docs")
|
||||
collection_name = f"{prefix}_{job.source_type.value}"
|
||||
|
||||
from ..config import get_settings
|
||||
if storage_block_name:
|
||||
# Load storage config from block
|
||||
loaded_block = await Block.aload(f"storage-config/{storage_block_name}")
|
||||
storage_config = cast(StorageConfig, loaded_block)
|
||||
# Override collection name if provided
|
||||
storage_config.collection_name = collection_name
|
||||
else:
|
||||
# Fallback to building config from settings
|
||||
from ..config import get_settings
|
||||
settings = get_settings()
|
||||
storage_config = _build_storage_config(job, settings, collection_name)
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
storage_config = _build_storage_config(job, settings, collection_name)
|
||||
storage = _instantiate_storage(job.storage_backend, storage_config)
|
||||
|
||||
await storage.initialize()
|
||||
return storage
|
||||
|
||||
@@ -257,16 +320,19 @@ def _build_storage_config(
|
||||
StorageBackend.OPEN_WEBUI: settings.openwebui_endpoint,
|
||||
StorageBackend.R2R: settings.get_storage_endpoint("r2r"),
|
||||
}
|
||||
storage_api_keys = {
|
||||
storage_api_keys: dict[StorageBackend, str | None] = {
|
||||
StorageBackend.WEAVIATE: settings.get_api_key("weaviate"),
|
||||
StorageBackend.OPEN_WEBUI: settings.get_api_key("openwebui"),
|
||||
StorageBackend.R2R: None, # R2R is self-hosted, no API key needed
|
||||
}
|
||||
|
||||
api_key_raw: str | None = storage_api_keys[job.storage_backend]
|
||||
api_key: SecretStr | None = SecretStr(api_key_raw) if api_key_raw is not None else None
|
||||
|
||||
return StorageConfig(
|
||||
backend=job.storage_backend,
|
||||
endpoint=storage_endpoints[job.storage_backend],
|
||||
api_key=storage_api_keys[job.storage_backend],
|
||||
api_key=api_key,
|
||||
collection_name=collection_name,
|
||||
)
|
||||
|
||||
@@ -324,7 +390,20 @@ async def _process_documents(
|
||||
if progress_callback:
|
||||
progress_callback(45, "Ingesting documents from source...")
|
||||
|
||||
async for document in ingestor.ingest(job):
|
||||
# Use smart ingestion with deduplication if storage supports it
|
||||
if hasattr(storage, 'check_exists'):
|
||||
try:
|
||||
# Try to use the smart ingestion method
|
||||
document_generator = ingestor.ingest_with_dedup(
|
||||
job, storage, collection_name=collection_name
|
||||
)
|
||||
except Exception:
|
||||
# Fall back to regular ingestion if smart method fails
|
||||
document_generator = ingestor.ingest(job)
|
||||
else:
|
||||
document_generator = ingestor.ingest(job)
|
||||
|
||||
async for document in document_generator:
|
||||
batch.append(document)
|
||||
total_documents += 1
|
||||
|
||||
@@ -425,16 +504,35 @@ async def firecrawl_to_r2r_flow(
|
||||
r2r_storage = cast("R2RStorageType", storage_client)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(45, "Discovering pages with Firecrawl...")
|
||||
progress_callback(45, "Checking for existing content before mapping...")
|
||||
|
||||
discovered_urls = await map_firecrawl_site_task(str(job.source_url), firecrawl_config)
|
||||
# Smart mapping: try single URL first to avoid expensive map operation
|
||||
base_url = str(job.source_url)
|
||||
single_url_id = str(FirecrawlIngestor.compute_document_id(base_url))
|
||||
base_exists = await r2r_storage.check_exists(
|
||||
single_url_id, collection_name=resolved_collection, stale_after_days=30
|
||||
)
|
||||
|
||||
if base_exists:
|
||||
# Check if this is a recent single-page update
|
||||
logger.info("Base URL %s exists and is fresh, skipping expensive mapping", base_url)
|
||||
if progress_callback:
|
||||
progress_callback(100, "Content is up to date, no processing needed")
|
||||
return 0, 0
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(50, "Discovering pages with Firecrawl...")
|
||||
|
||||
discovered_urls = await map_firecrawl_site_task(base_url, firecrawl_config)
|
||||
unique_urls = _deduplicate_urls(discovered_urls)
|
||||
logger.info("Discovered %s unique URLs from Firecrawl map", len(unique_urls))
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(55, f"Found {len(unique_urls)} pages, filtering existing content...")
|
||||
progress_callback(60, f"Found {len(unique_urls)} pages, filtering existing content...")
|
||||
|
||||
eligible_urls = await filter_existing_documents_task(unique_urls, r2r_storage)
|
||||
eligible_urls = await filter_existing_documents_task(
|
||||
unique_urls, r2r_storage, collection_name=resolved_collection
|
||||
)
|
||||
|
||||
if not eligible_urls:
|
||||
logger.info("All Firecrawl pages are up to date for %s", job.source_url)
|
||||
@@ -443,7 +541,7 @@ async def firecrawl_to_r2r_flow(
|
||||
return 0, 0
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(65, f"Scraping {len(eligible_urls)} new/updated pages...")
|
||||
progress_callback(70, f"Scraping {len(eligible_urls)} new/updated pages...")
|
||||
|
||||
batch_size = min(settings.default_batch_size, firecrawl_config.limit)
|
||||
url_batches = _chunk_urls(eligible_urls, batch_size)
|
||||
@@ -462,7 +560,7 @@ async def firecrawl_to_r2r_flow(
|
||||
scraped_pages.extend(batch_pages)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(75, f"Processing {len(scraped_pages)} scraped pages...")
|
||||
progress_callback(80, f"Processing {len(scraped_pages)} scraped pages...")
|
||||
|
||||
documents = await annotate_firecrawl_metadata_task(scraped_pages, job)
|
||||
|
||||
@@ -471,7 +569,7 @@ async def firecrawl_to_r2r_flow(
|
||||
return 0, len(eligible_urls)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(85, f"Storing {len(documents)} documents in R2R...")
|
||||
progress_callback(90, f"Storing {len(documents)} documents in R2R...")
|
||||
|
||||
processed, failed = await upsert_r2r_documents_task(r2r_storage, documents, resolved_collection)
|
||||
|
||||
@@ -481,7 +579,7 @@ async def firecrawl_to_r2r_flow(
|
||||
|
||||
|
||||
@task(name="update_job_status", tags=["tracking"])
|
||||
def update_job_status_task(
|
||||
async def update_job_status_task(
|
||||
job: IngestionJob,
|
||||
status: IngestionStatus,
|
||||
processed: int = 0,
|
||||
@@ -575,7 +673,7 @@ async def create_ingestion_flow(
|
||||
# Update status to in progress
|
||||
if progress_callback:
|
||||
progress_callback(20, "Initializing storage...")
|
||||
job = update_job_status_task(job, IngestionStatus.IN_PROGRESS)
|
||||
job = await update_job_status_task(job, IngestionStatus.IN_PROGRESS)
|
||||
|
||||
# Run ingestion
|
||||
if progress_callback:
|
||||
@@ -601,7 +699,7 @@ async def create_ingestion_flow(
|
||||
else:
|
||||
final_status = IngestionStatus.COMPLETED
|
||||
|
||||
job = update_job_status_task(job, final_status, processed=processed, _failed=failed)
|
||||
job = await update_job_status_task(job, final_status, processed=processed, _failed=failed)
|
||||
|
||||
print(f"Ingestion completed: {processed} processed, {failed} failed")
|
||||
|
||||
@@ -610,7 +708,7 @@ async def create_ingestion_flow(
|
||||
error_messages.append(str(e))
|
||||
|
||||
# Don't reset counts - keep whatever was processed before the error
|
||||
job = update_job_status_task(
|
||||
job = await update_job_status_task(
|
||||
job, IngestionStatus.FAILED, processed=processed, _failed=failed, error=str(e)
|
||||
)
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Literal, Protocol, cast
|
||||
from prefect import serve
|
||||
from prefect.deployments.runner import RunnerDeployment
|
||||
from prefect.schedules import Cron, Interval
|
||||
from prefect.variables import Variable
|
||||
|
||||
from ..core.models import IngestionSource, StorageBackend
|
||||
from .ingestion import SourceTypeLike, StorageBackendLike, create_ingestion_flow
|
||||
@@ -30,11 +31,13 @@ def create_scheduled_deployment(
|
||||
storage_backend: StorageBackendLike = StorageBackend.WEAVIATE,
|
||||
schedule_type: Literal["cron", "interval"] = "interval",
|
||||
cron_expression: str | None = None,
|
||||
interval_minutes: int = 60,
|
||||
interval_minutes: int | None = None,
|
||||
tags: list[str] | None = None,
|
||||
storage_block_name: str | None = None,
|
||||
ingestor_config_block_name: str | None = None,
|
||||
) -> RunnerDeployment:
|
||||
"""
|
||||
Create a scheduled deployment for ingestion.
|
||||
Create a scheduled deployment for ingestion with block support.
|
||||
|
||||
Args:
|
||||
name: Deployment name
|
||||
@@ -43,12 +46,28 @@ def create_scheduled_deployment(
|
||||
storage_backend: Storage backend
|
||||
schedule_type: Type of schedule
|
||||
cron_expression: Cron expression if using cron
|
||||
interval_minutes: Interval in minutes if using interval
|
||||
interval_minutes: Interval in minutes (uses Variable if None)
|
||||
tags: Optional tags for deployment
|
||||
storage_block_name: Optional storage block name
|
||||
ingestor_config_block_name: Optional ingestor config block name
|
||||
|
||||
Returns:
|
||||
Deployment configuration
|
||||
"""
|
||||
# Use Variable for interval if not provided
|
||||
if interval_minutes is None:
|
||||
try:
|
||||
interval_var = Variable.get("default_schedule_interval", default="60")
|
||||
# Convert Variable result to int, handling various types
|
||||
if isinstance(interval_var, int):
|
||||
interval_minutes = interval_var
|
||||
elif isinstance(interval_var, (str, float)):
|
||||
interval_minutes = int(float(str(interval_var)))
|
||||
else:
|
||||
interval_minutes = 60
|
||||
except Exception:
|
||||
interval_minutes = 60
|
||||
|
||||
# Create schedule
|
||||
if schedule_type == "cron" and cron_expression:
|
||||
schedule = Cron(cron_expression, timezone="UTC")
|
||||
@@ -62,18 +81,27 @@ def create_scheduled_deployment(
|
||||
if tags is None:
|
||||
tags = [source_enum.value, backend_enum.value]
|
||||
|
||||
# Create deployment parameters with block support
|
||||
parameters = {
|
||||
"source_url": source_url,
|
||||
"source_type": source_enum.value,
|
||||
"storage_backend": backend_enum.value,
|
||||
"validate_first": True,
|
||||
}
|
||||
|
||||
# Add block names if provided
|
||||
if storage_block_name:
|
||||
parameters["storage_block_name"] = storage_block_name
|
||||
if ingestor_config_block_name:
|
||||
parameters["ingestor_config_block_name"] = ingestor_config_block_name
|
||||
|
||||
# Create deployment
|
||||
# The flow decorator adds the to_deployment method at runtime
|
||||
to_deployment = create_ingestion_flow.to_deployment
|
||||
deployment = to_deployment(
|
||||
name=name,
|
||||
schedule=schedule,
|
||||
parameters={
|
||||
"source_url": source_url,
|
||||
"source_type": source_enum.value,
|
||||
"storage_backend": backend_enum.value,
|
||||
"validate_first": True,
|
||||
},
|
||||
parameters=parameters,
|
||||
tags=tags,
|
||||
description=f"Scheduled ingestion from {source_url}",
|
||||
)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,10 +1,16 @@
|
||||
"""Base ingestor interface."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..core.models import Document, IngestionJob
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..storage.base import BaseStorage
|
||||
|
||||
|
||||
class BaseIngestor(ABC):
|
||||
"""Abstract base class for all ingestors."""
|
||||
@@ -47,3 +53,30 @@ class BaseIngestor(ABC):
|
||||
Estimated number of documents
|
||||
"""
|
||||
pass # pragma: no cover
|
||||
|
||||
async def ingest_with_dedup(
|
||||
self,
|
||||
job: IngestionJob,
|
||||
storage_client: BaseStorage,
|
||||
*,
|
||||
collection_name: str | None = None,
|
||||
stale_after_days: int = 30,
|
||||
) -> AsyncGenerator[Document, None]:
|
||||
"""
|
||||
Ingest documents with duplicate detection (optional optimization).
|
||||
|
||||
Default implementation falls back to regular ingestion.
|
||||
Subclasses can override to provide optimized deduplication.
|
||||
|
||||
Args:
|
||||
job: The ingestion job configuration
|
||||
storage_client: Storage client to check for existing documents
|
||||
collection_name: Collection to check for duplicates
|
||||
stale_after_days: Consider documents stale after this many days
|
||||
|
||||
Yields:
|
||||
Documents from the source (with deduplication if implemented)
|
||||
"""
|
||||
# Default implementation: fall back to regular ingestion
|
||||
async for document in self.ingest(job):
|
||||
yield document
|
||||
|
||||
@@ -6,6 +6,7 @@ import re
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
from uuid import NAMESPACE_URL, UUID, uuid5
|
||||
|
||||
@@ -23,8 +24,11 @@ from ..core.models import (
|
||||
)
|
||||
from .base import BaseIngestor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..storage.base import BaseStorage
|
||||
|
||||
class FirecrawlError(IngestionError): # type: ignore[misc]
|
||||
|
||||
class FirecrawlError(IngestionError):
|
||||
"""Base exception for Firecrawl-related errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: int | None = None) -> None:
|
||||
@@ -161,6 +165,55 @@ class FirecrawlIngestor(BaseIngestor):
|
||||
for page in pages:
|
||||
yield self.create_document(page, job)
|
||||
|
||||
async def ingest_with_dedup(
|
||||
self,
|
||||
job: IngestionJob,
|
||||
storage_client: "BaseStorage",
|
||||
*,
|
||||
collection_name: str | None = None,
|
||||
stale_after_days: int = 30,
|
||||
) -> AsyncGenerator[Document, None]:
|
||||
"""
|
||||
Ingest documents with duplicate detection to avoid unnecessary scraping.
|
||||
|
||||
Args:
|
||||
job: The ingestion job configuration
|
||||
storage_client: Storage client to check for existing documents
|
||||
collection_name: Collection to check for duplicates
|
||||
stale_after_days: Consider documents stale after this many days
|
||||
|
||||
Yields:
|
||||
Documents from the web source (only new/stale ones)
|
||||
"""
|
||||
url = str(job.source_url)
|
||||
|
||||
# First, map the site to understand its structure
|
||||
site_map = await self.map_site(url) or [url]
|
||||
|
||||
# Filter out URLs that already exist in storage and are fresh
|
||||
eligible_urls: list[str] = []
|
||||
for check_url in site_map:
|
||||
document_id = str(self.compute_document_id(check_url))
|
||||
exists = await storage_client.check_exists(
|
||||
document_id,
|
||||
collection_name=collection_name,
|
||||
stale_after_days=stale_after_days
|
||||
)
|
||||
if not exists:
|
||||
eligible_urls.append(check_url)
|
||||
|
||||
if not eligible_urls:
|
||||
return # No new documents to scrape
|
||||
|
||||
# Process eligible pages in batches
|
||||
batch_size = 10
|
||||
for i in range(0, len(eligible_urls), batch_size):
|
||||
batch_urls = eligible_urls[i : i + batch_size]
|
||||
pages = await self.scrape_pages(batch_urls)
|
||||
|
||||
for page in pages:
|
||||
yield self.create_document(page, job)
|
||||
|
||||
async def map_site(self, url: str) -> list[str]:
|
||||
"""Public wrapper for mapping a site."""
|
||||
|
||||
@@ -391,7 +444,7 @@ class FirecrawlIngestor(BaseIngestor):
|
||||
else:
|
||||
avg_sentence_length = words / sentences
|
||||
# Simplified readability score (0-100, higher is more readable)
|
||||
readability_score = max(0, min(100, 100 - (avg_sentence_length - 15) * 2))
|
||||
readability_score = max(0.0, min(100.0, 100.0 - (avg_sentence_length - 15.0) * 2.0))
|
||||
|
||||
# Completeness score based on structure
|
||||
completeness_factors = 0
|
||||
@@ -478,11 +531,15 @@ class FirecrawlIngestor(BaseIngestor):
|
||||
"domain": domain_info["domain"],
|
||||
"site_name": domain_info["site_name"],
|
||||
# Document structure
|
||||
"heading_hierarchy": structure_info["heading_hierarchy"],
|
||||
"section_depth": structure_info["section_depth"],
|
||||
"has_code_blocks": structure_info["has_code_blocks"],
|
||||
"has_images": structure_info["has_images"],
|
||||
"has_links": structure_info["has_links"],
|
||||
"heading_hierarchy": (
|
||||
list(hierarchy_val) if (hierarchy_val := structure_info.get("heading_hierarchy")) and isinstance(hierarchy_val, (list, tuple)) else []
|
||||
),
|
||||
"section_depth": (
|
||||
int(depth_val) if (depth_val := structure_info.get("section_depth")) and isinstance(depth_val, (int, str)) else 0
|
||||
),
|
||||
"has_code_blocks": bool(structure_info.get("has_code_blocks", False)),
|
||||
"has_images": bool(structure_info.get("has_images", False)),
|
||||
"has_links": bool(structure_info.get("has_links", False)),
|
||||
# Processing metadata
|
||||
"extraction_method": "firecrawl",
|
||||
"last_modified": datetime.fromisoformat(page.sitemap_last_modified)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""Repomix ingestor for Git repositories."""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
@@ -400,16 +400,13 @@ class RepomixIngestor(BaseIngestor):
|
||||
|
||||
# Handle different URL formats
|
||||
if 'github.com' in repo_url or 'gitlab.com' in repo_url:
|
||||
# Extract from URLs like https://github.com/org/repo.git
|
||||
path_match = re.search(r'/([^/]+)/([^/]+?)(?:\.git)?/?$', repo_url)
|
||||
if path_match:
|
||||
org_name = path_match.group(1)
|
||||
repo_name = path_match.group(2)
|
||||
else:
|
||||
# Try to extract from generic git URLs
|
||||
path_match = re.search(r'/([^/]+?)(?:\.git)?/?$', repo_url)
|
||||
if path_match:
|
||||
repo_name = path_match.group(1)
|
||||
if path_match := re.search(
|
||||
r'/([^/]+)/([^/]+?)(?:\.git)?/?$', repo_url
|
||||
):
|
||||
org_name = path_match[1]
|
||||
repo_name = path_match[2]
|
||||
elif path_match := re.search(r'/([^/]+?)(?:\.git)?/?$', repo_url):
|
||||
repo_name = path_match[1]
|
||||
|
||||
return {
|
||||
'repository_name': repo_name or 'unknown',
|
||||
@@ -485,28 +482,18 @@ class RepomixIngestor(BaseIngestor):
|
||||
|
||||
# Build rich metadata
|
||||
metadata: DocumentMetadata = {
|
||||
# Core required fields
|
||||
"source_url": str(job.source_url),
|
||||
"timestamp": datetime.now(UTC),
|
||||
"content_type": content_type,
|
||||
"word_count": len(content.split()),
|
||||
"char_count": len(content),
|
||||
|
||||
# Basic fields
|
||||
"title": f"{file_path}" + (f" (chunk {chunk_index})" if chunk_index > 0 else ""),
|
||||
"title": f"{file_path}"
|
||||
+ (f" (chunk {chunk_index})" if chunk_index > 0 else ""),
|
||||
"description": f"Repository file: {file_path}",
|
||||
|
||||
# Content categorization
|
||||
"category": "source_code" if programming_language else "documentation",
|
||||
"language": programming_language or "text",
|
||||
|
||||
# Document structure from code analysis
|
||||
"has_code_blocks": True if programming_language else False,
|
||||
|
||||
# Processing metadata
|
||||
"has_code_blocks": bool(programming_language),
|
||||
"extraction_method": "repomix",
|
||||
|
||||
# Repository-specific fields
|
||||
"file_path": file_path,
|
||||
"programming_language": programming_language,
|
||||
}
|
||||
|
||||
0
ingest_pipeline/py.typed
Normal file
0
ingest_pipeline/py.typed
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -75,6 +75,40 @@ class BaseStorage(ABC):
|
||||
"""
|
||||
raise NotImplementedError(f"{self.__class__.__name__} doesn't support document retrieval")
|
||||
|
||||
async def check_exists(
|
||||
self, document_id: str, *, collection_name: str | None = None, stale_after_days: int = 30
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a document exists and is not stale.
|
||||
|
||||
Args:
|
||||
document_id: Document ID to check
|
||||
collection_name: Collection to check in
|
||||
stale_after_days: Consider document stale after this many days
|
||||
|
||||
Returns:
|
||||
True if document exists and is not stale, False otherwise
|
||||
"""
|
||||
try:
|
||||
document = await self.retrieve(document_id, collection_name=collection_name)
|
||||
if document is None:
|
||||
return False
|
||||
|
||||
# Check staleness if timestamp is available
|
||||
if "timestamp" in document.metadata:
|
||||
from datetime import UTC, datetime, timedelta
|
||||
timestamp_obj = document.metadata["timestamp"]
|
||||
if isinstance(timestamp_obj, datetime):
|
||||
timestamp = timestamp_obj
|
||||
cutoff = datetime.now(UTC) - timedelta(days=stale_after_days)
|
||||
return timestamp >= cutoff
|
||||
|
||||
# If no timestamp, assume it exists and is valid
|
||||
return True
|
||||
except Exception:
|
||||
# Backend doesn't support retrieval, assume doesn't exist
|
||||
return False
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
@@ -130,6 +164,15 @@ class BaseStorage(ABC):
|
||||
"""
|
||||
return []
|
||||
|
||||
async def describe_collections(self) -> list[dict[str, object]]:
|
||||
"""
|
||||
Describe available collections with metadata (if supported by backend).
|
||||
|
||||
Returns:
|
||||
List of collection metadata dictionaries, empty list if not supported
|
||||
"""
|
||||
return []
|
||||
|
||||
async def list_documents(
|
||||
self,
|
||||
limit: int = 100,
|
||||
@@ -159,4 +202,5 @@ class BaseStorage(ABC):
|
||||
|
||||
Default implementation does nothing.
|
||||
"""
|
||||
pass
|
||||
# Default implementation - storage backends can override to cleanup connections
|
||||
return None
|
||||
|
||||
@@ -274,6 +274,25 @@ class OpenWebUIStorage(BaseStorage):
|
||||
except Exception as e:
|
||||
raise StorageError(f"Failed to store batch: {e}") from e
|
||||
|
||||
@override
|
||||
async def retrieve(
|
||||
self, document_id: str, *, collection_name: str | None = None
|
||||
) -> Document | None:
|
||||
"""
|
||||
OpenWebUI doesn't support document retrieval by ID.
|
||||
|
||||
Args:
|
||||
document_id: File ID (not supported)
|
||||
collection_name: Collection name (not used)
|
||||
|
||||
Returns:
|
||||
Always None - retrieval not supported
|
||||
"""
|
||||
# OpenWebUI uses file-based storage without direct document retrieval
|
||||
# This will cause the base check_exists method to return False,
|
||||
# which means documents will always be re-scraped for OpenWebUI
|
||||
raise NotImplementedError("OpenWebUI doesn't support document retrieval by ID")
|
||||
|
||||
@override
|
||||
async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool:
|
||||
"""
|
||||
@@ -414,69 +433,57 @@ class OpenWebUIStorage(BaseStorage):
|
||||
count: int
|
||||
size_mb: float
|
||||
|
||||
async def describe_collections(self) -> list[CollectionSummary]:
|
||||
|
||||
async def _get_knowledge_base_count(self, kb: dict[str, object]) -> int:
|
||||
"""Get the file count for a knowledge base."""
|
||||
kb_id = kb.get("id")
|
||||
name = kb.get("name", "Unknown")
|
||||
|
||||
if not kb_id:
|
||||
return self._count_files_from_basic_info(kb)
|
||||
|
||||
return await self._count_files_from_detailed_info(str(kb_id), str(name), kb)
|
||||
|
||||
def _count_files_from_basic_info(self, kb: dict[str, object]) -> int:
|
||||
"""Count files from basic knowledge base info."""
|
||||
files = kb.get("files", [])
|
||||
return len(files) if isinstance(files, list) and files is not None else 0
|
||||
|
||||
async def _count_files_from_detailed_info(self, kb_id: str, name: str, kb: dict[str, object]) -> int:
|
||||
"""Count files by fetching detailed knowledge base info."""
|
||||
try:
|
||||
LOGGER.debug(f"Fetching detailed info for KB '{name}' from /api/v1/knowledge/{kb_id}")
|
||||
detail_response = await self.client.get(f"/api/v1/knowledge/{kb_id}")
|
||||
detail_response.raise_for_status()
|
||||
detailed_kb = detail_response.json()
|
||||
|
||||
files = detailed_kb.get("files", [])
|
||||
count = len(files) if isinstance(files, list) and files is not None else 0
|
||||
|
||||
LOGGER.info(f"Knowledge base '{name}' (ID: {kb_id}): found {count} files")
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"Failed to get detailed info for KB '{name}' (ID: {kb_id}): {e}")
|
||||
return self._count_files_from_basic_info(kb)
|
||||
|
||||
async def describe_collections(self) -> list[dict[str, object]]:
|
||||
"""Return metadata about each knowledge base."""
|
||||
try:
|
||||
# First get the list of knowledge bases
|
||||
response = await self.client.get("/api/v1/knowledge/")
|
||||
response.raise_for_status()
|
||||
knowledge_bases = response.json()
|
||||
knowledge_bases = await self._fetch_knowledge_bases()
|
||||
collections: list[dict[str, object]] = []
|
||||
|
||||
LOGGER.info(f"OpenWebUI returned {len(knowledge_bases)} knowledge bases")
|
||||
LOGGER.debug(f"Knowledge bases structure: {knowledge_bases}")
|
||||
|
||||
collections: list[OpenWebUIStorage.CollectionSummary] = []
|
||||
for kb in knowledge_bases:
|
||||
if not isinstance(kb, dict):
|
||||
continue
|
||||
|
||||
kb_id = kb.get("id")
|
||||
count = await self._get_knowledge_base_count(kb)
|
||||
name = kb.get("name", "Unknown")
|
||||
|
||||
LOGGER.info(f"Processing knowledge base: '{name}' (ID: {kb_id})")
|
||||
LOGGER.debug(f"KB structure: {kb}")
|
||||
|
||||
if not kb_id:
|
||||
# If no ID, fall back to basic count from list response
|
||||
files = kb.get("files", [])
|
||||
if files is None:
|
||||
files = []
|
||||
count = len(files) if isinstance(files, list) else 0
|
||||
else:
|
||||
# Get detailed knowledge base information using the correct endpoint
|
||||
try:
|
||||
LOGGER.debug(f"Fetching detailed info for KB '{name}' from /api/v1/knowledge/{kb_id}")
|
||||
detail_response = await self.client.get(f"/api/v1/knowledge/{kb_id}")
|
||||
detail_response.raise_for_status()
|
||||
detailed_kb = detail_response.json()
|
||||
|
||||
LOGGER.debug(f"Detailed KB response: {detailed_kb}")
|
||||
|
||||
files = detailed_kb.get("files", [])
|
||||
if files is None:
|
||||
files = []
|
||||
count = len(files) if isinstance(files, list) else 0
|
||||
|
||||
# Debug logging
|
||||
LOGGER.info(f"Knowledge base '{name}' (ID: {kb_id}): found {count} files")
|
||||
if count > 0 and len(files) > 0:
|
||||
LOGGER.debug(f"First file structure: {files[0] if files else 'No files'}")
|
||||
elif count == 0:
|
||||
LOGGER.warning(f"Knowledge base '{name}' has 0 files. Files field type: {type(files)}, value: {files}")
|
||||
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"Failed to get detailed info for KB '{name}' (ID: {kb_id}): {e}")
|
||||
# Fallback to basic files list if detailed fetch fails
|
||||
files = kb.get("files", [])
|
||||
if files is None:
|
||||
files = []
|
||||
count = len(files) if isinstance(files, list) else 0
|
||||
LOGGER.info(f"Fallback count for KB '{name}': {count}")
|
||||
|
||||
size_mb = count * 0.5 # rough heuristic
|
||||
summary: OpenWebUIStorage.CollectionSummary = {
|
||||
|
||||
summary: dict[str, object] = {
|
||||
"name": str(name),
|
||||
"count": int(count),
|
||||
"count": count,
|
||||
"size_mb": float(size_mb),
|
||||
}
|
||||
collections.append(summary)
|
||||
@@ -500,7 +507,10 @@ class OpenWebUIStorage(BaseStorage):
|
||||
# If no collection name provided, return total across all collections
|
||||
try:
|
||||
collections = await self.describe_collections()
|
||||
return sum(collection["count"] for collection in collections)
|
||||
return sum(
|
||||
int(collection["count"]) if isinstance(collection["count"], (int, str)) else 0
|
||||
for collection in collections
|
||||
)
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
Binary file not shown.
@@ -62,7 +62,8 @@ class R2RCollections:
|
||||
name=name,
|
||||
description=description,
|
||||
)
|
||||
return cast(JsonData, response.results.model_dump())
|
||||
# response.results is a list, not a model with model_dump()
|
||||
return cast(JsonData, response.results if isinstance(response.results, list) else [])
|
||||
except Exception as e:
|
||||
raise StorageError(f"Failed to create collection '{name}': {e}") from e
|
||||
|
||||
@@ -80,7 +81,8 @@ class R2RCollections:
|
||||
"""
|
||||
try:
|
||||
response = await self.client.collections.retrieve(str(collection_id))
|
||||
return cast(JsonData, response.results.model_dump())
|
||||
# response.results is a list, not a model with model_dump()
|
||||
return cast(JsonData, response.results if isinstance(response.results, list) else [])
|
||||
except Exception as e:
|
||||
raise StorageError(f"Failed to retrieve collection {collection_id}: {e}") from e
|
||||
|
||||
@@ -109,7 +111,8 @@ class R2RCollections:
|
||||
name=name,
|
||||
description=description,
|
||||
)
|
||||
return cast(JsonData, response.results.model_dump())
|
||||
# response.results is a list, not a model with model_dump()
|
||||
return cast(JsonData, response.results if isinstance(response.results, list) else [])
|
||||
except Exception as e:
|
||||
raise StorageError(f"Failed to update collection {collection_id}: {e}") from e
|
||||
|
||||
@@ -153,7 +156,8 @@ class R2RCollections:
|
||||
limit=limit,
|
||||
owner_only=owner_only,
|
||||
)
|
||||
return cast(JsonData, response.results.model_dump())
|
||||
# response.results is a list, not a model with model_dump()
|
||||
return cast(JsonData, response.results if isinstance(response.results, list) else [])
|
||||
except Exception as e:
|
||||
raise StorageError(f"Failed to list collections: {e}") from e
|
||||
|
||||
@@ -202,7 +206,8 @@ class R2RCollections:
|
||||
id=str(collection_id),
|
||||
document_id=str(document_id),
|
||||
)
|
||||
return cast(JsonData, response.results.model_dump())
|
||||
# response.results is a list, not a model with model_dump()
|
||||
return cast(JsonData, response.results if isinstance(response.results, list) else [])
|
||||
except Exception as e:
|
||||
raise StorageError(
|
||||
f"Failed to add document {document_id} to collection {collection_id}: {e}"
|
||||
@@ -254,7 +259,8 @@ class R2RCollections:
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
return cast(JsonData, response.results.model_dump())
|
||||
# response.results is a list, not a model with model_dump()
|
||||
return cast(JsonData, response.results if isinstance(response.results, list) else [])
|
||||
except Exception as e:
|
||||
raise StorageError(
|
||||
f"Failed to list documents in collection {collection_id}: {e}"
|
||||
@@ -278,7 +284,8 @@ class R2RCollections:
|
||||
id=str(collection_id),
|
||||
user_id=str(user_id),
|
||||
)
|
||||
return cast(JsonData, response.results.model_dump())
|
||||
# response.results is a list, not a model with model_dump()
|
||||
return cast(JsonData, response.results if isinstance(response.results, list) else [])
|
||||
except Exception as e:
|
||||
raise StorageError(
|
||||
f"Failed to add user {user_id} to collection {collection_id}: {e}"
|
||||
@@ -330,7 +337,8 @@ class R2RCollections:
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
return cast(JsonData, response.results.model_dump())
|
||||
# response.results is a list, not a model with model_dump()
|
||||
return cast(JsonData, response.results if isinstance(response.results, list) else [])
|
||||
except Exception as e:
|
||||
raise StorageError(f"Failed to list users for collection {collection_id}: {e}") from e
|
||||
|
||||
@@ -359,7 +367,8 @@ class R2RCollections:
|
||||
run_with_orchestration=run_with_orchestration,
|
||||
settings=cast(dict[str, object], settings or {}),
|
||||
)
|
||||
return cast(JsonData, response.results.model_dump())
|
||||
# response.results is a list, not a model with model_dump()
|
||||
return cast(JsonData, response.results if isinstance(response.results, list) else [])
|
||||
except Exception as e:
|
||||
raise StorageError(
|
||||
f"Failed to extract entities from collection {collection_id}: {e}"
|
||||
|
||||
@@ -9,10 +9,15 @@ from datetime import UTC, datetime
|
||||
from typing import Self, TypeVar, cast
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import httpx
|
||||
from r2r import R2RAsyncClient, R2RException
|
||||
from r2r import R2RAsyncClient
|
||||
from typing_extensions import override
|
||||
|
||||
# Direct imports for runtime and type checking
|
||||
# Note: Some type checkers (basedpyright/Pyrefly) may report import issues
|
||||
# but these work correctly at runtime and with mypy
|
||||
from httpx import AsyncClient, HTTPStatusError
|
||||
from r2r import R2RException
|
||||
|
||||
from ...core.exceptions import StorageError
|
||||
from ...core.models import Document, DocumentMetadata, IngestionSource, StorageConfig
|
||||
from ..base import BaseStorage
|
||||
@@ -83,7 +88,7 @@ class R2RStorage(BaseStorage):
|
||||
try:
|
||||
# Ensure we have an event loop
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
_ = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# No event loop running, this should not happen in async context
|
||||
# but let's be defensive
|
||||
@@ -93,7 +98,7 @@ class R2RStorage(BaseStorage):
|
||||
|
||||
# Test connection using direct HTTP call to v3 API
|
||||
endpoint = self.endpoint
|
||||
client = httpx.AsyncClient()
|
||||
client = AsyncClient()
|
||||
try:
|
||||
response = await client.get(f"{endpoint}/v3/collections")
|
||||
response.raise_for_status()
|
||||
@@ -107,7 +112,7 @@ class R2RStorage(BaseStorage):
|
||||
"""Get or create collection by name."""
|
||||
try:
|
||||
endpoint = self.endpoint
|
||||
client = httpx.AsyncClient()
|
||||
client = AsyncClient()
|
||||
try:
|
||||
# List collections and find by name
|
||||
response = await client.get(f"{endpoint}/v3/collections")
|
||||
@@ -203,7 +208,7 @@ class R2RStorage(BaseStorage):
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
async with AsyncClient() as http_client:
|
||||
# Use files parameter but with string values for multipart/form-data
|
||||
# This matches the cURL -F behavior more closely
|
||||
metadata = self._build_metadata(document)
|
||||
@@ -264,7 +269,7 @@ class R2RStorage(BaseStorage):
|
||||
doc_response = response.json()
|
||||
break # Success - exit retry loop
|
||||
|
||||
except httpx.TimeoutException:
|
||||
except (OSError, asyncio.TimeoutError):
|
||||
if attempt < max_retries - 1:
|
||||
print(
|
||||
f"Timeout for document {requested_id}, retrying in {retry_delay}s..."
|
||||
@@ -274,7 +279,7 @@ class R2RStorage(BaseStorage):
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
except httpx.HTTPStatusError as e:
|
||||
except HTTPStatusError as e:
|
||||
if e.response.status_code >= 500 and attempt < max_retries - 1:
|
||||
print(
|
||||
f"Server error {e.response.status_code} for document {requested_id}, retrying in {retry_delay}s..."
|
||||
@@ -470,7 +475,7 @@ class R2RStorage(BaseStorage):
|
||||
elif description := metadata_map.get("description"):
|
||||
metadata["description"] = cast(str | None, description)
|
||||
if tags := metadata_map.get("tags"):
|
||||
metadata["tags"] = _as_sequence(tags) if isinstance(tags, list) else []
|
||||
metadata["tags"] = [str(tag) for tag in tags] if isinstance(tags, list) else []
|
||||
if category := metadata_map.get("category"):
|
||||
metadata["category"] = str(category)
|
||||
if section := metadata_map.get("section"):
|
||||
@@ -502,13 +507,15 @@ class R2RStorage(BaseStorage):
|
||||
if last_modified := metadata_map.get("last_modified"):
|
||||
metadata["last_modified"] = _as_datetime(last_modified)
|
||||
if readability_score := metadata_map.get("readability_score"):
|
||||
metadata["readability_score"] = (
|
||||
float(readability_score) if readability_score is not None else None
|
||||
)
|
||||
try:
|
||||
metadata["readability_score"] = float(str(readability_score))
|
||||
except (ValueError, TypeError):
|
||||
metadata["readability_score"] = None
|
||||
if completeness_score := metadata_map.get("completeness_score"):
|
||||
metadata["completeness_score"] = (
|
||||
float(completeness_score) if completeness_score is not None else None
|
||||
)
|
||||
try:
|
||||
metadata["completeness_score"] = float(str(completeness_score))
|
||||
except (ValueError, TypeError):
|
||||
metadata["completeness_score"] = None
|
||||
|
||||
source_value = str(metadata_map.get("ingestion_source", IngestionSource.WEB.value))
|
||||
try:
|
||||
@@ -609,7 +616,7 @@ class R2RStorage(BaseStorage):
|
||||
"""Get document count in collection."""
|
||||
try:
|
||||
endpoint = self.endpoint
|
||||
client = httpx.AsyncClient()
|
||||
client = AsyncClient()
|
||||
try:
|
||||
# Get collections and find the count for the specific collection
|
||||
response = await client.get(f"{endpoint}/v3/collections")
|
||||
@@ -677,7 +684,7 @@ class R2RStorage(BaseStorage):
|
||||
"""List all available collections."""
|
||||
try:
|
||||
endpoint = self.endpoint
|
||||
client = httpx.AsyncClient()
|
||||
client = AsyncClient()
|
||||
try:
|
||||
response = await client.get(f"{endpoint}/v3/collections")
|
||||
response.raise_for_status()
|
||||
@@ -777,7 +784,7 @@ class R2RStorage(BaseStorage):
|
||||
collection_id = await self._ensure_collection(collection_name)
|
||||
# Use the collections API to list documents in a specific collection
|
||||
endpoint = self.endpoint
|
||||
client = httpx.AsyncClient()
|
||||
client = AsyncClient()
|
||||
try:
|
||||
params = {"offset": offset, "limit": limit}
|
||||
response = await client.get(
|
||||
|
||||
@@ -22,7 +22,6 @@ from ..core.models import Document, DocumentMetadata, IngestionSource, StorageCo
|
||||
from ..utils.vectorizer import Vectorizer
|
||||
from .base import BaseStorage
|
||||
|
||||
|
||||
VectorContainer: TypeAlias = Mapping[str, object] | Sequence[object] | None
|
||||
|
||||
|
||||
@@ -591,13 +590,13 @@ class WeaviateStorage(BaseStorage):
|
||||
except Exception as e:
|
||||
raise StorageError(f"Failed to list collections: {e}") from e
|
||||
|
||||
async def describe_collections(self) -> list[dict[str, str | int | float]]:
|
||||
async def describe_collections(self) -> list[dict[str, object]]:
|
||||
"""Return metadata for each Weaviate collection."""
|
||||
if not self.client:
|
||||
raise StorageError("Weaviate client not initialized")
|
||||
|
||||
try:
|
||||
collections: list[dict[str, str | int | float]] = []
|
||||
collections: list[dict[str, object]] = []
|
||||
for name in self.client.collections.list_all():
|
||||
collection_obj = self.client.collections.get(name)
|
||||
if not collection_obj:
|
||||
@@ -808,7 +807,7 @@ class WeaviateStorage(BaseStorage):
|
||||
offset: int = 0,
|
||||
*,
|
||||
collection_name: str | None = None,
|
||||
) -> list[dict[str, str | int]]:
|
||||
) -> list[dict[str, object]]:
|
||||
"""
|
||||
List documents in the collection with pagination.
|
||||
|
||||
@@ -830,7 +829,7 @@ class WeaviateStorage(BaseStorage):
|
||||
limit=limit, offset=offset, return_metadata=["creation_time"]
|
||||
)
|
||||
|
||||
documents = []
|
||||
documents: list[dict[str, object]] = []
|
||||
for obj in response.objects:
|
||||
props = self._coerce_properties(
|
||||
obj.properties,
|
||||
@@ -849,7 +848,7 @@ class WeaviateStorage(BaseStorage):
|
||||
else:
|
||||
word_count = 0
|
||||
|
||||
doc_info: dict[str, str | int] = {
|
||||
doc_info: dict[str, object] = {
|
||||
"id": str(obj.uuid),
|
||||
"title": str(props.get("title", "Untitled")),
|
||||
"source_url": str(props.get("source_url", "")),
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -2,7 +2,7 @@
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from typing import TypedDict, cast
|
||||
from typing import Protocol, TypedDict, cast
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -10,6 +10,41 @@ from ..core.exceptions import IngestionError
|
||||
from ..core.models import Document
|
||||
|
||||
|
||||
class HttpResponse(Protocol):
|
||||
"""Protocol for HTTP response."""
|
||||
|
||||
def raise_for_status(self) -> None: ...
|
||||
def json(self) -> dict[str, object]: ...
|
||||
|
||||
|
||||
class AsyncHttpClient(Protocol):
|
||||
"""Protocol for async HTTP client."""
|
||||
|
||||
async def post(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
json: dict[str, object] | None = None
|
||||
) -> HttpResponse: ...
|
||||
|
||||
async def aclose(self) -> None: ...
|
||||
|
||||
|
||||
class LlmResponse(TypedDict):
|
||||
"""Type for LLM API response structure."""
|
||||
choices: list[dict[str, object]]
|
||||
|
||||
|
||||
class LlmChoice(TypedDict):
|
||||
"""Type for individual choice in LLM response."""
|
||||
message: dict[str, object]
|
||||
|
||||
|
||||
class LlmMessage(TypedDict):
|
||||
"""Type for message in LLM choice."""
|
||||
content: str
|
||||
|
||||
|
||||
class DocumentMetadata(TypedDict, total=False):
|
||||
"""Structured metadata for documents."""
|
||||
|
||||
@@ -27,7 +62,7 @@ class MetadataTagger:
|
||||
|
||||
endpoint: str
|
||||
model: str
|
||||
client: httpx.AsyncClient
|
||||
client: AsyncHttpClient
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -60,7 +95,10 @@ class MetadataTagger:
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
self.client = httpx.AsyncClient(timeout=60.0, headers=headers)
|
||||
# Create client with proper typing - httpx.AsyncClient implements AsyncHttpClient protocol
|
||||
AsyncClientClass = getattr(httpx, "AsyncClient")
|
||||
raw_client = AsyncClientClass(timeout=60.0, headers=headers)
|
||||
self.client = cast(AsyncHttpClient, raw_client)
|
||||
|
||||
async def tag_document(
|
||||
self, document: Document, custom_instructions: str | None = None
|
||||
@@ -87,29 +125,52 @@ class MetadataTagger:
|
||||
)
|
||||
|
||||
# Merge with existing metadata - preserve ALL existing fields and add LLM-generated ones
|
||||
from typing import cast
|
||||
|
||||
from ..core.models import DocumentMetadata as CoreDocumentMetadata
|
||||
|
||||
# Start with a copy of existing metadata to preserve all fields
|
||||
# Cast to avoid TypedDict key errors during manipulation
|
||||
updated_metadata = cast(dict[str, object], dict(document.metadata))
|
||||
updated_metadata = dict(document.metadata)
|
||||
|
||||
# Update/enhance with LLM-generated metadata, preserving existing values when new ones are empty
|
||||
if metadata.get("title") and not updated_metadata.get("title"):
|
||||
updated_metadata["title"] = str(metadata["title"]) # type: ignore[typeddict-item]
|
||||
if metadata.get("summary") and not updated_metadata.get("description"):
|
||||
updated_metadata["description"] = str(metadata["summary"])
|
||||
if title_val := metadata.get("title"):
|
||||
if not updated_metadata.get("title") and isinstance(title_val, str):
|
||||
updated_metadata["title"] = title_val
|
||||
if summary_val := metadata.get("summary"):
|
||||
if not updated_metadata.get("description"):
|
||||
updated_metadata["description"] = str(summary_val)
|
||||
|
||||
# Ensure required fields have values
|
||||
updated_metadata.setdefault("source_url", "")
|
||||
updated_metadata.setdefault("timestamp", datetime.now(UTC))
|
||||
updated_metadata.setdefault("content_type", "text/plain")
|
||||
updated_metadata.setdefault("word_count", len(document.content.split()))
|
||||
updated_metadata.setdefault("char_count", len(document.content))
|
||||
_ = updated_metadata.setdefault("source_url", "")
|
||||
_ = updated_metadata.setdefault("timestamp", datetime.now(UTC))
|
||||
_ = updated_metadata.setdefault("content_type", "text/plain")
|
||||
_ = updated_metadata.setdefault("word_count", len(document.content.split()))
|
||||
_ = updated_metadata.setdefault("char_count", len(document.content))
|
||||
|
||||
# Cast to the expected type since we're preserving all fields from the original metadata
|
||||
document.metadata = cast(CoreDocumentMetadata, updated_metadata)
|
||||
# Build a proper DocumentMetadata instance with only valid keys
|
||||
new_metadata: CoreDocumentMetadata = {
|
||||
"source_url": str(updated_metadata.get("source_url", "")),
|
||||
"timestamp": (
|
||||
lambda ts: ts if isinstance(ts, datetime) else datetime.now(UTC)
|
||||
)(updated_metadata.get("timestamp", datetime.now(UTC))),
|
||||
"content_type": str(updated_metadata.get("content_type", "text/plain")),
|
||||
"word_count": (lambda wc: int(wc) if isinstance(wc, (int, str)) else 0)(updated_metadata.get("word_count", 0)),
|
||||
"char_count": (lambda cc: int(cc) if isinstance(cc, (int, str)) else 0)(updated_metadata.get("char_count", 0)),
|
||||
}
|
||||
|
||||
# Add optional fields if they exist
|
||||
if "title" in updated_metadata and updated_metadata["title"]:
|
||||
new_metadata["title"] = str(updated_metadata["title"])
|
||||
if "description" in updated_metadata and updated_metadata["description"]:
|
||||
new_metadata["description"] = str(updated_metadata["description"])
|
||||
if "tags" in updated_metadata and isinstance(updated_metadata["tags"], list):
|
||||
tags_list = cast(list[object], updated_metadata["tags"])
|
||||
new_metadata["tags"] = [str(tag) for tag in tags_list if tag is not None]
|
||||
if "category" in updated_metadata and updated_metadata["category"]:
|
||||
new_metadata["category"] = str(updated_metadata["category"])
|
||||
if "language" in updated_metadata and updated_metadata["language"]:
|
||||
new_metadata["language"] = str(updated_metadata["language"])
|
||||
|
||||
document.metadata = new_metadata
|
||||
|
||||
return document
|
||||
|
||||
@@ -199,27 +260,27 @@ Return a JSON object with the following structure:
|
||||
if not isinstance(result_raw, dict):
|
||||
raise IngestionError("Invalid response format from LLM")
|
||||
|
||||
result = cast(dict[str, object], result_raw)
|
||||
result = cast(LlmResponse, result_raw)
|
||||
|
||||
# Extract content from response
|
||||
choices = result.get("choices", [])
|
||||
if not choices or not isinstance(choices, list):
|
||||
raise IngestionError("No response from LLM")
|
||||
|
||||
first_choice_raw = cast(object, choices[0])
|
||||
first_choice_raw = choices[0]
|
||||
if not isinstance(first_choice_raw, dict):
|
||||
raise IngestionError("Invalid choice format")
|
||||
|
||||
first_choice = cast(dict[str, object], first_choice_raw)
|
||||
first_choice = cast(LlmChoice, first_choice_raw)
|
||||
message_raw = first_choice.get("message", {})
|
||||
if not isinstance(message_raw, dict):
|
||||
raise IngestionError("Invalid message format")
|
||||
|
||||
message = cast(dict[str, object], message_raw)
|
||||
message = cast(LlmMessage, message_raw)
|
||||
content_str = str(message.get("content", "{}"))
|
||||
|
||||
try:
|
||||
raw_metadata = json.loads(content_str)
|
||||
raw_metadata = cast(dict[str, object], json.loads(content_str))
|
||||
except json.JSONDecodeError as e:
|
||||
raise IngestionError(f"Failed to parse LLM response: {e}") from e
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ class Vectorizer:
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
self.client = httpx.AsyncClient(timeout=60.0, headers=headers) # type: ignore[attr-defined]
|
||||
self.client: httpx.AsyncClient = httpx.AsyncClient(timeout=60.0, headers=headers)
|
||||
|
||||
async def vectorize(self, text: str) -> list[float]:
|
||||
"""
|
||||
|
||||
26
logs/tui.log
26
logs/tui.log
@@ -846,3 +846,29 @@ metadata.author
|
||||
2025-09-18 08:29:04 | INFO | httpx | HTTP Request: POST http://prefect.lab/api/flow_runs/3d6f8223-0b7e-43fd-a1f4-c102f3fc8919/set_state "HTTP/1.1 201 Created"
|
||||
2025-09-18 08:29:04 | INFO | prefect.flow_runs | Finished in state Completed()
|
||||
2025-09-18 08:29:06 | INFO | httpx | HTTP Request: POST http://prefect.lab/api/logs/ "HTTP/1.1 201 Created"
|
||||
2025-09-18 22:21:09 | INFO | ingest_pipeline.cli.tui.utils.runners | Initializing collection management TUI
|
||||
2025-09-18 22:21:09 | INFO | ingest_pipeline.cli.tui.utils.runners | Scanning available storage backends
|
||||
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://weaviate.yo/v1/.well-known/openid-configuration "HTTP/1.1 404 Not Found"
|
||||
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://weaviate.yo/v1/meta "HTTP/1.1 200 OK"
|
||||
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET https://pypi.org/pypi/weaviate-client/json "HTTP/1.1 200 OK"
|
||||
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://weaviate.yo/v1/schema "HTTP/1.1 200 OK"
|
||||
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://chat.lab/api/v1/knowledge/list "HTTP/1.1 401 Unauthorized"
|
||||
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://r2r.lab/v3/collections "HTTP/1.1 200 OK"
|
||||
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://r2r.lab/v3/collections "HTTP/1.1 200 OK"
|
||||
2025-09-18 22:21:09 | INFO | ingest_pipeline.cli.tui.utils.runners | weaviate connected successfully
|
||||
2025-09-18 22:21:09 | WARNING | ingest_pipeline.cli.tui.utils.runners | open_webui connection failed
|
||||
2025-09-18 22:21:09 | INFO | ingest_pipeline.cli.tui.utils.runners | r2r connected successfully
|
||||
2025-09-18 22:21:09 | INFO | ingest_pipeline.cli.tui.utils.runners | Launching TUI with 2 backend(s): weaviate, r2r
|
||||
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://weaviate.yo/v1/schema "HTTP/1.1 200 OK"
|
||||
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: POST http://weaviate.yo/v1/graphql "HTTP/1.1 200 OK"
|
||||
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: POST http://weaviate.yo/v1/graphql "HTTP/1.1 200 OK"
|
||||
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: POST http://weaviate.yo/v1/graphql "HTTP/1.1 200 OK"
|
||||
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://r2r.lab/v3/collections "HTTP/1.1 200 OK"
|
||||
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://r2r.lab/v3/collections "HTTP/1.1 200 OK"
|
||||
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://r2r.lab/v3/collections "HTTP/1.1 200 OK"
|
||||
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://r2r.lab/v3/collections "HTTP/1.1 200 OK"
|
||||
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://r2r.lab/v3/collections "HTTP/1.1 200 OK"
|
||||
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://r2r.lab/v3/collections "HTTP/1.1 200 OK"
|
||||
2025-09-18 22:21:09 | INFO | httpx | HTTP Request: GET http://r2r.lab/v3/collections "HTTP/1.1 200 OK"
|
||||
2025-09-18 23:06:37 | INFO | ingest_pipeline.cli.tui.utils.runners | Shutting down storage connections
|
||||
2025-09-18 23:06:37 | INFO | ingest_pipeline.cli.tui.utils.runners | All storage connections closed gracefully
|
||||
|
||||
154
prefect.yaml
Normal file
154
prefect.yaml
Normal file
@@ -0,0 +1,154 @@
|
||||
# Prefect deployment configuration for RAG Manager
|
||||
name: rag-manager
|
||||
prefect-version: 3.0.0
|
||||
|
||||
# Build steps - prepare the environment
|
||||
build:
|
||||
- prefect.deployments.steps.run_shell_script:
|
||||
id: prepare-environment
|
||||
script: |
|
||||
echo "Preparing RAG Manager environment..."
|
||||
# Ensure virtual environment is activated
|
||||
source .venv/bin/activate || echo "Virtual environment not found"
|
||||
# Install dependencies
|
||||
uv sync --frozen
|
||||
# Register custom blocks
|
||||
prefect block register --module ingest_pipeline.core.models
|
||||
|
||||
# Push steps - handle deployment artifacts
|
||||
push: null
|
||||
|
||||
# Work pool configuration
|
||||
work_pool:
|
||||
name: "{{ prefect.variables.work_pool_name | default('default') }}"
|
||||
work_queue_name: "{{ prefect.variables.work_queue_name | default('default') }}"
|
||||
job_variables:
|
||||
env:
|
||||
# Prefect configuration
|
||||
PREFECT_API_URL: "{{ $PREFECT_API_URL }}"
|
||||
PREFECT_API_KEY: "{{ $PREFECT_API_KEY }}"
|
||||
|
||||
# Application configuration from variables
|
||||
DEFAULT_BATCH_SIZE: "{{ prefect.variables.default_batch_size | default('50') }}"
|
||||
MAX_CRAWL_DEPTH: "{{ prefect.variables.max_crawl_depth | default('5') }}"
|
||||
MAX_CRAWL_PAGES: "{{ prefect.variables.max_crawl_pages | default('100') }}"
|
||||
MAX_CONCURRENT_TASKS: "{{ prefect.variables.max_concurrent_tasks | default('5') }}"
|
||||
|
||||
# Service endpoints from variables
|
||||
LLM_ENDPOINT: "{{ prefect.variables.llm_endpoint | default('http://llm.lab') }}"
|
||||
WEAVIATE_ENDPOINT: "{{ prefect.variables.weaviate_endpoint | default('http://weaviate.yo') }}"
|
||||
OPENWEBUI_ENDPOINT: "{{ prefect.variables.openwebui_endpoint | default('http://chat.lab') }}"
|
||||
FIRECRAWL_ENDPOINT: "{{ prefect.variables.firecrawl_endpoint | default('http://crawl.lab:30002') }}"
|
||||
|
||||
# Deployment definitions
|
||||
deployments:
|
||||
# Web ingestion deployment
|
||||
- name: web-ingestion
|
||||
version: "{{ prefect.variables.deployment_version | default('1.0.0') }}"
|
||||
tags:
|
||||
- "{{ prefect.variables.environment | default('development') }}"
|
||||
- web
|
||||
- ingestion
|
||||
description: "Automated web content ingestion using Firecrawl"
|
||||
entrypoint: ingest_pipeline/flows/ingestion.py:create_ingestion_flow
|
||||
parameters:
|
||||
source_type: web
|
||||
storage_backend: "{{ prefect.variables.default_storage_backend | default('weaviate') }}"
|
||||
validate_first: true
|
||||
storage_block_name: "{{ prefect.variables.default_storage_block }}"
|
||||
ingestor_config_block_name: "{{ prefect.variables.default_firecrawl_block }}"
|
||||
schedule:
|
||||
interval: "{{ prefect.variables.default_schedule_interval | default(3600) }}"
|
||||
timezone: UTC
|
||||
work_pool:
|
||||
name: "{{ prefect.variables.web_work_pool | default('default') }}"
|
||||
job_variables:
|
||||
env:
|
||||
INGESTION_TYPE: "web"
|
||||
MAX_PAGES: "{{ prefect.variables.max_crawl_pages | default('100') }}"
|
||||
|
||||
# Repository ingestion deployment
|
||||
- name: repository-ingestion
|
||||
version: "{{ prefect.variables.deployment_version | default('1.0.0') }}"
|
||||
tags:
|
||||
- "{{ prefect.variables.environment | default('development') }}"
|
||||
- repository
|
||||
- ingestion
|
||||
description: "Automated repository content ingestion using Repomix"
|
||||
entrypoint: ingest_pipeline/flows/ingestion.py:create_ingestion_flow
|
||||
parameters:
|
||||
source_type: repository
|
||||
storage_backend: "{{ prefect.variables.default_storage_backend | default('weaviate') }}"
|
||||
validate_first: true
|
||||
storage_block_name: "{{ prefect.variables.default_storage_block }}"
|
||||
ingestor_config_block_name: "{{ prefect.variables.default_repomix_block }}"
|
||||
schedule: null # Manual trigger only
|
||||
work_pool:
|
||||
name: "{{ prefect.variables.repo_work_pool | default('default') }}"
|
||||
job_variables:
|
||||
env:
|
||||
INGESTION_TYPE: "repository"
|
||||
|
||||
# R2R specialized deployment
|
||||
- name: firecrawl-to-r2r
|
||||
version: "{{ prefect.variables.deployment_version | default('1.0.0') }}"
|
||||
tags:
|
||||
- "{{ prefect.variables.environment | default('development') }}"
|
||||
- firecrawl
|
||||
- r2r
|
||||
- specialized
|
||||
description: "Optimized Firecrawl to R2R ingestion flow"
|
||||
entrypoint: ingest_pipeline/flows/ingestion.py:firecrawl_to_r2r_flow
|
||||
parameters:
|
||||
storage_block_name: "{{ prefect.variables.r2r_storage_block }}"
|
||||
schedule:
|
||||
cron: "{{ prefect.variables.r2r_cron_schedule | default('0 2 * * *') }}"
|
||||
timezone: UTC
|
||||
work_pool:
|
||||
name: "{{ prefect.variables.r2r_work_pool | default('default') }}"
|
||||
job_variables:
|
||||
env:
|
||||
INGESTION_TYPE: "r2r"
|
||||
SPECIALIZED_FLOW: "true"
|
||||
|
||||
# Automation definitions (commented out - would be created via API)
|
||||
# automations:
|
||||
# - name: Cancel Long Running Flows
|
||||
# description: Cancels flows running longer than 30 minutes
|
||||
# trigger:
|
||||
# type: event
|
||||
# posture: Proactive
|
||||
# expect: [prefect.flow-run.Running]
|
||||
# match_related:
|
||||
# prefect.resource.role: flow
|
||||
# prefect.resource.name: ingestion_pipeline
|
||||
# threshold: 1
|
||||
# within: 1800
|
||||
# actions:
|
||||
# - type: cancel-flow-run
|
||||
# source: inferred
|
||||
# enabled: true
|
||||
|
||||
# Variables that should be set for optimal operation
|
||||
# Use: prefect variable set <name> <value>
|
||||
# Required variables:
|
||||
# - default_storage_backend: weaviate|open_webui|r2r
|
||||
# - llm_endpoint: URL for LLM service
|
||||
# - weaviate_endpoint: URL for Weaviate instance
|
||||
# - openwebui_endpoint: URL for OpenWebUI instance
|
||||
# - firecrawl_endpoint: URL for Firecrawl service
|
||||
#
|
||||
# Optional variables with defaults:
|
||||
# - default_batch_size: 50
|
||||
# - max_crawl_depth: 5
|
||||
# - max_crawl_pages: 100
|
||||
# - max_concurrent_tasks: 5
|
||||
# - default_schedule_interval: 3600 (1 hour)
|
||||
# - deployment_version: 1.0.0
|
||||
# - environment: development
|
||||
|
||||
# Block types that should be registered:
|
||||
# - storage-config: Storage backend configurations
|
||||
# - firecrawl-config: Firecrawl scraping parameters
|
||||
# - repomix-config: Repository processing settings
|
||||
# - r2r-config: R2R-specific chunking and graph settings
|
||||
@@ -65,7 +65,7 @@ ignore = [
|
||||
"ingest_pipeline/cli/main.py" = ["B008"] # Typer uses function calls in defaults
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.11"
|
||||
python_version = "3.12"
|
||||
strict = true
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
@@ -73,6 +73,18 @@ ignore_missing_imports = true
|
||||
# Allow AsyncGenerator types in overrides
|
||||
disable_error_code = ["override"]
|
||||
|
||||
[tool.basedpyright]
|
||||
include = ["ingest_pipeline"]
|
||||
exclude = ["**/__pycache__", "**/.pytest_cache", "**/node_modules", ".venv", "build", "dist"]
|
||||
pythonVersion = "3.12"
|
||||
venvPath = "."
|
||||
venv = ".venv"
|
||||
typeCheckingMode = "standard"
|
||||
useLibraryCodeForTypes = true
|
||||
reportMissingTypeStubs = "none"
|
||||
reportMissingModuleSource = "none"
|
||||
reportAttributeAccessIssue = "warning"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
|
||||
45
pyrightconfig.json
Normal file
45
pyrightconfig.json
Normal file
@@ -0,0 +1,45 @@
|
||||
{
|
||||
"include": [
|
||||
"ingest_pipeline"
|
||||
],
|
||||
"exclude": [
|
||||
"**/__pycache__",
|
||||
"**/.pytest_cache",
|
||||
"**/node_modules",
|
||||
".venv",
|
||||
"build",
|
||||
"dist"
|
||||
],
|
||||
"pythonVersion": "3.12",
|
||||
"venvPath": ".",
|
||||
"venv": ".venv",
|
||||
"typeCheckingMode": "standard",
|
||||
"useLibraryCodeForTypes": true,
|
||||
"stubPath": "./typings",
|
||||
"reportCallInDefaultInitializer": "none",
|
||||
"reportUnknownVariableType": "warning",
|
||||
"reportUnknownMemberType": "warning",
|
||||
"reportUnknownArgumentType": "warning",
|
||||
"reportUnknownLambdaType": "warning",
|
||||
"reportUnknownParameterType": "warning",
|
||||
"reportMissingParameterType": "warning",
|
||||
"reportUnannotatedClassAttribute": "warning",
|
||||
"reportMissingTypeStubs": "none",
|
||||
"reportMissingModuleSource": "none",
|
||||
"reportImportCycles": "none",
|
||||
"reportUnusedImport": "warning",
|
||||
"reportUnusedClass": "warning",
|
||||
"reportUnusedFunction": "warning",
|
||||
"reportUnusedVariable": "warning",
|
||||
"reportDuplicateImport": "warning",
|
||||
"reportWildcardImportFromLibrary": "warning",
|
||||
"reportAny": "warning",
|
||||
"reportUnusedCallResult": "none",
|
||||
"reportUnnecessaryIsInstance": "none",
|
||||
"reportImplicitOverride": "none",
|
||||
"reportDeprecated": "warning",
|
||||
"reportIncompatibleMethodOverride": "error",
|
||||
"reportIncompatibleVariableOverride": "error",
|
||||
"reportInconsistentConstructor": "none",
|
||||
"analyzeUnannotatedFunctions": true
|
||||
}
|
||||
3448
repomix-output.xml
3448
repomix-output.xml
File diff suppressed because it is too large
Load Diff
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
BIN
tests/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
tests/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
tests/__pycache__/conftest.cpython-312-pytest-8.4.2.pyc
Normal file
BIN
tests/__pycache__/conftest.cpython-312-pytest-8.4.2.pyc
Normal file
Binary file not shown.
BIN
tests/__pycache__/conftest.cpython-312.pyc
Normal file
BIN
tests/__pycache__/conftest.cpython-312.pyc
Normal file
Binary file not shown.
BIN
tests/__pycache__/openapi_mocks.cpython-312.pyc
Normal file
BIN
tests/__pycache__/openapi_mocks.cpython-312.pyc
Normal file
Binary file not shown.
545
tests/conftest.py
Normal file
545
tests/conftest.py
Normal file
@@ -0,0 +1,545 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Protocol, TypedDict, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from pydantic import HttpUrl
|
||||
|
||||
from ingest_pipeline.core.models import (
|
||||
Document,
|
||||
DocumentMetadata,
|
||||
IngestionSource,
|
||||
StorageBackend,
|
||||
StorageConfig,
|
||||
VectorConfig,
|
||||
)
|
||||
from typings import EmbeddingData, EmbeddingResponse
|
||||
|
||||
from .openapi_mocks import (
|
||||
FirecrawlMockService,
|
||||
OpenAPISpec,
|
||||
OpenWebUIMockService,
|
||||
R2RMockService,
|
||||
)
|
||||
|
||||
# Type aliases for mock responses
|
||||
MockResponseData = dict[str, object]
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
|
||||
|
||||
class RequestRecord(TypedDict):
|
||||
"""Captured HTTP request payload."""
|
||||
|
||||
method: str
|
||||
url: str
|
||||
json_body: dict[str, object] | None
|
||||
params: dict[str, object] | None
|
||||
files: object | None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class StubbedResponse:
|
||||
"""In-memory HTTP response for httpx client mocking."""
|
||||
|
||||
payload: object
|
||||
status_code: int = 200
|
||||
|
||||
def json(self) -> object:
|
||||
return self.payload
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
if self.status_code < 400:
|
||||
return
|
||||
# Create a minimal exception for testing - we don't need full httpx objects for tests
|
||||
class TestHTTPError(Exception):
|
||||
request: object | None
|
||||
response: object | None
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
self.request = None
|
||||
self.response = None
|
||||
|
||||
raise TestHTTPError(f"Stubbed HTTP error with status {self.status_code}")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AsyncClientStub:
|
||||
"""Replacement for httpx.AsyncClient used in tests."""
|
||||
|
||||
responses: deque[StubbedResponse]
|
||||
requests: list[RequestRecord]
|
||||
owner: HttpxStub
|
||||
timeout: object | None = None
|
||||
headers: dict[str, str] = field(default_factory=dict)
|
||||
base_url: str = ""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
responses: deque[StubbedResponse],
|
||||
requests: list[RequestRecord],
|
||||
timeout: object | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
base_url: str | None = None,
|
||||
owner: HttpxStub,
|
||||
**_: object,
|
||||
) -> None:
|
||||
self.responses = responses
|
||||
self.requests = requests
|
||||
self.timeout = timeout
|
||||
self.headers = dict(headers or {})
|
||||
self.base_url = str(base_url or "")
|
||||
self.owner = owner
|
||||
|
||||
def _normalize_url(self, url: str) -> str:
|
||||
if url.startswith("http://") or url.startswith("https://"):
|
||||
return url
|
||||
if not self.base_url:
|
||||
return url
|
||||
prefix = self.base_url.rstrip("/")
|
||||
suffix = url.lstrip("/")
|
||||
return f"{prefix}/{suffix}" if suffix else prefix
|
||||
|
||||
def _consume(
|
||||
self,
|
||||
*,
|
||||
method: str,
|
||||
url: str,
|
||||
json: dict[str, object] | None,
|
||||
params: dict[str, object] | None,
|
||||
files: object | None,
|
||||
) -> StubbedResponse:
|
||||
if self.responses:
|
||||
return self.responses.popleft()
|
||||
dispatched = self.owner.dispatch(
|
||||
method=method,
|
||||
url=url,
|
||||
json=json,
|
||||
params=params,
|
||||
files=files,
|
||||
)
|
||||
if dispatched is not None:
|
||||
return dispatched
|
||||
raise AssertionError(f"No stubbed response for {method} {url}")
|
||||
|
||||
def _record(
|
||||
self,
|
||||
*,
|
||||
method: str,
|
||||
url: str,
|
||||
json: dict[str, object] | None,
|
||||
params: dict[str, object] | None,
|
||||
files: object | None,
|
||||
) -> str:
|
||||
normalized = self._normalize_url(url)
|
||||
record: RequestRecord = {
|
||||
"method": method,
|
||||
"url": normalized,
|
||||
"json_body": json,
|
||||
"params": params,
|
||||
"files": files,
|
||||
}
|
||||
self.requests.append(record)
|
||||
return normalized
|
||||
|
||||
async def post(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
json: dict[str, object] | None = None,
|
||||
files: object | None = None,
|
||||
params: dict[str, object] | None = None,
|
||||
) -> StubbedResponse:
|
||||
normalized = self._record(
|
||||
method="POST",
|
||||
url=url,
|
||||
json=json,
|
||||
params=params,
|
||||
files=files,
|
||||
)
|
||||
return self._consume(
|
||||
method="POST",
|
||||
url=normalized,
|
||||
json=json,
|
||||
params=params,
|
||||
files=files,
|
||||
)
|
||||
|
||||
async def get(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
params: dict[str, object] | None = None,
|
||||
) -> StubbedResponse:
|
||||
normalized = self._record(
|
||||
method="GET",
|
||||
url=url,
|
||||
json=None,
|
||||
params=params,
|
||||
files=None,
|
||||
)
|
||||
return self._consume(
|
||||
method="GET",
|
||||
url=normalized,
|
||||
json=None,
|
||||
params=params,
|
||||
files=None,
|
||||
)
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
params: dict[str, object] | None = None,
|
||||
json: dict[str, object] | None = None,
|
||||
) -> StubbedResponse:
|
||||
normalized = self._record(
|
||||
method="DELETE",
|
||||
url=url,
|
||||
json=json,
|
||||
params=params,
|
||||
files=None,
|
||||
)
|
||||
return self._consume(
|
||||
method="DELETE",
|
||||
url=normalized,
|
||||
json=json,
|
||||
params=params,
|
||||
files=None,
|
||||
)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
return None
|
||||
|
||||
async def __aenter__(self) -> AsyncClientStub:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: object | None,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class HttpxStub:
|
||||
"""Helper exposing queued responses and captured requests."""
|
||||
|
||||
responses: deque[StubbedResponse] = field(default_factory=deque)
|
||||
requests: list[RequestRecord] = field(default_factory=list)
|
||||
clients: list[AsyncClientStub] = field(default_factory=list)
|
||||
services: dict[str, MockService] = field(default_factory=dict)
|
||||
|
||||
def queue_json(self, payload: object, status_code: int = 200) -> None:
|
||||
self.responses.append(StubbedResponse(payload=payload, status_code=status_code))
|
||||
|
||||
def register_service(self, base_url: str, service: MockService) -> None:
|
||||
normalized = base_url.rstrip("/") or base_url
|
||||
self.services[normalized] = service
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
*,
|
||||
method: str,
|
||||
url: str,
|
||||
json: Mapping[str, object] | None,
|
||||
params: Mapping[str, object] | None,
|
||||
files: object | None,
|
||||
) -> StubbedResponse | None:
|
||||
parsed = urlparse(url)
|
||||
base = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme and parsed.netloc else ""
|
||||
base = base.rstrip("/") if base else base
|
||||
path = parsed.path or "/"
|
||||
service = self.services.get(base)
|
||||
if service is None:
|
||||
return None
|
||||
status, payload = service.handle(
|
||||
method=method,
|
||||
path=path,
|
||||
json=json,
|
||||
params=params,
|
||||
files=files,
|
||||
)
|
||||
return StubbedResponse(payload=payload, status_code=status)
|
||||
|
||||
|
||||
class DocumentFactory(Protocol):
|
||||
"""Callable protocol for building Document instances."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
content: str,
|
||||
metadata_updates: dict[str, object] | None = None,
|
||||
) -> Document:
|
||||
"""Create Document for testing."""
|
||||
...
|
||||
|
||||
|
||||
class VectorConfigFactory(Protocol):
|
||||
"""Callable protocol for building VectorConfig instances."""
|
||||
|
||||
def __call__(self, *, model: str, dimension: int, endpoint: str) -> VectorConfig:
|
||||
"""Create VectorConfig for testing."""
|
||||
...
|
||||
|
||||
|
||||
class EmbeddingPayloadFactory(Protocol):
|
||||
"""Callable protocol for synthesizing embedding responses."""
|
||||
|
||||
def __call__(self, *, dimension: int) -> EmbeddingResponse:
|
||||
"""Create embedding payload with the requested dimension."""
|
||||
...
|
||||
|
||||
|
||||
class MockService(Protocol):
|
||||
"""Protocol for mock services that can handle HTTP requests."""
|
||||
|
||||
def handle(
|
||||
self,
|
||||
*,
|
||||
method: str,
|
||||
path: str,
|
||||
json: Mapping[str, object] | None,
|
||||
params: Mapping[str, object] | None,
|
||||
files: object | None,
|
||||
) -> tuple[int, object]:
|
||||
"""Handle HTTP request and return status code and response payload."""
|
||||
...
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def base_metadata() -> DocumentMetadata:
|
||||
"""Provide reusable base metadata for document fixtures."""
|
||||
|
||||
required = {
|
||||
"source_url": "https://example.com/article",
|
||||
"timestamp": datetime.now(UTC),
|
||||
"content_type": "text/plain",
|
||||
"word_count": 100,
|
||||
"char_count": 500,
|
||||
}
|
||||
return cast(DocumentMetadata, cast(object, required))
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def document_factory(base_metadata: DocumentMetadata) -> DocumentFactory:
|
||||
"""Build Document models with deterministic defaults."""
|
||||
|
||||
def _factory(
|
||||
*,
|
||||
content: str,
|
||||
metadata_updates: dict[str, object] | None = None,
|
||||
) -> Document:
|
||||
metadata_dict = dict(base_metadata)
|
||||
if metadata_updates:
|
||||
metadata_dict.update(metadata_updates)
|
||||
return Document(
|
||||
content=content,
|
||||
metadata=cast(DocumentMetadata, cast(object, metadata_dict)),
|
||||
source=IngestionSource.WEB,
|
||||
)
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vector_config_factory() -> VectorConfigFactory:
|
||||
"""Construct VectorConfig instances for tests."""
|
||||
|
||||
def _factory(*, model: str, dimension: int, endpoint: str) -> VectorConfig:
|
||||
return VectorConfig(model=model, dimension=dimension, embedding_endpoint=HttpUrl(endpoint))
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def storage_config() -> StorageConfig:
|
||||
"""Provide canonical storage configuration for adapters."""
|
||||
|
||||
return StorageConfig(
|
||||
backend=StorageBackend.WEAVIATE,
|
||||
endpoint=HttpUrl("http://storage.local"),
|
||||
collection_name="documents",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def r2r_storage_config() -> StorageConfig:
|
||||
"""Provide storage configuration for R2R adapter tests."""
|
||||
|
||||
return StorageConfig(
|
||||
backend=StorageBackend.R2R,
|
||||
endpoint=HttpUrl("http://r2r.local"),
|
||||
collection_name="documents",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def openwebui_spec() -> OpenAPISpec:
|
||||
"""Load OpenWebUI OpenAPI specification."""
|
||||
|
||||
return OpenAPISpec.from_file(PROJECT_ROOT / "chat.json")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def r2r_spec() -> OpenAPISpec:
|
||||
"""Load R2R OpenAPI specification."""
|
||||
|
||||
return OpenAPISpec.from_file(PROJECT_ROOT / "r2r.json")
|
||||
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def firecrawl_spec() -> OpenAPISpec:
|
||||
"""Load Firecrawl OpenAPI specification."""
|
||||
|
||||
return OpenAPISpec.from_file(PROJECT_ROOT / "firecrawl.json")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def httpx_stub(monkeypatch: pytest.MonkeyPatch) -> HttpxStub:
|
||||
"""Replace httpx.AsyncClient with an in-memory stub."""
|
||||
|
||||
stub = HttpxStub()
|
||||
|
||||
def _client_factory(**kwargs: object) -> AsyncClientStub:
|
||||
client = AsyncClientStub(
|
||||
responses=stub.responses,
|
||||
requests=stub.requests,
|
||||
timeout=kwargs.get("timeout"),
|
||||
headers=cast(dict[str, str] | None, kwargs.get("headers")),
|
||||
base_url=cast(str | None, kwargs.get("base_url")),
|
||||
owner=stub,
|
||||
)
|
||||
stub.clients.append(client)
|
||||
return client
|
||||
|
||||
monkeypatch.setattr(httpx, "AsyncClient", _client_factory)
|
||||
monkeypatch.delenv("LLM_API_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
return stub
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def openwebui_service(
|
||||
httpx_stub: HttpxStub,
|
||||
openwebui_spec: OpenAPISpec,
|
||||
storage_config: StorageConfig,
|
||||
) -> OpenWebUIMockService:
|
||||
"""Stateful mock for OpenWebUI APIs."""
|
||||
|
||||
service = OpenWebUIMockService(
|
||||
base_url=str(storage_config.endpoint),
|
||||
spec=openwebui_spec,
|
||||
)
|
||||
httpx_stub.register_service(service.base_url, service)
|
||||
return service
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def r2r_service(
|
||||
httpx_stub: HttpxStub,
|
||||
r2r_spec: OpenAPISpec,
|
||||
r2r_storage_config: StorageConfig,
|
||||
) -> R2RMockService:
|
||||
"""Stateful mock for R2R APIs."""
|
||||
|
||||
service = R2RMockService(
|
||||
base_url=str(r2r_storage_config.endpoint),
|
||||
spec=r2r_spec,
|
||||
)
|
||||
httpx_stub.register_service(service.base_url, service)
|
||||
return service
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def firecrawl_service(
|
||||
httpx_stub: HttpxStub,
|
||||
firecrawl_spec: OpenAPISpec,
|
||||
) -> FirecrawlMockService:
|
||||
"""Stateful mock for Firecrawl APIs."""
|
||||
|
||||
service = FirecrawlMockService(
|
||||
base_url="http://crawl.lab:30002",
|
||||
spec=firecrawl_spec,
|
||||
)
|
||||
httpx_stub.register_service(service.base_url, service)
|
||||
return service
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def firecrawl_client_stub(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
firecrawl_service: FirecrawlMockService,
|
||||
) -> object:
|
||||
"""Patch AsyncFirecrawl to use the mock service."""
|
||||
|
||||
class AsyncFirecrawlStub:
|
||||
_service: FirecrawlMockService
|
||||
|
||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||
self._service = firecrawl_service
|
||||
|
||||
async def map(self, url: str, limit: int | None = None, **_: object) -> SimpleNamespace:
|
||||
payload = cast(MockResponseData, self._service.map_response(url, limit))
|
||||
links_data = cast(list[MockResponseData], payload.get("links", []))
|
||||
links = [SimpleNamespace(url=cast(str, item.get("url", ""))) for item in links_data]
|
||||
return SimpleNamespace(success=payload.get("success", True), links=links)
|
||||
|
||||
async def scrape(self, url: str, formats: list[str] | None = None, **_: object) -> SimpleNamespace:
|
||||
payload = cast(MockResponseData, self._service.scrape_response(url, formats))
|
||||
data = cast(MockResponseData, payload.get("data", {}))
|
||||
metadata_payload = cast(MockResponseData, data.get("metadata", {}))
|
||||
metadata_obj = SimpleNamespace(**metadata_payload)
|
||||
return SimpleNamespace(
|
||||
markdown=data.get("markdown"),
|
||||
html=data.get("html"),
|
||||
rawHtml=data.get("rawHtml"),
|
||||
links=data.get("links", []),
|
||||
metadata=metadata_obj,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
return None
|
||||
|
||||
async def __aenter__(self) -> AsyncFirecrawlStub:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: object | None,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"ingest_pipeline.ingestors.firecrawl.AsyncFirecrawl",
|
||||
AsyncFirecrawlStub,
|
||||
)
|
||||
return AsyncFirecrawlStub
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def embedding_payload_factory() -> EmbeddingPayloadFactory:
|
||||
"""Provide embedding payloads tailored to the requested dimension."""
|
||||
|
||||
def _factory(*, dimension: int) -> EmbeddingResponse:
|
||||
vector = [float(index) for index in range(dimension)]
|
||||
data_entry: EmbeddingData = {"embedding": vector}
|
||||
return {"data": [data_entry]}
|
||||
|
||||
return _factory
|
||||
807
tests/openapi_mocks.py
Normal file
807
tests/openapi_mocks.py
Normal file
@@ -0,0 +1,807 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json as json_module
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
class OpenAPISpec:
|
||||
"""Utility for generating example payloads from an OpenAPI document."""
|
||||
|
||||
def __init__(self, raw: Mapping[str, Any]):
|
||||
self._raw = raw
|
||||
self._paths = raw.get("paths", {})
|
||||
self._components = raw.get("components", {})
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, path: Path) -> OpenAPISpec:
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
raw = json_module.load(handle)
|
||||
return cls(raw)
|
||||
|
||||
def resolve_ref(self, ref: str) -> Mapping[str, Any]:
|
||||
if not ref.startswith("#/"):
|
||||
raise ValueError(f"Unsupported $ref format: {ref}")
|
||||
parts = ref.lstrip("#/").split("/")
|
||||
node: Any = self._raw
|
||||
for part in parts:
|
||||
if not isinstance(node, Mapping) or part not in node:
|
||||
raise KeyError(f"Unable to resolve reference: {ref}")
|
||||
node = node[part]
|
||||
if not isinstance(node, Mapping):
|
||||
raise TypeError(f"Reference {ref} did not resolve to an object")
|
||||
return node
|
||||
|
||||
def generate(self, schema: Mapping[str, Any] | None, name: str = "value") -> Any:
|
||||
if schema is None:
|
||||
return None
|
||||
if "$ref" in schema:
|
||||
resolved = self.resolve_ref(schema["$ref"])
|
||||
return self.generate(resolved, name=name)
|
||||
if "example" in schema:
|
||||
return copy.deepcopy(schema["example"])
|
||||
if "default" in schema and schema["default"] is not None:
|
||||
return copy.deepcopy(schema["default"])
|
||||
if "enum" in schema:
|
||||
enum_values = schema["enum"]
|
||||
if enum_values:
|
||||
return copy.deepcopy(enum_values[0])
|
||||
if "anyOf" in schema:
|
||||
for option in schema["anyOf"]:
|
||||
candidate = self.generate(option, name=name)
|
||||
if candidate is not None:
|
||||
return candidate
|
||||
if "oneOf" in schema:
|
||||
return self.generate(schema["oneOf"][0], name=name)
|
||||
if "allOf" in schema:
|
||||
result: dict[str, Any] = {}
|
||||
for option in schema["allOf"]:
|
||||
fragment = self.generate(option, name=name)
|
||||
if isinstance(fragment, Mapping):
|
||||
result.update(fragment)
|
||||
return result
|
||||
|
||||
type_name = schema.get("type")
|
||||
if type_name == "object":
|
||||
properties = schema.get("properties", {})
|
||||
required = schema.get("required", [])
|
||||
result: dict[str, Any] = {}
|
||||
keys = list(properties.keys()) or list(required)
|
||||
for key in keys:
|
||||
prop_schema = properties.get(key, {})
|
||||
result[key] = self.generate(prop_schema, name=key)
|
||||
additional = schema.get("additionalProperties")
|
||||
if additional and isinstance(additional, Mapping) and not properties:
|
||||
result["key"] = self.generate(additional, name="key")
|
||||
return result
|
||||
if type_name == "array":
|
||||
item_schema = schema.get("items", {})
|
||||
item = self.generate(item_schema, name=name)
|
||||
return [] if item is None else [item]
|
||||
if type_name == "string":
|
||||
format_hint = schema.get("format")
|
||||
if format_hint == "date-time":
|
||||
return "2024-01-01T00:00:00+00:00"
|
||||
if format_hint == "date":
|
||||
return "2024-01-01"
|
||||
if format_hint == "uuid":
|
||||
return "00000000-0000-4000-8000-000000000000"
|
||||
if format_hint == "uri":
|
||||
return "https://example.com"
|
||||
return f"{name}-value"
|
||||
if type_name == "integer":
|
||||
minimum = schema.get("minimum")
|
||||
if minimum is not None:
|
||||
return int(minimum)
|
||||
return 1
|
||||
if type_name == "number":
|
||||
return 1.0
|
||||
if type_name == "boolean":
|
||||
return True
|
||||
if type_name == "null":
|
||||
return None
|
||||
return {}
|
||||
|
||||
def _split_path(self, path: str) -> list[str]:
|
||||
if path in {"", "/"}:
|
||||
return []
|
||||
return [segment for segment in path.strip("/").split("/") if segment]
|
||||
|
||||
def find_operation(self, method: str, path: str) -> tuple[Mapping[str, Any] | None, dict[str, str]]:
|
||||
method = method.lower()
|
||||
normalized = "/" if path in {"", "/"} else "/" + path.strip("/")
|
||||
actual_segments = self._split_path(normalized)
|
||||
for template, operations in self._paths.items():
|
||||
operation = operations.get(method)
|
||||
if operation is None:
|
||||
continue
|
||||
template_path = template if template else "/"
|
||||
template_segments = self._split_path(template_path)
|
||||
if len(template_segments) != len(actual_segments):
|
||||
continue
|
||||
params: dict[str, str] = {}
|
||||
matched = True
|
||||
for template_part, actual_part in zip(template_segments, actual_segments, strict=False):
|
||||
if template_part.startswith("{") and template_part.endswith("}"):
|
||||
params[template_part[1:-1]] = actual_part
|
||||
elif template_part != actual_part:
|
||||
matched = False
|
||||
break
|
||||
if matched:
|
||||
return operation, params
|
||||
return None, {}
|
||||
|
||||
def build_response(
|
||||
self,
|
||||
operation: Mapping[str, Any],
|
||||
status: str | None = None,
|
||||
) -> tuple[int, Any]:
|
||||
responses = operation.get("responses", {})
|
||||
target_status = status or (
|
||||
"200" if "200" in responses else next(iter(responses), "200")
|
||||
)
|
||||
status_code = 200
|
||||
try:
|
||||
status_code = int(target_status)
|
||||
except ValueError:
|
||||
pass
|
||||
response_entry = responses.get(target_status, {})
|
||||
content = response_entry.get("content", {})
|
||||
media = None
|
||||
for candidate in ("application/json", "application/problem+json", "text/plain"):
|
||||
if candidate in content:
|
||||
media = content[candidate]
|
||||
break
|
||||
schema = media.get("schema") if media else None
|
||||
payload = self.generate(schema, name="response")
|
||||
return status_code, payload
|
||||
|
||||
def generate_from_ref(
|
||||
self,
|
||||
ref: str,
|
||||
overrides: Mapping[str, Any] | None = None,
|
||||
) -> Any:
|
||||
base = self.generate({"$ref": ref})
|
||||
if overrides is None or not isinstance(base, Mapping):
|
||||
return copy.deepcopy(base)
|
||||
merged = copy.deepcopy(base)
|
||||
for key, value in overrides.items():
|
||||
if isinstance(value, Mapping) and isinstance(merged.get(key), Mapping):
|
||||
merged[key] = self.generate_from_mapping(
|
||||
merged[key],
|
||||
value,
|
||||
)
|
||||
else:
|
||||
merged[key] = copy.deepcopy(value)
|
||||
return merged
|
||||
|
||||
def generate_from_mapping(
|
||||
self,
|
||||
base: Mapping[str, Any],
|
||||
overrides: Mapping[str, Any],
|
||||
) -> Mapping[str, Any]:
|
||||
result = copy.deepcopy(base)
|
||||
for key, value in overrides.items():
|
||||
if isinstance(value, Mapping) and isinstance(result.get(key), Mapping):
|
||||
result[key] = self.generate_from_mapping(result[key], value)
|
||||
else:
|
||||
result[key] = copy.deepcopy(value)
|
||||
return result
|
||||
|
||||
|
||||
class OpenAPIMockService:
|
||||
"""Base class for stateful mock services backed by an OpenAPI spec."""
|
||||
|
||||
def __init__(self, base_url: str, spec: OpenAPISpec):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.spec = spec
|
||||
|
||||
@staticmethod
|
||||
def _normalize_path(path: str) -> str:
|
||||
if not path or path == "/":
|
||||
return "/"
|
||||
return "/" + path.strip("/")
|
||||
|
||||
def handle(
|
||||
self,
|
||||
*,
|
||||
method: str,
|
||||
path: str,
|
||||
json: Mapping[str, Any] | None,
|
||||
params: Mapping[str, Any] | None,
|
||||
files: Any,
|
||||
) -> tuple[int, Any]:
|
||||
operation, _ = self.spec.find_operation(method, path)
|
||||
if operation is None:
|
||||
return 404, {"detail": f"Unhandled {method.upper()} {path}"}
|
||||
return self.spec.build_response(operation)
|
||||
|
||||
|
||||
class OpenWebUIMockService(OpenAPIMockService):
|
||||
"""Stateful mock capturing OpenWebUI knowledge and file operations."""
|
||||
|
||||
def __init__(self, base_url: str, spec: OpenAPISpec):
|
||||
super().__init__(base_url, spec)
|
||||
self._knowledge: dict[str, dict[str, Any]] = {}
|
||||
self._files: dict[str, dict[str, Any]] = {}
|
||||
self._knowledge_counter = 1
|
||||
self._file_counter = 1
|
||||
self._default_user = "user-1"
|
||||
|
||||
@staticmethod
|
||||
def _timestamp() -> int:
|
||||
return int(time.time())
|
||||
|
||||
def ensure_knowledge(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
description: str = "",
|
||||
knowledge_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
identifier = knowledge_id or f"kb-{self._knowledge_counter}"
|
||||
self._knowledge_counter += 1
|
||||
entry = self.spec.generate_from_ref("#/components/schemas/KnowledgeUserResponse")
|
||||
ts = self._timestamp()
|
||||
entry.update(
|
||||
{
|
||||
"id": identifier,
|
||||
"user_id": self._default_user,
|
||||
"name": name,
|
||||
"description": description or entry.get("description") or "",
|
||||
"created_at": ts,
|
||||
"updated_at": ts,
|
||||
"data": entry.get("data") or {},
|
||||
"meta": entry.get("meta") or {},
|
||||
"access_control": entry.get("access_control") or {},
|
||||
"files": [],
|
||||
}
|
||||
)
|
||||
self._knowledge[identifier] = entry
|
||||
return copy.deepcopy(entry)
|
||||
|
||||
def create_file(
|
||||
self,
|
||||
*,
|
||||
filename: str,
|
||||
user_id: str | None = None,
|
||||
file_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
identifier = file_id or f"file-{self._file_counter}"
|
||||
self._file_counter += 1
|
||||
entry = self.spec.generate_from_ref("#/components/schemas/FileModelResponse")
|
||||
ts = self._timestamp()
|
||||
entry.update(
|
||||
{
|
||||
"id": identifier,
|
||||
"user_id": user_id or self._default_user,
|
||||
"filename": filename,
|
||||
"meta": entry.get("meta") or {},
|
||||
"created_at": ts,
|
||||
"updated_at": ts,
|
||||
}
|
||||
)
|
||||
self._files[identifier] = entry
|
||||
return copy.deepcopy(entry)
|
||||
|
||||
def _build_file_metadata(self, file_id: str) -> dict[str, Any]:
|
||||
metadata = self.spec.generate_from_ref("#/components/schemas/FileMetadataResponse")
|
||||
source = self._files.get(file_id)
|
||||
ts = self._timestamp()
|
||||
metadata.update(
|
||||
{
|
||||
"id": file_id,
|
||||
"meta": (source or {}).get("meta", {}),
|
||||
"created_at": ts,
|
||||
"updated_at": ts,
|
||||
}
|
||||
)
|
||||
return metadata
|
||||
|
||||
def get_knowledge(self, knowledge_id: str) -> dict[str, Any] | None:
|
||||
entry = self._knowledge.get(knowledge_id)
|
||||
return copy.deepcopy(entry) if entry is not None else None
|
||||
|
||||
def find_knowledge_by_name(self, name: str) -> tuple[str, dict[str, Any]] | None:
|
||||
for identifier, entry in self._knowledge.items():
|
||||
if entry.get("name") == name:
|
||||
return identifier, copy.deepcopy(entry)
|
||||
return None
|
||||
|
||||
def attach_existing_file(self, knowledge_id: str, file_id: str) -> None:
|
||||
if knowledge_id not in self._knowledge:
|
||||
raise KeyError(f"Knowledge {knowledge_id} not found")
|
||||
knowledge = self._knowledge[knowledge_id]
|
||||
metadata = self._build_file_metadata(file_id)
|
||||
knowledge.setdefault("files", [])
|
||||
knowledge["files"] = [item for item in knowledge["files"] if item.get("id") != file_id]
|
||||
knowledge["files"].append(metadata)
|
||||
knowledge["updated_at"] = self._timestamp()
|
||||
|
||||
def _knowledge_user_response(self, knowledge: Mapping[str, Any]) -> dict[str, Any]:
|
||||
payload = self.spec.generate_from_ref("#/components/schemas/KnowledgeUserResponse")
|
||||
payload.update({
|
||||
"id": knowledge["id"],
|
||||
"user_id": knowledge["user_id"],
|
||||
"name": knowledge["name"],
|
||||
"description": knowledge.get("description", ""),
|
||||
"data": copy.deepcopy(knowledge.get("data", {})),
|
||||
"meta": copy.deepcopy(knowledge.get("meta", {})),
|
||||
"access_control": copy.deepcopy(knowledge.get("access_control", {})),
|
||||
"created_at": knowledge.get("created_at", self._timestamp()),
|
||||
"updated_at": knowledge.get("updated_at", self._timestamp()),
|
||||
"files": copy.deepcopy(knowledge.get("files", [])),
|
||||
})
|
||||
return payload
|
||||
|
||||
def _knowledge_files_response(self, knowledge: Mapping[str, Any]) -> dict[str, Any]:
|
||||
payload = self.spec.generate_from_ref("#/components/schemas/KnowledgeFilesResponse")
|
||||
payload.update(
|
||||
{
|
||||
"id": knowledge["id"],
|
||||
"user_id": knowledge["user_id"],
|
||||
"name": knowledge["name"],
|
||||
"description": knowledge.get("description", ""),
|
||||
"data": copy.deepcopy(knowledge.get("data", {})),
|
||||
"meta": copy.deepcopy(knowledge.get("meta", {})),
|
||||
"access_control": copy.deepcopy(knowledge.get("access_control", {})),
|
||||
"created_at": knowledge.get("created_at", self._timestamp()),
|
||||
"updated_at": knowledge.get("updated_at", self._timestamp()),
|
||||
"files": copy.deepcopy(knowledge.get("files", [])),
|
||||
}
|
||||
)
|
||||
return payload
|
||||
|
||||
def handle(
|
||||
self,
|
||||
*,
|
||||
method: str,
|
||||
path: str,
|
||||
json: Mapping[str, Any] | None,
|
||||
params: Mapping[str, Any] | None,
|
||||
files: Any,
|
||||
) -> tuple[int, Any]:
|
||||
normalized = self._normalize_path(path)
|
||||
segments = [segment for segment in normalized.strip("/").split("/") if segment]
|
||||
if segments[:3] != ["api", "v1", "knowledge"]:
|
||||
return super().handle(method=method, path=path, json=json, params=params, files=files)
|
||||
|
||||
method_upper = method.upper()
|
||||
segment_count = len(segments)
|
||||
|
||||
if method_upper == "GET" and segment_count == 4 and segments[3] == "list":
|
||||
body = [self._knowledge_user_response(entry) for entry in self._knowledge.values()]
|
||||
return 200, body
|
||||
|
||||
if method_upper == "GET" and segment_count == 3:
|
||||
body = [self._knowledge_user_response(entry) for entry in self._knowledge.values()]
|
||||
return 200, body
|
||||
|
||||
if method_upper == "POST" and segment_count == 4 and segments[3] == "create":
|
||||
payload = json or {}
|
||||
entry = self.ensure_knowledge(
|
||||
name=str(payload.get("name", "knowledge")),
|
||||
description=str(payload.get("description", "")),
|
||||
)
|
||||
return 200, entry
|
||||
|
||||
if segment_count >= 4:
|
||||
knowledge_id = segments[3]
|
||||
knowledge = self._knowledge.get(knowledge_id)
|
||||
if knowledge is None:
|
||||
return 404, {"detail": f"Knowledge {knowledge_id} not found"}
|
||||
|
||||
if method_upper == "GET" and segment_count == 4:
|
||||
return 200, self._knowledge_files_response(knowledge)
|
||||
|
||||
if method_upper == "POST" and segment_count == 6 and segments[4:] == ["file", "add"]:
|
||||
payload = json or {}
|
||||
file_id = str(payload.get("file_id", ""))
|
||||
if not file_id:
|
||||
return 422, {"detail": "file_id is required"}
|
||||
if file_id not in self._files:
|
||||
self.create_file(filename=f"{file_id}.txt")
|
||||
metadata = self._build_file_metadata(file_id)
|
||||
knowledge.setdefault("files", [])
|
||||
knowledge["files"] = [item for item in knowledge["files"] if item.get("id") != file_id]
|
||||
knowledge["files"].append(metadata)
|
||||
knowledge["updated_at"] = self._timestamp()
|
||||
return 200, self._knowledge_files_response(knowledge)
|
||||
|
||||
if method_upper == "POST" and segment_count == 6 and segments[4:] == ["file", "remove"]:
|
||||
payload = json or {}
|
||||
file_id = str(payload.get("file_id", ""))
|
||||
knowledge["files"] = [item for item in knowledge.get("files", []) if item.get("id") != file_id]
|
||||
knowledge["updated_at"] = self._timestamp()
|
||||
return 200, self._knowledge_files_response(knowledge)
|
||||
|
||||
if method_upper == "DELETE" and segment_count == 5 and segments[4] == "delete":
|
||||
self._knowledge.pop(knowledge_id, None)
|
||||
return 200, True
|
||||
|
||||
if method_upper == "POST" and segments == ["api", "v1", "files"]:
|
||||
filename = "uploaded.txt"
|
||||
if files and isinstance(files, Mapping):
|
||||
file_entry = files.get("file")
|
||||
if isinstance(file_entry, tuple) and len(file_entry) >= 1:
|
||||
filename = str(file_entry[0]) or filename
|
||||
entry = self.create_file(filename=filename)
|
||||
return 200, entry
|
||||
|
||||
if method_upper == "DELETE" and segments[:3] == ["api", "v1", "files"] and len(segments) == 4:
|
||||
file_id = segments[3]
|
||||
self._files.pop(file_id, None)
|
||||
for knowledge in self._knowledge.values():
|
||||
knowledge["files"] = [item for item in knowledge.get("files", []) if item.get("id") != file_id]
|
||||
return 200, {"deleted": True}
|
||||
|
||||
return super().handle(method=method, path=path, json=json, params=params, files=files)
|
||||
|
||||
|
||||
class R2RMockService(OpenAPIMockService):
|
||||
"""Stateful mock of core R2R collection endpoints."""
|
||||
|
||||
def __init__(self, base_url: str, spec: OpenAPISpec):
|
||||
super().__init__(base_url, spec)
|
||||
self._collections: dict[str, dict[str, Any]] = {}
|
||||
self._documents: dict[str, dict[str, Any]] = {}
|
||||
|
||||
@staticmethod
|
||||
def _iso_now() -> str:
|
||||
return datetime.now(tz=UTC).isoformat()
|
||||
|
||||
def create_collection(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
description: str = "",
|
||||
collection_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
identifier = collection_id or str(uuid4())
|
||||
entry = self.spec.generate_from_ref("#/components/schemas/CollectionResponse")
|
||||
timestamp = self._iso_now()
|
||||
entry.update(
|
||||
{
|
||||
"id": identifier,
|
||||
"owner_id": entry.get("owner_id") or str(uuid4()),
|
||||
"name": name,
|
||||
"description": description or entry.get("description") or "",
|
||||
"graph_cluster_status": entry.get("graph_cluster_status") or "idle",
|
||||
"graph_sync_status": entry.get("graph_sync_status") or "synced",
|
||||
"created_at": timestamp,
|
||||
"updated_at": timestamp,
|
||||
"user_count": entry.get("user_count", 1) or 1,
|
||||
"document_count": entry.get("document_count", 0) or 0,
|
||||
"documents": entry.get("documents", []),
|
||||
}
|
||||
)
|
||||
self._collections[identifier] = entry
|
||||
return copy.deepcopy(entry)
|
||||
|
||||
def get_collection(self, collection_id: str) -> dict[str, Any] | None:
|
||||
entry = self._collections.get(collection_id)
|
||||
return copy.deepcopy(entry) if entry is not None else None
|
||||
|
||||
def find_collection_by_name(self, name: str) -> tuple[str, dict[str, Any]] | None:
|
||||
for identifier, entry in self._collections.items():
|
||||
if entry.get("name") == name:
|
||||
return identifier, copy.deepcopy(entry)
|
||||
return None
|
||||
|
||||
def _set_collection_document_ids(self, collection_id: str, document_id: str, *, add: bool) -> None:
|
||||
collection = self._collections.get(collection_id)
|
||||
if collection is None:
|
||||
collection = self.create_collection(name=f"Collection {collection_id}", collection_id=collection_id)
|
||||
documents = collection.setdefault("documents", [])
|
||||
if add:
|
||||
if document_id not in documents:
|
||||
documents.append(document_id)
|
||||
else:
|
||||
documents[:] = [doc for doc in documents if doc != document_id]
|
||||
collection["document_count"] = len(documents)
|
||||
|
||||
def create_document(
|
||||
self,
|
||||
*,
|
||||
document_id: str,
|
||||
content: str,
|
||||
metadata: dict[str, Any],
|
||||
collection_ids: list[str],
|
||||
) -> dict[str, Any]:
|
||||
entry = self.spec.generate_from_ref("#/components/schemas/DocumentResponse")
|
||||
timestamp = self._iso_now()
|
||||
entry.update(
|
||||
{
|
||||
"id": document_id,
|
||||
"owner_id": entry.get("owner_id") or str(uuid4()),
|
||||
"collection_ids": collection_ids,
|
||||
"metadata": copy.deepcopy(metadata),
|
||||
"document_type": entry.get("document_type") or "text",
|
||||
"version": entry.get("version") or "1.0",
|
||||
"size_in_bytes": metadata.get("char_count") or len(content.encode("utf-8")),
|
||||
"created_at": entry.get("created_at") or timestamp,
|
||||
"updated_at": timestamp,
|
||||
"summary": entry.get("summary"),
|
||||
"summary_embedding": entry.get("summary_embedding") or None,
|
||||
"chunks": entry.get("chunks") or [],
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
entry["__content"] = content
|
||||
self._documents[document_id] = entry
|
||||
for collection_id in collection_ids:
|
||||
self._set_collection_document_ids(collection_id, document_id, add=True)
|
||||
return copy.deepcopy(entry)
|
||||
|
||||
def get_document(self, document_id: str) -> dict[str, Any] | None:
|
||||
entry = self._documents.get(document_id)
|
||||
if entry is None:
|
||||
return None
|
||||
return copy.deepcopy(entry)
|
||||
|
||||
def delete_document(self, document_id: str) -> bool:
|
||||
entry = self._documents.pop(document_id, None)
|
||||
if entry is None:
|
||||
return False
|
||||
for collection_id in entry.get("collection_ids", []):
|
||||
self._set_collection_document_ids(collection_id, document_id, add=False)
|
||||
return True
|
||||
|
||||
def append_document_metadata(self, document_id: str, metadata_list: list[dict[str, Any]]) -> dict[str, Any] | None:
|
||||
entry = self._documents.get(document_id)
|
||||
if entry is None:
|
||||
return None
|
||||
metadata = entry.setdefault("metadata", {})
|
||||
for item in metadata_list:
|
||||
for key, value in item.items():
|
||||
metadata[key] = value
|
||||
entry["updated_at"] = self._iso_now()
|
||||
return copy.deepcopy(entry)
|
||||
|
||||
def handle(
|
||||
self,
|
||||
*,
|
||||
method: str,
|
||||
path: str,
|
||||
json: Mapping[str, Any] | None,
|
||||
params: Mapping[str, Any] | None,
|
||||
files: Any,
|
||||
) -> tuple[int, Any]:
|
||||
normalized = self._normalize_path(path)
|
||||
method_upper = method.upper()
|
||||
|
||||
if normalized == "/v3/collections" and method_upper == "GET":
|
||||
payload = self.spec.generate_from_ref(
|
||||
"#/components/schemas/PaginatedR2RResult_list_CollectionResponse__"
|
||||
)
|
||||
results = []
|
||||
for _identifier, entry in self._collections.items():
|
||||
clone = copy.deepcopy(entry)
|
||||
clone["document_count"] = len(clone.get("documents", []))
|
||||
results.append(clone)
|
||||
payload.update(
|
||||
{
|
||||
"results": results,
|
||||
"total_entries": len(results),
|
||||
}
|
||||
)
|
||||
return 200, payload
|
||||
|
||||
if normalized == "/v3/collections" and method_upper == "POST":
|
||||
body = json or {}
|
||||
entry = self.create_collection(
|
||||
name=str(body.get("name", "Collection")),
|
||||
description=str(body.get("description", "")),
|
||||
)
|
||||
payload = self.spec.generate_from_ref("#/components/schemas/R2RResults_CollectionResponse_")
|
||||
payload.update({"results": entry})
|
||||
return 200, payload
|
||||
|
||||
segments = [segment for segment in normalized.strip("/").split("/") if segment]
|
||||
if segments[:2] == ["v3", "documents"]:
|
||||
if method_upper == "POST" and len(segments) == 2:
|
||||
metadata_raw = {}
|
||||
content = ""
|
||||
doc_id = str(uuid4())
|
||||
collection_ids: list[str] = []
|
||||
if isinstance(files, Mapping):
|
||||
if "metadata" in files:
|
||||
metadata_entry = files["metadata"]
|
||||
if isinstance(metadata_entry, tuple) and len(metadata_entry) >= 2:
|
||||
try:
|
||||
metadata_raw = json_module.loads(metadata_entry[1] or "{}")
|
||||
except json_module.JSONDecodeError:
|
||||
metadata_raw = {}
|
||||
if "raw_text" in files:
|
||||
raw_text_entry = files["raw_text"]
|
||||
if isinstance(raw_text_entry, tuple) and len(raw_text_entry) >= 2 and raw_text_entry[1] is not None:
|
||||
content = str(raw_text_entry[1])
|
||||
if "id" in files:
|
||||
id_entry = files["id"]
|
||||
if isinstance(id_entry, tuple) and len(id_entry) >= 2 and id_entry[1]:
|
||||
doc_id = str(id_entry[1])
|
||||
if "collection_ids" in files:
|
||||
coll_entry = files["collection_ids"]
|
||||
if isinstance(coll_entry, tuple) and len(coll_entry) >= 2 and coll_entry[1]:
|
||||
try:
|
||||
parsed = json_module.loads(coll_entry[1])
|
||||
if isinstance(parsed, list):
|
||||
collection_ids = [str(item) for item in parsed]
|
||||
except json_module.JSONDecodeError:
|
||||
collection_ids = []
|
||||
if not collection_ids:
|
||||
name = metadata_raw.get("collection_name") or metadata_raw.get("collection")
|
||||
if isinstance(name, str):
|
||||
located = self.find_collection_by_name(name)
|
||||
if located:
|
||||
collection_ids = [located[0]]
|
||||
if not collection_ids and self._collections:
|
||||
collection_ids = [next(iter(self._collections))]
|
||||
document = self.create_document(
|
||||
document_id=doc_id,
|
||||
content=content,
|
||||
metadata=metadata_raw,
|
||||
collection_ids=collection_ids,
|
||||
)
|
||||
response_payload = self.spec.generate_from_ref(
|
||||
"#/components/schemas/R2RResults_IngestionResponse_"
|
||||
)
|
||||
ingestion = self.spec.generate_from_ref("#/components/schemas/IngestionResponse")
|
||||
ingestion.update(
|
||||
{
|
||||
"message": ingestion.get("message") or "Ingestion task queued successfully.",
|
||||
"document_id": document["id"],
|
||||
"task_id": ingestion.get("task_id") or str(uuid4()),
|
||||
}
|
||||
)
|
||||
response_payload["results"] = ingestion
|
||||
return 202, response_payload
|
||||
|
||||
if method_upper == "GET" and len(segments) == 3:
|
||||
doc_id = segments[2]
|
||||
document = self.get_document(doc_id)
|
||||
if document is None:
|
||||
return 404, {"detail": f"Document {doc_id} not found"}
|
||||
response_payload = self.spec.generate_from_ref(
|
||||
"#/components/schemas/R2RResults_DocumentResponse_"
|
||||
)
|
||||
response_payload["results"] = document
|
||||
return 200, response_payload
|
||||
|
||||
if method_upper == "DELETE" and len(segments) == 3:
|
||||
doc_id = segments[2]
|
||||
success = self.delete_document(doc_id)
|
||||
response_payload = self.spec.generate_from_ref(
|
||||
"#/components/schemas/R2RResults_GenericBooleanResponse_"
|
||||
)
|
||||
results = response_payload.get("results")
|
||||
if isinstance(results, Mapping):
|
||||
results = copy.deepcopy(results)
|
||||
results.update({"success": success})
|
||||
else:
|
||||
results = {"success": success}
|
||||
response_payload["results"] = results
|
||||
status = 200 if success else 404
|
||||
return status, response_payload
|
||||
|
||||
if method_upper == "PATCH" and len(segments) == 4 and segments[3] == "metadata":
|
||||
doc_id = segments[2]
|
||||
metadata_list = []
|
||||
if isinstance(json, list):
|
||||
metadata_list = [dict(item) for item in json]
|
||||
document = self.append_document_metadata(doc_id, metadata_list)
|
||||
if document is None:
|
||||
return 404, {"detail": f"Document {doc_id} not found"}
|
||||
response_payload = self.spec.generate_from_ref(
|
||||
"#/components/schemas/R2RResults_DocumentResponse_"
|
||||
)
|
||||
response_payload["results"] = document
|
||||
return 200, response_payload
|
||||
|
||||
return super().handle(method=method, path=path, json=json, params=params, files=files)
|
||||
|
||||
|
||||
class FirecrawlMockService(OpenAPIMockService):
|
||||
"""Stateful mock for Firecrawl map and scrape endpoints."""
|
||||
|
||||
def __init__(self, base_url: str, spec: OpenAPISpec) -> None:
|
||||
super().__init__(base_url, spec)
|
||||
self._maps: dict[str, list[str]] = {}
|
||||
self._pages: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def register_map_result(self, origin: str, links: list[str]) -> None:
|
||||
self._maps[origin] = list(links)
|
||||
|
||||
def register_page(self, url: str, *, markdown: str | None = None, html: str | None = None, metadata: Mapping[str, Any] | None = None, links: list[str] | None = None) -> None:
|
||||
self._pages[url] = {
|
||||
"markdown": markdown,
|
||||
"html": html,
|
||||
"metadata": dict(metadata or {}),
|
||||
"links": list(links or []),
|
||||
}
|
||||
|
||||
def _build_map_payload(self, target_url: str, limit: int | None) -> dict[str, Any]:
|
||||
payload = self.spec.generate_from_ref("#/components/schemas/MapResponse")
|
||||
links = self._maps.get(target_url, [target_url])
|
||||
if limit is not None:
|
||||
try:
|
||||
limit_value = int(limit)
|
||||
if limit_value >= 0:
|
||||
links = links[:limit_value]
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
payload["success"] = True
|
||||
payload["links"] = [{"url": link} for link in links]
|
||||
return payload
|
||||
|
||||
def _default_metadata(self, url: str) -> dict[str, Any]:
|
||||
metadata = self.spec.generate_from_ref("#/components/schemas/ScrapeMetadata")
|
||||
metadata.update(
|
||||
{
|
||||
"url": url,
|
||||
"sourceURL": url,
|
||||
"scrapeId": metadata.get("scrapeId") or str(uuid4()),
|
||||
"statusCode": metadata.get("statusCode", 200) or 200,
|
||||
"contentType": metadata.get("contentType", "text/html") or "text/html",
|
||||
"creditsUsed": metadata.get("creditsUsed", 1) or 1,
|
||||
}
|
||||
)
|
||||
return metadata
|
||||
|
||||
def _build_scrape_payload(self, target_url: str, formats: list[str] | None) -> dict[str, Any]:
|
||||
payload = self.spec.generate_from_ref("#/components/schemas/ScrapeResponse")
|
||||
payload["success"] = True
|
||||
page_info = self._pages.get(target_url, {})
|
||||
data = payload.get("data", {})
|
||||
markdown = page_info.get("markdown") or f"# Content for {target_url}\n"
|
||||
html = page_info.get("html") or f"<h1>Content for {target_url}</h1>"
|
||||
data.update(
|
||||
{
|
||||
"markdown": markdown,
|
||||
"html": html,
|
||||
"rawHtml": page_info.get("rawHtml", html),
|
||||
"links": page_info.get("links", []),
|
||||
}
|
||||
)
|
||||
metadata_payload = self._default_metadata(target_url)
|
||||
metadata_payload.update(page_info.get("metadata", {}))
|
||||
data["metadata"] = metadata_payload
|
||||
payload["data"] = data
|
||||
return payload
|
||||
|
||||
def map_response(self, url: str, limit: int | None = None) -> dict[str, Any]:
|
||||
return self._build_map_payload(url, limit)
|
||||
|
||||
def scrape_response(self, url: str, formats: list[str] | None = None) -> dict[str, Any]:
|
||||
return self._build_scrape_payload(url, formats)
|
||||
|
||||
def handle(
|
||||
self,
|
||||
*,
|
||||
method: str,
|
||||
path: str,
|
||||
json: Mapping[str, Any] | None,
|
||||
params: Mapping[str, Any] | None,
|
||||
files: Any,
|
||||
) -> tuple[int, Any]:
|
||||
normalized = self._normalize_path(path)
|
||||
method_upper = method.upper()
|
||||
|
||||
if normalized == "/v2/map" and method_upper == "POST":
|
||||
body = json or {}
|
||||
target_url = str(body.get("url", ""))
|
||||
limit = body.get("limit")
|
||||
return 200, self._build_map_payload(target_url, limit)
|
||||
|
||||
if normalized == "/v2/scrape" and method_upper == "POST":
|
||||
body = json or {}
|
||||
target_url = str(body.get("url", ""))
|
||||
formats = body.get("formats")
|
||||
return 200, self._build_scrape_payload(target_url, formats)
|
||||
|
||||
return super().handle(method=method, path=path, json=json, params=params, files=files)
|
||||
0
tests/unit/__init__.py
Normal file
0
tests/unit/__init__.py
Normal file
BIN
tests/unit/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
tests/unit/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
1
tests/unit/cli/__init__.py
Normal file
1
tests/unit/cli/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""CLI tests package."""
|
||||
BIN
tests/unit/cli/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
tests/unit/cli/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
147
tests/unit/cli/test_main_cli.py
Normal file
147
tests/unit/cli/test_main_cli.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from ingest_pipeline.cli import main
|
||||
from ingest_pipeline.core.models import (
|
||||
IngestionResult,
|
||||
IngestionSource,
|
||||
IngestionStatus,
|
||||
StorageBackend,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("collection_arg", "expected"),
|
||||
(
|
||||
(None, "docs_example_com_web"),
|
||||
("explicit_collection", "explicit_collection"),
|
||||
),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_ingestion_collection_resolution(monkeypatch: pytest.MonkeyPatch, collection_arg: str | None, expected: str) -> None:
|
||||
recorded: dict[str, object] = {}
|
||||
|
||||
async def fake_flow(**kwargs: object) -> IngestionResult:
|
||||
recorded.update(kwargs)
|
||||
return IngestionResult(
|
||||
job_id=uuid4(),
|
||||
status=IngestionStatus.COMPLETED,
|
||||
documents_processed=7,
|
||||
documents_failed=0,
|
||||
duration_seconds=1.5,
|
||||
error_messages=[],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(main, "create_ingestion_flow", fake_flow)
|
||||
|
||||
result = await main.run_ingestion(
|
||||
url="https://docs.example.com/guide",
|
||||
source_type=IngestionSource.WEB,
|
||||
storage_backend=StorageBackend.WEAVIATE,
|
||||
collection_name=collection_arg,
|
||||
validate_first=False,
|
||||
)
|
||||
|
||||
assert recorded["collection_name"] == expected
|
||||
assert result.documents_processed == 7
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("cron_value", "serve_now", "expected_type", "serve_expected"),
|
||||
(
|
||||
("0 * * * *", False, "cron", False),
|
||||
(None, True, "interval", True),
|
||||
),
|
||||
)
|
||||
def test_schedule_creates_deployment(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
cron_value: str | None,
|
||||
serve_now: bool,
|
||||
expected_type: str,
|
||||
serve_expected: bool,
|
||||
) -> None:
|
||||
recorded: dict[str, object] = {}
|
||||
served = {"called": False}
|
||||
|
||||
def fake_create_scheduled_deployment(**kwargs: object) -> str:
|
||||
recorded.update(kwargs)
|
||||
return "deployment"
|
||||
|
||||
def fake_serve_deployments(deployments: list[str]) -> None:
|
||||
served["called"] = True
|
||||
assert deployments == ["deployment"]
|
||||
|
||||
monkeypatch.setattr(main, "create_scheduled_deployment", fake_create_scheduled_deployment)
|
||||
monkeypatch.setattr(main, "serve_deployments", fake_serve_deployments)
|
||||
|
||||
main.schedule(
|
||||
name="nightly",
|
||||
source_url="https://example.com",
|
||||
source_type=IngestionSource.WEB,
|
||||
storage=StorageBackend.WEAVIATE,
|
||||
cron=cron_value,
|
||||
interval=30,
|
||||
serve_now=serve_now,
|
||||
)
|
||||
|
||||
assert recorded["schedule_type"] == expected_type
|
||||
assert served["called"] is serve_expected
|
||||
|
||||
|
||||
def test_serve_tui_launch(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
invoked = {"count": 0}
|
||||
|
||||
def fake_dashboard() -> None:
|
||||
invoked["count"] = invoked["count"] + 1
|
||||
|
||||
monkeypatch.setattr("ingest_pipeline.cli.tui.dashboard", fake_dashboard)
|
||||
|
||||
main.serve(ui="tui")
|
||||
|
||||
assert invoked["count"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_search_collects_results(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
messages: list[object] = []
|
||||
|
||||
class DummyConsole:
|
||||
def print(self, message: object) -> None:
|
||||
messages.append(message)
|
||||
|
||||
class WeaviateStub:
|
||||
def __init__(self, config: object) -> None:
|
||||
self.config = config
|
||||
self.initialized = False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.initialized = True
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
threshold: float = 0.7,
|
||||
*,
|
||||
collection_name: str | None = None,
|
||||
) -> object:
|
||||
yield SimpleNamespace(title="Title", content="Body text", score=0.91)
|
||||
|
||||
dummy_settings = SimpleNamespace(
|
||||
weaviate_endpoint="http://weaviate.local",
|
||||
weaviate_api_key="token",
|
||||
openwebui_endpoint="http://chat.local",
|
||||
openwebui_api_key=None,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(main, "console", DummyConsole())
|
||||
monkeypatch.setattr(main, "get_settings", lambda: dummy_settings)
|
||||
monkeypatch.setattr("ingest_pipeline.storage.weaviate.WeaviateStorage", WeaviateStub)
|
||||
|
||||
await main.run_search("query", collection=None, backend="weaviate", limit=1)
|
||||
|
||||
assert messages[-1] == "\n✅ [green]Found 1 results[/green]"
|
||||
1
tests/unit/flows/__init__.py
Normal file
1
tests/unit/flows/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Flow tests package."""
|
||||
BIN
tests/unit/flows/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
tests/unit/flows/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
131
tests/unit/flows/test_ingestion_flow.py
Normal file
131
tests/unit/flows/test_ingestion_flow.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from ingest_pipeline.core.models import (
|
||||
IngestionJob,
|
||||
IngestionSource,
|
||||
StorageBackend,
|
||||
)
|
||||
from ingest_pipeline.flows import ingestion
|
||||
from ingest_pipeline.flows.ingestion import (
|
||||
FirecrawlIngestor,
|
||||
_chunk_urls,
|
||||
_deduplicate_urls,
|
||||
filter_existing_documents_task,
|
||||
ingest_documents_task,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("urls", "size", "expected"),
|
||||
(
|
||||
(["a", "b", "c", "d"], 2, [["a", "b"], ["c", "d"]]),
|
||||
(["solo"], 3, [["solo"]]),
|
||||
),
|
||||
)
|
||||
def test_chunk_urls_batches_urls(urls: list[str], size: int, expected: list[list[str]]) -> None:
|
||||
assert _chunk_urls(urls, size) == expected
|
||||
|
||||
|
||||
def test_chunk_urls_rejects_zero() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
_chunk_urls(["item"], 0)
|
||||
|
||||
|
||||
def test_deduplicate_urls_preserves_order() -> None:
|
||||
urls = ["https://example.com", "https://example.com", "https://other.com"]
|
||||
|
||||
unique = _deduplicate_urls(urls)
|
||||
|
||||
assert unique == ["https://example.com", "https://other.com"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_existing_documents_task_filters_known_urls(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
existing_url = "https://keep.example.com"
|
||||
new_url = "https://new.example.com"
|
||||
existing_id = str(FirecrawlIngestor.compute_document_id(existing_url))
|
||||
|
||||
class StubStorage:
|
||||
display_name = "stub-storage"
|
||||
|
||||
async def check_exists(
|
||||
self,
|
||||
document_id: str,
|
||||
*,
|
||||
collection_name: str | None = None,
|
||||
stale_after_days: int,
|
||||
) -> bool:
|
||||
return document_id == existing_id
|
||||
|
||||
monkeypatch.setattr(
|
||||
ingestion,
|
||||
"get_run_logger",
|
||||
lambda: SimpleNamespace(info=lambda *args, **kwargs: None),
|
||||
)
|
||||
|
||||
eligible = await filter_existing_documents_task.fn(
|
||||
[existing_url, new_url],
|
||||
StubStorage(),
|
||||
stale_after_days=30,
|
||||
)
|
||||
|
||||
assert eligible == [new_url]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_documents_task_invokes_helpers(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
job = IngestionJob(
|
||||
source_url="https://docs.example.com",
|
||||
source_type=IngestionSource.WEB,
|
||||
storage_backend=StorageBackend.WEAVIATE,
|
||||
)
|
||||
ingestor_sentinel = object()
|
||||
storage_sentinel = object()
|
||||
progress_events: list[tuple[int, str]] = []
|
||||
|
||||
def record_progress(percent: int, message: str) -> None:
|
||||
progress_events.append((percent, message))
|
||||
|
||||
async def fake_create_storage(
|
||||
job_arg: IngestionJob,
|
||||
collection_name: str | None,
|
||||
storage_block_name: str | None = None,
|
||||
) -> object:
|
||||
return storage_sentinel
|
||||
|
||||
async def fake_create_ingestor(job_arg: IngestionJob, config_block_name: str | None = None) -> object:
|
||||
return ingestor_sentinel
|
||||
|
||||
monkeypatch.setattr(ingestion, "_create_ingestor", fake_create_ingestor)
|
||||
monkeypatch.setattr(ingestion, "_create_storage", fake_create_storage)
|
||||
fake_process = AsyncMock(return_value=(5, 1))
|
||||
monkeypatch.setattr(ingestion, "_process_documents", fake_process)
|
||||
|
||||
result = await ingest_documents_task.fn(
|
||||
job,
|
||||
collection_name="unit",
|
||||
progress_callback=record_progress,
|
||||
)
|
||||
|
||||
assert result == (5, 1)
|
||||
assert progress_events == [
|
||||
(35, "Creating ingestor and storage clients..."),
|
||||
(40, "Starting document processing..."),
|
||||
]
|
||||
fake_process.assert_awaited_once_with(
|
||||
ingestor_sentinel,
|
||||
storage_sentinel,
|
||||
job,
|
||||
50,
|
||||
"unit",
|
||||
record_progress,
|
||||
)
|
||||
72
tests/unit/flows/test_scheduler.py
Normal file
72
tests/unit/flows/test_scheduler.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from ingest_pipeline.flows import scheduler
|
||||
|
||||
|
||||
def test_create_scheduled_deployment_cron(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class DummyFlow:
|
||||
def to_deployment(self, **kwargs: object) -> SimpleNamespace:
|
||||
captured.update(kwargs)
|
||||
return SimpleNamespace(**kwargs)
|
||||
|
||||
monkeypatch.setattr(scheduler, "create_ingestion_flow", DummyFlow())
|
||||
|
||||
deployment = scheduler.create_scheduled_deployment(
|
||||
name="cron-ingestion",
|
||||
source_url="https://example.com",
|
||||
source_type="web",
|
||||
schedule_type="cron",
|
||||
cron_expression="0 * * * *",
|
||||
)
|
||||
|
||||
assert captured["schedule"].cron == "0 * * * *"
|
||||
assert captured["parameters"]["source_type"] == "web"
|
||||
assert deployment.tags == ["web", "weaviate"]
|
||||
|
||||
|
||||
def test_create_scheduled_deployment_interval(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class DummyFlow:
|
||||
def to_deployment(self, **kwargs: object) -> SimpleNamespace:
|
||||
captured.update(kwargs)
|
||||
return SimpleNamespace(**kwargs)
|
||||
|
||||
monkeypatch.setattr(scheduler, "create_ingestion_flow", DummyFlow())
|
||||
|
||||
deployment = scheduler.create_scheduled_deployment(
|
||||
name="interval-ingestion",
|
||||
source_url="https://repo.example.com",
|
||||
source_type="repository",
|
||||
storage_backend="open_webui",
|
||||
schedule_type="interval",
|
||||
interval_minutes=15,
|
||||
tags=["custom"],
|
||||
)
|
||||
|
||||
assert captured["schedule"].interval == timedelta(minutes=15)
|
||||
assert captured["tags"] == ["custom"]
|
||||
assert deployment.parameters["storage_backend"] == "open_webui"
|
||||
|
||||
|
||||
def test_serve_deployments_invokes_prefect(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
called: dict[str, object] = {}
|
||||
|
||||
def fake_serve(*deployments: object, limit: int) -> None:
|
||||
called["deployments"] = deployments
|
||||
called["limit"] = limit
|
||||
|
||||
monkeypatch.setattr(scheduler, "serve", fake_serve)
|
||||
|
||||
deployment = SimpleNamespace(name="only")
|
||||
scheduler.serve_deployments([deployment])
|
||||
|
||||
assert called["deployments"] == (deployment,)
|
||||
assert called["limit"] == 10
|
||||
0
tests/unit/ingestors/__init__.py
Normal file
0
tests/unit/ingestors/__init__.py
Normal file
BIN
tests/unit/ingestors/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
tests/unit/ingestors/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
64
tests/unit/ingestors/test_firecrawl_ingestor.py
Normal file
64
tests/unit/ingestors/test_firecrawl_ingestor.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from ingest_pipeline.core.models import IngestionJob, IngestionSource, StorageBackend
|
||||
from ingest_pipeline.ingestors.firecrawl import FirecrawlIngestor
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_firecrawl_ingest_flows(
|
||||
firecrawl_service,
|
||||
firecrawl_client_stub,
|
||||
) -> None:
|
||||
base_url = "https://example.com"
|
||||
page_urls = [f"{base_url}/docs", f"{base_url}/about"]
|
||||
firecrawl_service.register_map_result(base_url, page_urls)
|
||||
|
||||
for index, url in enumerate(page_urls, start=1):
|
||||
firecrawl_service.register_page(
|
||||
url,
|
||||
markdown=f"# Page {index}\nContent body",
|
||||
metadata={
|
||||
"title": f"Page {index}",
|
||||
"description": f"Description {index}",
|
||||
"statusCode": 200,
|
||||
"language": "en",
|
||||
"sourceURL": url,
|
||||
},
|
||||
)
|
||||
|
||||
ingestor = FirecrawlIngestor()
|
||||
job = IngestionJob(
|
||||
source_type=IngestionSource.WEB,
|
||||
source_url=base_url,
|
||||
storage_backend=StorageBackend.WEAVIATE,
|
||||
)
|
||||
|
||||
documents = [document async for document in ingestor.ingest(job)]
|
||||
|
||||
assert {doc.metadata["title"] for doc in documents} == {"Page 1", "Page 2"}
|
||||
assert all(doc.content.startswith("# Page") for doc in documents)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_firecrawl_validate_source(
|
||||
firecrawl_service,
|
||||
firecrawl_client_stub,
|
||||
) -> None:
|
||||
target_url = "https://validate.example.com"
|
||||
firecrawl_service.register_page(
|
||||
target_url,
|
||||
markdown="# Validation Page\n",
|
||||
metadata={
|
||||
"title": "Validation Page",
|
||||
"statusCode": 200,
|
||||
"language": "en",
|
||||
"sourceURL": target_url,
|
||||
},
|
||||
)
|
||||
|
||||
ingestor = FirecrawlIngestor()
|
||||
|
||||
assert await ingestor.validate_source(target_url) is True
|
||||
assert await ingestor.estimate_size(target_url) == 1
|
||||
85
tests/unit/ingestors/test_repomix_ingestor.py
Normal file
85
tests/unit/ingestors/test_repomix_ingestor.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from ingest_pipeline.core.models import IngestionJob, IngestionSource, StorageBackend
|
||||
from ingest_pipeline.ingestors.repomix import RepomixIngestor
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("content", "expected_keys"),
|
||||
(
|
||||
("## File: src/app.py\nprint(\"hi\")", ["src/app.py"]),
|
||||
("plain content without markers", ["repository"]),
|
||||
),
|
||||
)
|
||||
def test_split_by_files_detects_file_markers(content: str, expected_keys: list[str]) -> None:
|
||||
ingestor = RepomixIngestor()
|
||||
|
||||
results = ingestor._split_by_files(content)
|
||||
|
||||
assert list(results) == expected_keys
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("content", "chunk_size", "expected"),
|
||||
(
|
||||
("line-one\nline-two\nline-three", 9, ["line-one", "line-two", "line-three"]),
|
||||
("single-line", 50, ["single-line"]),
|
||||
),
|
||||
)
|
||||
def test_chunk_content_respects_max_size(
|
||||
content: str,
|
||||
chunk_size: int,
|
||||
expected: list[str],
|
||||
) -> None:
|
||||
ingestor = RepomixIngestor()
|
||||
|
||||
chunks = ingestor._chunk_content(content, chunk_size=chunk_size)
|
||||
|
||||
assert chunks == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("file_path", "content", "expected"),
|
||||
(
|
||||
("src/app.py", "def feature():\n return True", "python"),
|
||||
("scripts/run", "#!/usr/bin/env python\nprint(\"ok\")", "python"),
|
||||
("documentation.md", "# Title", "markdown"),
|
||||
("unknown.ext", "text", None),
|
||||
),
|
||||
)
|
||||
def test_detect_programming_language_infers_extension(
|
||||
file_path: str,
|
||||
content: str,
|
||||
expected: str | None,
|
||||
) -> None:
|
||||
ingestor = RepomixIngestor()
|
||||
|
||||
detected = ingestor._detect_programming_language(file_path, content)
|
||||
|
||||
assert detected == expected
|
||||
|
||||
|
||||
def test_create_document_enriches_metadata() -> None:
|
||||
ingestor = RepomixIngestor()
|
||||
job = IngestionJob(
|
||||
source_url="https://example.com/repo.git",
|
||||
source_type=IngestionSource.REPOSITORY,
|
||||
storage_backend=StorageBackend.WEAVIATE,
|
||||
)
|
||||
|
||||
document = ingestor._create_document(
|
||||
"src/module.py",
|
||||
"def alpha():\n return 42\n",
|
||||
job,
|
||||
chunk_index=1,
|
||||
git_metadata={"branch_name": "main", "commit_hash": "deadbeef"},
|
||||
repo_info={"repository_name": "demo"},
|
||||
)
|
||||
|
||||
assert document.metadata["repository_name"] == "demo"
|
||||
assert document.metadata["branch_name"] == "main"
|
||||
assert document.metadata["commit_hash"] == "deadbeef"
|
||||
assert document.metadata["title"].endswith("(chunk 1)")
|
||||
assert document.collection == job.storage_backend.value
|
||||
0
tests/unit/storage/__init__.py
Normal file
0
tests/unit/storage/__init__.py
Normal file
BIN
tests/unit/storage/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
tests/unit/storage/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
102
tests/unit/storage/test_base_storage.py
Normal file
102
tests/unit/storage/test_base_storage.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from ingest_pipeline.core.models import Document, StorageConfig
|
||||
from ingest_pipeline.storage.base import BaseStorage
|
||||
|
||||
|
||||
class StubStorage(BaseStorage):
|
||||
"""Concrete BaseStorage for exercising shared logic."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: StorageConfig,
|
||||
*,
|
||||
result: Document | None = None,
|
||||
error: BaseException | None = None,
|
||||
) -> None:
|
||||
super().__init__(config)
|
||||
self._result = result
|
||||
self._error = error
|
||||
|
||||
async def initialize(self) -> None: # pragma: no cover - not used in tests
|
||||
return None
|
||||
|
||||
async def store(
|
||||
self,
|
||||
document: Document,
|
||||
*,
|
||||
collection_name: str | None = None,
|
||||
) -> str: # pragma: no cover
|
||||
return ""
|
||||
|
||||
async def store_batch(
|
||||
self,
|
||||
documents: list[Document],
|
||||
*,
|
||||
collection_name: str | None = None,
|
||||
) -> list[str]: # pragma: no cover
|
||||
return []
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
document_id: str,
|
||||
*,
|
||||
collection_name: str | None = None,
|
||||
) -> bool: # pragma: no cover
|
||||
return False
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
document_id: str,
|
||||
*,
|
||||
collection_name: str | None = None,
|
||||
) -> Document | None:
|
||||
if self._error is not None:
|
||||
raise self._error
|
||||
return self._result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"delta_days",
|
||||
[pytest.param(0, id="fresh"), pytest.param(40, id="stale")],
|
||||
)
|
||||
async def test_check_exists_uses_staleness(document_factory, storage_config, delta_days) -> None:
|
||||
"""Return True only when document timestamp is within freshness window."""
|
||||
|
||||
timestamp = datetime.now(UTC) - timedelta(days=delta_days)
|
||||
document = document_factory(content=f"doc-{delta_days}", metadata_updates={"timestamp": timestamp})
|
||||
storage = StubStorage(storage_config, result=document)
|
||||
|
||||
outcome = await storage.check_exists("identifier", stale_after_days=30)
|
||||
|
||||
assert outcome is (delta_days <= 30)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_exists_returns_false_for_missing(storage_config) -> None:
|
||||
"""Return False when retrieve yields no document."""
|
||||
|
||||
storage = StubStorage(storage_config, result=None)
|
||||
|
||||
assert await storage.check_exists("missing") is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"error_factory",
|
||||
[
|
||||
pytest.param(lambda: NotImplementedError("unsupported"), id="not-implemented"),
|
||||
pytest.param(lambda: RuntimeError("unexpected"), id="generic-error"),
|
||||
],
|
||||
)
|
||||
async def test_check_exists_swallows_errors(storage_config, error_factory) -> None:
|
||||
"""Return False when retrieval raises errors."""
|
||||
|
||||
storage = StubStorage(storage_config, error=error_factory())
|
||||
|
||||
assert await storage.check_exists("identifier") is False
|
||||
136
tests/unit/storage/test_openwebui.py
Normal file
136
tests/unit/storage/test_openwebui.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from ingest_pipeline.core.models import StorageConfig
|
||||
from ingest_pipeline.storage.openwebui import OpenWebUIStorage
|
||||
|
||||
|
||||
def _make_storage(config: StorageConfig) -> OpenWebUIStorage:
|
||||
return OpenWebUIStorage(config)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_knowledge_id_returns_existing(
|
||||
storage_config: StorageConfig,
|
||||
openwebui_service,
|
||||
httpx_stub,
|
||||
) -> None:
|
||||
"""Return cached identifier when knowledge base already exists."""
|
||||
|
||||
openwebui_service.ensure_knowledge(
|
||||
name=storage_config.collection_name,
|
||||
knowledge_id="kb-123",
|
||||
)
|
||||
storage = _make_storage(storage_config)
|
||||
|
||||
knowledge_id = await storage._get_knowledge_id(None, create=False)
|
||||
|
||||
assert knowledge_id == "kb-123"
|
||||
urls = [request["url"] for request in httpx_stub.requests]
|
||||
assert "http://storage.local/api/v1/knowledge/list" in urls
|
||||
await storage.client.aclose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_knowledge_id_creates_when_missing(
|
||||
storage_config: StorageConfig,
|
||||
openwebui_service,
|
||||
httpx_stub,
|
||||
) -> None:
|
||||
"""Create knowledge base when missing and caching the result."""
|
||||
|
||||
derived_config = storage_config.model_copy(update={"collection_name": "custom"})
|
||||
storage = _make_storage(derived_config)
|
||||
|
||||
knowledge_id = await storage._get_knowledge_id("custom", create=True)
|
||||
|
||||
assert knowledge_id is not None
|
||||
urls = [request["url"] for request in httpx_stub.requests]
|
||||
assert "http://storage.local/api/v1/knowledge/list" in urls
|
||||
assert any(
|
||||
url.startswith("http://storage.local/api/v1/knowledge/") and url.endswith("/create")
|
||||
for url in urls
|
||||
)
|
||||
await storage.client.aclose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_uploads_and_attaches_document(
|
||||
document_factory,
|
||||
storage_config: StorageConfig,
|
||||
openwebui_service,
|
||||
httpx_stub,
|
||||
) -> None:
|
||||
"""Upload file then attach it to the knowledge base."""
|
||||
|
||||
storage = _make_storage(storage_config)
|
||||
document = document_factory(content="hello world", metadata_updates={"title": "example"})
|
||||
|
||||
file_id = await storage.store(document)
|
||||
|
||||
assert file_id is not None
|
||||
urls = [request["url"] for request in httpx_stub.requests]
|
||||
assert any(url.endswith("/api/v1/files/") for url in urls)
|
||||
assert any("/file/add" in url for url in urls)
|
||||
knowledge_entry = openwebui_service.find_knowledge_by_name(storage_config.collection_name)
|
||||
assert knowledge_entry is not None
|
||||
_, knowledge = knowledge_entry
|
||||
assert len(knowledge.get("files", [])) == 1
|
||||
assert knowledge["files"][0]["id"] == file_id
|
||||
await storage.client.aclose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_batch_handles_multiple_documents(
|
||||
document_factory,
|
||||
storage_config: StorageConfig,
|
||||
openwebui_service,
|
||||
httpx_stub,
|
||||
) -> None:
|
||||
"""Store documents in batch and return collected file identifiers."""
|
||||
|
||||
storage = _make_storage(storage_config)
|
||||
first = document_factory(content="alpha", metadata_updates={"title": "alpha"})
|
||||
second = document_factory(content="beta", metadata_updates={"title": "beta"})
|
||||
|
||||
file_ids = await storage.store_batch([first, second])
|
||||
|
||||
assert len(file_ids) == 2
|
||||
files_payloads: list[Any] = [request["files"] for request in httpx_stub.requests if request["method"] == "POST"]
|
||||
assert any(payload is not None for payload in files_payloads)
|
||||
knowledge_entry = openwebui_service.find_knowledge_by_name(storage_config.collection_name)
|
||||
assert knowledge_entry is not None
|
||||
_, knowledge = knowledge_entry
|
||||
assert {meta["id"] for meta in knowledge.get("files", [])} == set(file_ids)
|
||||
await storage.client.aclose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_removes_file(
|
||||
storage_config: StorageConfig,
|
||||
openwebui_service,
|
||||
httpx_stub,
|
||||
) -> None:
|
||||
"""Remove file from knowledge base and delete the uploaded resource."""
|
||||
|
||||
openwebui_service.ensure_knowledge(
|
||||
name=storage_config.collection_name,
|
||||
knowledge_id="kb-55",
|
||||
)
|
||||
openwebui_service.create_file(filename="to-delete.txt", file_id="file-xyz")
|
||||
openwebui_service.attach_existing_file("kb-55", "file-xyz")
|
||||
storage = _make_storage(storage_config)
|
||||
|
||||
result = await storage.delete("file-xyz")
|
||||
|
||||
assert result is True
|
||||
urls = [request["url"] for request in httpx_stub.requests]
|
||||
assert "http://storage.local/api/v1/knowledge/kb-55/file/remove" in urls
|
||||
assert "http://storage.local/api/v1/files/file-xyz" in urls
|
||||
knowledge = openwebui_service.get_knowledge("kb-55")
|
||||
assert knowledge is not None
|
||||
assert knowledge.get("files", []) == []
|
||||
await storage.client.aclose()
|
||||
343
tests/unit/storage/test_r2r_helpers.py
Normal file
343
tests/unit/storage/test_r2r_helpers.py
Normal file
@@ -0,0 +1,343 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Self
|
||||
|
||||
import pytest
|
||||
|
||||
from ingest_pipeline.core.models import StorageConfig
|
||||
from ingest_pipeline.storage.r2r.storage import (
|
||||
R2RStorage,
|
||||
_as_datetime,
|
||||
_as_int,
|
||||
_as_mapping,
|
||||
_as_sequence,
|
||||
_extract_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def r2r_client_stub(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
r2r_service,
|
||||
) -> object:
|
||||
class DummyR2RException(Exception):
|
||||
def __init__(self, message: str, status_code: int | None = None) -> None:
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, json_data: dict[str, Any], status_code: int = 200) -> None:
|
||||
self._json_data = json_data
|
||||
self.status_code = status_code
|
||||
|
||||
def json(self) -> dict[str, Any]:
|
||||
return self._json_data
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
if self.status_code >= 400:
|
||||
from httpx import HTTPStatusError
|
||||
raise HTTPStatusError("HTTP error", request=None, response=self)
|
||||
|
||||
class MockAsyncClient:
|
||||
def __init__(self, service: Any) -> None:
|
||||
self._service = service
|
||||
|
||||
async def get(self, url: str) -> MockResponse:
|
||||
if "/v3/collections" in url:
|
||||
# Return existing collections
|
||||
collections = []
|
||||
for collection_id, collection_data in self._service._collections.items():
|
||||
collections.append({
|
||||
"id": collection_id,
|
||||
"name": collection_data["name"],
|
||||
"description": collection_data.get("description", ""),
|
||||
})
|
||||
return MockResponse({"results": collections})
|
||||
return MockResponse({})
|
||||
|
||||
async def post(self, url: str, *, json: dict[str, Any] | None = None, files: dict[str, Any] | None = None) -> MockResponse:
|
||||
if "/v3/collections" in url and json:
|
||||
# Create new collection
|
||||
new_collection_id = f"col-{len(self._service._collections) + 1}"
|
||||
self._service.create_collection(
|
||||
name=json["name"],
|
||||
collection_id=new_collection_id,
|
||||
description=json.get("description", ""),
|
||||
)
|
||||
return MockResponse({
|
||||
"results": {
|
||||
"id": new_collection_id,
|
||||
"name": json["name"],
|
||||
"description": json.get("description", ""),
|
||||
}
|
||||
})
|
||||
elif "/v3/documents" in url and files:
|
||||
# Create document
|
||||
import json as json_lib
|
||||
document_id = files.get("id", (None, f"doc-{len(self._service._documents) + 1}"))[1]
|
||||
content = files.get("raw_text", (None, ""))[1]
|
||||
metadata_str = files.get("metadata", (None, "{}"))[1]
|
||||
metadata = json_lib.loads(metadata_str) if metadata_str else {}
|
||||
|
||||
# Store document in mock service
|
||||
document_data = {
|
||||
"id": document_id,
|
||||
"content": content,
|
||||
"metadata": metadata,
|
||||
}
|
||||
self._service._documents[document_id] = document_data
|
||||
|
||||
# Update collection document count if specified
|
||||
collection_ids = files.get("collection_ids")
|
||||
if collection_ids:
|
||||
# collection_ids is passed as a tuple (None, "[collection_id]") in the files format
|
||||
collection_ids_str = collection_ids[1] if isinstance(collection_ids, tuple) else collection_ids
|
||||
try:
|
||||
import json as json_lib
|
||||
collection_list = json_lib.loads(collection_ids_str) if isinstance(collection_ids_str, str) else collection_ids_str
|
||||
if isinstance(collection_list, list) and len(collection_list) > 0:
|
||||
# Extract the collection ID - it could be a string or dict
|
||||
first_collection = collection_list[0]
|
||||
if isinstance(first_collection, dict) and "id" in first_collection:
|
||||
collection_id = first_collection["id"]
|
||||
elif isinstance(first_collection, str):
|
||||
collection_id = first_collection
|
||||
else:
|
||||
collection_id = None
|
||||
|
||||
if collection_id and collection_id in self._service._collections:
|
||||
# Update the collection's document count to match the actual number of documents
|
||||
total_docs = len(self._service._documents)
|
||||
self._service._collections[collection_id]["document_count"] = total_docs
|
||||
except (json_lib.JSONDecodeError, TypeError, KeyError):
|
||||
pass # Ignore parsing errors
|
||||
|
||||
return MockResponse({
|
||||
"results": {
|
||||
"document_id": document_id,
|
||||
"message": "Document created successfully",
|
||||
}
|
||||
})
|
||||
return MockResponse({})
|
||||
|
||||
async def aclose(self) -> None:
|
||||
return None
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
return None
|
||||
|
||||
class DocumentsAPI:
|
||||
def __init__(self, service: Any) -> None:
|
||||
self._service = service
|
||||
|
||||
async def retrieve(self, document_id: str) -> dict[str, Any]:
|
||||
document = self._service.get_document(document_id)
|
||||
if document is None:
|
||||
raise DummyR2RException("Not found", status_code=404)
|
||||
return {"results": document}
|
||||
|
||||
async def delete(self, document_id: str) -> dict[str, Any]:
|
||||
if not self._service.delete_document(document_id):
|
||||
raise DummyR2RException("Not found", status_code=404)
|
||||
return {"results": {"success": True}}
|
||||
|
||||
async def append_metadata(self, id: str, metadata: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
document = self._service.append_document_metadata(id, metadata)
|
||||
if document is None:
|
||||
raise DummyR2RException("Not found", status_code=404)
|
||||
return {"results": document}
|
||||
|
||||
class RetrievalAPI:
|
||||
def __init__(self, service: Any) -> None:
|
||||
self._service = service
|
||||
|
||||
async def search(self, query: str, search_settings: Mapping[str, Any]) -> dict[str, Any]:
|
||||
results = [
|
||||
{"document_id": doc_id, "score": 1.0}
|
||||
for doc_id in self._service._documents
|
||||
]
|
||||
return {"results": results}
|
||||
|
||||
class DummyClient:
|
||||
def __init__(self, service: Any) -> None:
|
||||
self.documents = DocumentsAPI(service)
|
||||
self.retrieval = RetrievalAPI(service)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
return None
|
||||
|
||||
async def close(self) -> None:
|
||||
return None
|
||||
|
||||
# Mock the AsyncClient that R2RStorage uses internally
|
||||
mock_async_client = MockAsyncClient(r2r_service)
|
||||
monkeypatch.setattr(
|
||||
"ingest_pipeline.storage.r2r.storage.AsyncClient",
|
||||
lambda: mock_async_client,
|
||||
)
|
||||
|
||||
client = DummyClient(r2r_service)
|
||||
monkeypatch.setattr(
|
||||
"ingest_pipeline.storage.r2r.storage.R2RAsyncClient",
|
||||
lambda endpoint: client,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"ingest_pipeline.storage.r2r.storage.R2RException",
|
||||
DummyR2RException,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("value", "expected"),
|
||||
[
|
||||
pytest.param({"a": 1}, {"a": 1}, id="mapping"),
|
||||
pytest.param(SimpleNamespace(a=2), {"a": 2}, id="namespace"),
|
||||
pytest.param(5, {}, id="other"),
|
||||
],
|
||||
)
|
||||
def test_as_mapping_normalizes(value, expected) -> None:
|
||||
"""Convert inputs into dictionaries where possible."""
|
||||
|
||||
assert _as_mapping(value) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("value", "expected"),
|
||||
[
|
||||
pytest.param([1, 2], (1, 2), id="list"),
|
||||
pytest.param((3, 4), (3, 4), id="tuple"),
|
||||
pytest.param("ab", ("a", "b"), id="string"),
|
||||
pytest.param(7, (), id="non-iterable"),
|
||||
],
|
||||
)
|
||||
def test_as_sequence_coerces_iterables(value, expected) -> None:
|
||||
"""Represent values as tuples for downstream iteration."""
|
||||
|
||||
assert _as_sequence(value) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("source", "fallback", "expected"),
|
||||
[
|
||||
pytest.param({"id": "abc"}, "x", "abc", id="mapping"),
|
||||
pytest.param(SimpleNamespace(id=123), "x", "123", id="attribute"),
|
||||
pytest.param({}, "fallback", "fallback", id="fallback"),
|
||||
],
|
||||
)
|
||||
def test_extract_id_falls_back(source, fallback, expected) -> None:
|
||||
"""Prefer embedded identifier values and fall back otherwise."""
|
||||
|
||||
assert _extract_id(source, fallback) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("value", "expected_year"),
|
||||
[
|
||||
pytest.param(datetime(2024, 1, 1, tzinfo=UTC), 2024, id="datetime"),
|
||||
pytest.param("2024-02-01T00:00:00+00:00", 2024, id="iso"),
|
||||
pytest.param("invalid", datetime.now(UTC).year, id="fallback"),
|
||||
],
|
||||
)
|
||||
def test_as_datetime_recognizes_formats(value, expected_year) -> None:
|
||||
"""Produce timezone-aware datetime objects."""
|
||||
|
||||
assert _as_datetime(value).year == expected_year
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("value", "expected"),
|
||||
[
|
||||
pytest.param(True, 1, id="bool"),
|
||||
pytest.param(5, 5, id="int"),
|
||||
pytest.param(3.9, 3, id="float"),
|
||||
pytest.param("7", 7, id="string"),
|
||||
pytest.param("8.2", 8, id="float-string"),
|
||||
pytest.param("bad", 2, id="default"),
|
||||
],
|
||||
)
|
||||
def test_as_int_handles_numeric_coercions(value, expected) -> None:
|
||||
"""Convert assorted numeric representations."""
|
||||
|
||||
assert _as_int(value, default=2) == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_collection_finds_existing(
|
||||
r2r_storage_config: StorageConfig,
|
||||
r2r_service,
|
||||
httpx_stub,
|
||||
r2r_client_stub,
|
||||
) -> None:
|
||||
"""Return collection identifier when already present."""
|
||||
|
||||
r2r_service.create_collection(
|
||||
name=r2r_storage_config.collection_name,
|
||||
collection_id="col-1",
|
||||
)
|
||||
storage = R2RStorage(r2r_storage_config)
|
||||
|
||||
collection_id = await storage._ensure_collection(r2r_storage_config.collection_name)
|
||||
|
||||
assert collection_id == "col-1"
|
||||
assert storage.default_collection_id == "col-1"
|
||||
await storage.client.aclose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_collection_creates_when_missing(
|
||||
r2r_storage_config: StorageConfig,
|
||||
r2r_service,
|
||||
httpx_stub,
|
||||
r2r_client_stub,
|
||||
) -> None:
|
||||
"""Create collection via POST when absent."""
|
||||
|
||||
storage = R2RStorage(r2r_storage_config)
|
||||
|
||||
collection_id = await storage._ensure_collection("alternate")
|
||||
|
||||
assert collection_id is not None
|
||||
located = r2r_service.find_collection_by_name("alternate")
|
||||
assert located is not None
|
||||
identifier, _ = located
|
||||
assert identifier == collection_id
|
||||
await storage.client.aclose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_batch_creates_documents(
|
||||
document_factory,
|
||||
r2r_storage_config: StorageConfig,
|
||||
r2r_service,
|
||||
httpx_stub,
|
||||
r2r_client_stub,
|
||||
) -> None:
|
||||
"""Store documents and persist them via the R2R mock service."""
|
||||
|
||||
storage = R2RStorage(r2r_storage_config)
|
||||
documents = [
|
||||
document_factory(content="first document", metadata_updates={"title": "First"}),
|
||||
document_factory(content="second document", metadata_updates={"title": "Second"}),
|
||||
]
|
||||
|
||||
stored_ids = await storage.store_batch(documents)
|
||||
|
||||
assert len(stored_ids) == 2
|
||||
for doc_id, original in zip(stored_ids, documents, strict=False):
|
||||
stored = r2r_service.get_document(doc_id)
|
||||
assert stored is not None
|
||||
assert stored["metadata"]["source_url"] == original.metadata["source_url"]
|
||||
|
||||
collection = r2r_service.find_collection_by_name(r2r_storage_config.collection_name)
|
||||
assert collection is not None
|
||||
_, collection_payload = collection
|
||||
assert collection_payload["document_count"] == 2
|
||||
|
||||
await storage.client.aclose()
|
||||
140
tests/unit/storage/test_weaviate_helpers.py
Normal file
140
tests/unit/storage/test_weaviate_helpers.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from ingest_pipeline.core.exceptions import StorageError
|
||||
from ingest_pipeline.core.models import IngestionSource, StorageConfig
|
||||
from ingest_pipeline.storage.weaviate import WeaviateStorage
|
||||
from ingest_pipeline.utils.vectorizer import Vectorizer
|
||||
|
||||
|
||||
def _build_storage(config: StorageConfig) -> WeaviateStorage:
|
||||
storage = WeaviateStorage.__new__(WeaviateStorage)
|
||||
storage.config = config
|
||||
storage.client = None
|
||||
storage.vectorizer = cast(Vectorizer, None)
|
||||
storage._default_collection = config.collection_name
|
||||
return storage
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw, expected",
|
||||
[
|
||||
pytest.param([1, 2, 3], [1.0, 2.0, 3.0], id="list"),
|
||||
pytest.param({"default": [4, 5]}, [4.0, 5.0], id="mapping"),
|
||||
pytest.param([[6, 7]], [6.0, 7.0], id="nested"),
|
||||
pytest.param("invalid", None, id="invalid"),
|
||||
],
|
||||
)
|
||||
def test_extract_vector_normalizes(raw, expected) -> None:
|
||||
"""Normalize various vector payload structures."""
|
||||
|
||||
assert WeaviateStorage._extract_vector(raw) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value, expected",
|
||||
[
|
||||
pytest.param(IngestionSource.WEB, IngestionSource.WEB, id="enum"),
|
||||
pytest.param("documentation", IngestionSource.DOCUMENTATION, id="string"),
|
||||
pytest.param("unknown", IngestionSource.WEB, id="fallback"),
|
||||
pytest.param(42, IngestionSource.WEB, id="non-string"),
|
||||
],
|
||||
)
|
||||
def test_parse_source_normalizes(value, expected) -> None:
|
||||
"""Ensure ingestion source strings coerce into enums."""
|
||||
|
||||
assert WeaviateStorage._parse_source(value) is expected
|
||||
|
||||
|
||||
def test_coerce_properties_returns_mapping() -> None:
|
||||
"""Pass through mapping when provided."""
|
||||
|
||||
payload = {"key": "value"}
|
||||
|
||||
assert WeaviateStorage._coerce_properties(payload, context="test") is payload
|
||||
|
||||
|
||||
def test_coerce_properties_allows_missing() -> None:
|
||||
"""Return None when allow_missing is True and payload absent."""
|
||||
|
||||
assert WeaviateStorage._coerce_properties(None, context="test", allow_missing=True) is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"payload",
|
||||
[pytest.param(None, id="none"), pytest.param([1, 2, 3], id="sequence")],
|
||||
)
|
||||
def test_coerce_properties_raises_on_invalid(payload) -> None:
|
||||
"""Raise storage error when payload is not a mapping."""
|
||||
|
||||
with pytest.raises(StorageError):
|
||||
WeaviateStorage._coerce_properties(payload, context="invalid")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_name, expected",
|
||||
[
|
||||
pytest.param(None, "Documents", id="default"),
|
||||
pytest.param(" custom ", "Custom", id="trimmed"),
|
||||
],
|
||||
)
|
||||
def test_normalize_collection_name(storage_config, input_name, expected) -> None:
|
||||
"""Normalize with capitalization and fallback to config value."""
|
||||
|
||||
storage = _build_storage(storage_config)
|
||||
|
||||
assert storage._normalize_collection_name(input_name) == expected
|
||||
|
||||
|
||||
def test_normalize_collection_name_rejects_empty(storage_config) -> None:
|
||||
"""Raise when provided name is blank."""
|
||||
|
||||
storage = _build_storage(storage_config)
|
||||
|
||||
with pytest.raises(StorageError):
|
||||
storage._normalize_collection_name(" ")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value, expected",
|
||||
[
|
||||
pytest.param(5, 5, id="int"),
|
||||
pytest.param(7.9, 7, id="float"),
|
||||
pytest.param("12", 12, id="string"),
|
||||
pytest.param(None, 0, id="none"),
|
||||
],
|
||||
)
|
||||
def test_safe_convert_count(storage_config, value, expected) -> None:
|
||||
"""Convert assorted count representations into integers."""
|
||||
|
||||
storage = _build_storage(storage_config)
|
||||
|
||||
assert storage._safe_convert_count(value) == expected
|
||||
|
||||
|
||||
def test_build_document_metadata(storage_config) -> None:
|
||||
"""Build metadata mapping with type coercions."""
|
||||
|
||||
storage = _build_storage(storage_config)
|
||||
props = {
|
||||
"source_url": "https://example.com",
|
||||
"title": "Doc",
|
||||
"description": "Details",
|
||||
"timestamp": "2024-01-01T00:00:00+00:00",
|
||||
"content_type": "text/plain",
|
||||
"word_count": "5",
|
||||
"char_count": 128.0,
|
||||
}
|
||||
|
||||
metadata = storage._build_document_metadata(props)
|
||||
|
||||
assert metadata["source_url"] == "https://example.com"
|
||||
assert metadata["title"] == "Doc"
|
||||
assert metadata["description"] == "Details"
|
||||
assert metadata["timestamp"] == datetime(2024, 1, 1, tzinfo=UTC)
|
||||
assert metadata["word_count"] == 5
|
||||
assert metadata["char_count"] == 128
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user