15248 lines
509 KiB
XML
15248 lines
509 KiB
XML
This file is a merged representation of a subset of the codebase, containing specifically included files, combined into a single document by Repomix.
|
|
|
|
<file_summary>
|
|
This section contains a summary of this file.
|
|
|
|
<purpose>
|
|
This file contains a packed representation of a subset of the repository's contents that is considered the most important context.
|
|
It is designed to be easily consumable by AI systems for analysis, code review,
|
|
or other automated processes.
|
|
</purpose>
|
|
|
|
<file_format>
|
|
The content is organized as follows:
|
|
1. This summary section
|
|
2. Repository information
|
|
3. Directory structure
|
|
4. Repository files (if enabled)
|
|
5. Multiple file entries, each consisting of:
|
|
- File path as an attribute
|
|
- Full contents of the file
|
|
</file_format>
|
|
|
|
<usage_guidelines>
|
|
- This file should be treated as read-only. Any changes should be made to the
|
|
original repository files, not this packed version.
|
|
- When processing this file, use the file path to distinguish
|
|
between different files in the repository.
|
|
- Be aware that this file may contain sensitive information. Handle it with
|
|
the same level of security as you would the original repository.
|
|
</usage_guidelines>
|
|
|
|
<notes>
|
|
- Some files may have been excluded based on .gitignore rules and Repomix's configuration
|
|
- Binary files are not included in this packed representation. Please refer to the Repository Structure section for a complete list of file paths, including binary files
|
|
- Only files matching these patterns are included: ingest_pipeline/
|
|
- Files matching patterns in .gitignore are excluded
|
|
- Files matching default ignore patterns are excluded
|
|
- Files are sorted by Git change count (files with more changes are at the bottom)
|
|
</notes>
|
|
|
|
</file_summary>
|
|
|
|
<directory_structure>
|
|
ingest_pipeline/
|
|
automations/
|
|
__init__.py
|
|
cli/
|
|
tui/
|
|
screens/
|
|
__init__.py
|
|
base.py
|
|
dashboard.py
|
|
dialogs.py
|
|
documents.py
|
|
help.py
|
|
ingestion.py
|
|
search.py
|
|
utils/
|
|
__init__.py
|
|
runners.py
|
|
storage_manager.py
|
|
widgets/
|
|
__init__.py
|
|
cards.py
|
|
firecrawl_config.py
|
|
indicators.py
|
|
r2r_widgets.py
|
|
tables.py
|
|
__init__.py
|
|
app.py
|
|
layouts.py
|
|
models.py
|
|
styles.py
|
|
__init__.py
|
|
main.py
|
|
config/
|
|
__init__.py
|
|
settings.py
|
|
core/
|
|
__init__.py
|
|
exceptions.py
|
|
models.py
|
|
flows/
|
|
__init__.py
|
|
ingestion.py
|
|
scheduler.py
|
|
ingestors/
|
|
__init__.py
|
|
base.py
|
|
firecrawl.py
|
|
storage/
|
|
r2r/
|
|
__init__.py
|
|
collections.py
|
|
storage.py
|
|
__init__.py
|
|
base.py
|
|
openwebui.py
|
|
types.py
|
|
weaviate.py
|
|
utils/
|
|
__init__.py
|
|
async_helpers.py
|
|
metadata_tagger.py
|
|
vectorizer.py
|
|
__main__.py
|
|
</directory_structure>
|
|
|
|
<files>
|
|
This section contains the contents of the repository's files.
|
|
|
|
<file path="ingest_pipeline/utils/async_helpers.py">
|
|
"""Async utilities for task management and backpressure control."""
|
|
|
|
import asyncio
|
|
import logging
|
|
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable
|
|
from contextlib import asynccontextmanager
|
|
from typing import Final, TypeVar
|
|
|
|
LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class AsyncTaskManager:
|
|
"""Manages concurrent tasks with backpressure control."""
|
|
|
|
def __init__(self, max_concurrent: int = 10):
|
|
"""
|
|
Initialize task manager.
|
|
|
|
Args:
|
|
max_concurrent: Maximum number of concurrent tasks
|
|
"""
|
|
self.semaphore = asyncio.Semaphore(max_concurrent)
|
|
self.max_concurrent = max_concurrent
|
|
|
|
@asynccontextmanager
|
|
async def acquire(self) -> AsyncGenerator[None, None]:
|
|
"""Acquire a slot for task execution."""
|
|
async with self.semaphore:
|
|
yield
|
|
|
|
async def run_tasks(
|
|
self, tasks: Iterable[Awaitable[T]], return_exceptions: bool = False
|
|
) -> list[T | BaseException]:
|
|
"""
|
|
Run multiple tasks with backpressure control.
|
|
|
|
Args:
|
|
tasks: Iterable of awaitable tasks
|
|
return_exceptions: Whether to return exceptions or raise them
|
|
|
|
Returns:
|
|
List of task results or exceptions
|
|
"""
|
|
|
|
async def _controlled_task(task: Awaitable[T]) -> T:
|
|
async with self.acquire():
|
|
return await task
|
|
|
|
controlled_tasks = [_controlled_task(task) for task in tasks]
|
|
|
|
if return_exceptions:
|
|
results = await asyncio.gather(*controlled_tasks, return_exceptions=True)
|
|
return list(results)
|
|
else:
|
|
results = await asyncio.gather(*controlled_tasks)
|
|
return list(results)
|
|
|
|
async def map_async(
|
|
self, func: Callable[[T], Awaitable[T]], items: Iterable[T], return_exceptions: bool = False
|
|
) -> list[T | BaseException]:
|
|
"""
|
|
Apply async function to items with backpressure control.
|
|
|
|
Args:
|
|
func: Async function to apply
|
|
items: Items to process
|
|
return_exceptions: Whether to return exceptions or raise them
|
|
|
|
Returns:
|
|
List of processed results or exceptions
|
|
"""
|
|
tasks = [func(item) for item in items]
|
|
return await self.run_tasks(tasks, return_exceptions=return_exceptions)
|
|
|
|
|
|
async def run_with_semaphore(semaphore: asyncio.Semaphore, coro: Awaitable[T]) -> T:
|
|
"""Run coroutine with semaphore-controlled concurrency."""
|
|
async with semaphore:
|
|
return await coro
|
|
|
|
|
|
async def batch_process(
|
|
items: list[T],
|
|
processor: Callable[[T], Awaitable[T]],
|
|
batch_size: int = 50,
|
|
max_concurrent: int = 5,
|
|
) -> list[T]:
|
|
"""
|
|
Process items in batches with controlled concurrency.
|
|
|
|
Args:
|
|
items: Items to process
|
|
processor: Async function to process each item
|
|
batch_size: Number of items per batch
|
|
max_concurrent: Maximum concurrent tasks per batch
|
|
|
|
Returns:
|
|
List of processed results
|
|
"""
|
|
task_manager = AsyncTaskManager(max_concurrent)
|
|
results: list[T] = []
|
|
|
|
for i in range(0, len(items), batch_size):
|
|
batch = items[i : i + batch_size]
|
|
LOGGER.debug(
|
|
"Processing batch %d-%d of %d items", i, min(i + batch_size, len(items)), len(items)
|
|
)
|
|
|
|
batch_results = await task_manager.map_async(processor, batch, return_exceptions=False)
|
|
# If return_exceptions=False, exceptions would have been raised, so all results are successful
|
|
# Type checker doesn't know this, so we need to cast
|
|
successful_results: list[T] = [r for r in batch_results if not isinstance(r, BaseException)]
|
|
results.extend(successful_results)
|
|
|
|
return results
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/automations/__init__.py">
|
|
"""Prefect Automations for ingestion pipeline monitoring and management."""
|
|
|
|
# Automation configurations as YAML-ready dictionaries
|
|
AUTOMATION_TEMPLATES = {
|
|
"cancel_long_running": """
|
|
name: Cancel Long Running Ingestion Flows
|
|
description: Cancels ingestion flows running longer than 30 minutes
|
|
trigger:
|
|
type: event
|
|
posture: Proactive
|
|
expect: [prefect.flow-run.Running]
|
|
match_related:
|
|
prefect.resource.role: flow
|
|
prefect.resource.name: ingestion_pipeline
|
|
threshold: 1
|
|
within: 1800
|
|
actions:
|
|
- type: cancel-flow-run
|
|
source: inferred
|
|
enabled: true
|
|
""",
|
|
"retry_failed": """
|
|
name: Retry Failed Ingestion Flows
|
|
description: Retries failed ingestion flows with original parameters
|
|
trigger:
|
|
type: event
|
|
posture: Reactive
|
|
expect: [prefect.flow-run.Failed]
|
|
match_related:
|
|
prefect.resource.role: flow
|
|
prefect.resource.name: ingestion_pipeline
|
|
threshold: 1
|
|
within: 0
|
|
actions:
|
|
- type: run-deployment
|
|
source: inferred
|
|
parameters:
|
|
validate_first: false
|
|
enabled: true
|
|
""",
|
|
"resource_monitoring": """
|
|
name: Manage Work Pool Based on Resources
|
|
description: Pauses work pool when system resources are constrained
|
|
trigger:
|
|
type: event
|
|
posture: Reactive
|
|
expect: [system.resource.high_usage]
|
|
threshold: 1
|
|
within: 120
|
|
actions:
|
|
- type: pause-work-pool
|
|
work_pool_name: default
|
|
enabled: true
|
|
""",
|
|
}
|
|
|
|
|
|
def get_automation_yaml_templates() -> dict[str, str]:
|
|
"""Get automation templates as YAML strings."""
|
|
return AUTOMATION_TEMPLATES.copy()
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/cli/tui/screens/help.py">
|
|
"""Help screen with keyboard shortcuts and usage information."""
|
|
|
|
from textual.app import ComposeResult
|
|
from textual.binding import Binding
|
|
from textual.containers import Container, ScrollableContainer
|
|
from textual.screen import ModalScreen
|
|
from textual.widgets import Button, Markdown, Rule, Static
|
|
from typing_extensions import override
|
|
|
|
|
|
class HelpScreen(ModalScreen[None]):
|
|
"""Modern help screen with comprehensive keyboard shortcuts."""
|
|
|
|
help_content: str
|
|
|
|
BINDINGS = [
|
|
Binding("escape", "app.pop_screen", "Close"),
|
|
Binding("q", "app.pop_screen", "Close"),
|
|
Binding("enter", "app.pop_screen", "Close"),
|
|
Binding("f1", "app.pop_screen", "Close"),
|
|
]
|
|
|
|
def __init__(self, help_content: str):
|
|
super().__init__()
|
|
self.help_content = help_content
|
|
|
|
@override
|
|
def compose(self) -> ComposeResult:
|
|
with Container(classes="modal-container"):
|
|
yield Static("📚 Help & Keyboard Shortcuts", classes="title")
|
|
yield Static("Enhanced navigation and productivity features", classes="subtitle")
|
|
yield Rule(line_style="heavy")
|
|
|
|
with ScrollableContainer():
|
|
yield Markdown(self.help_content)
|
|
|
|
yield Container(
|
|
Button("✅ Got it! (Press Escape or Enter)", id="close_btn", variant="primary"),
|
|
classes="action_buttons center",
|
|
)
|
|
|
|
def on_mount(self) -> None:
|
|
"""Initialize the help screen."""
|
|
# Focus the close button
|
|
self.query_one("#close_btn").focus()
|
|
|
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
"""Close help screen."""
|
|
if event.button.id == "close_btn":
|
|
self.app.pop_screen()
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/cli/tui/utils/__init__.py">
|
|
"""Utility functions for the TUI."""
|
|
|
|
from .runners import dashboard, run_textual_tui
|
|
|
|
__all__ = ["dashboard", "run_textual_tui"]
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/cli/tui/widgets/__init__.py">
|
|
"""Enhanced widgets with keyboard navigation support."""
|
|
|
|
from .cards import MetricsCard
|
|
from .indicators import EnhancedProgressBar, StatusIndicator
|
|
from .tables import EnhancedDataTable
|
|
|
|
__all__ = [
|
|
"MetricsCard",
|
|
"StatusIndicator",
|
|
"EnhancedProgressBar",
|
|
"EnhancedDataTable",
|
|
]
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/cli/tui/widgets/cards.py">
|
|
"""Metrics card widget."""
|
|
|
|
from typing import Any
|
|
|
|
from textual.app import ComposeResult
|
|
from textual.widgets import Static
|
|
from typing_extensions import override
|
|
|
|
|
|
class MetricsCard(Static):
|
|
"""A modern metrics display card."""
|
|
|
|
title: str
|
|
value: str
|
|
description: str
|
|
|
|
def __init__(self, title: str, value: str, description: str = "", **kwargs: Any) -> None:
|
|
super().__init__(**kwargs)
|
|
self.title = title
|
|
self.value = value
|
|
self.description = description
|
|
|
|
@override
|
|
def compose(self) -> ComposeResult:
|
|
yield Static(self.value, classes="metrics-value")
|
|
yield Static(self.title, classes="metrics-label")
|
|
if self.description:
|
|
yield Static(self.description, classes="metrics-description")
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/cli/tui/widgets/tables.py">
|
|
"""Enhanced DataTable with improved keyboard navigation."""
|
|
|
|
from typing import Any
|
|
|
|
from textual import events
|
|
from textual.binding import Binding
|
|
from textual.message import Message
|
|
from textual.widgets import DataTable
|
|
|
|
|
|
class EnhancedDataTable(DataTable[Any]):
|
|
"""DataTable with enhanced keyboard navigation and visual feedback."""
|
|
|
|
BINDINGS = [
|
|
Binding("up,k", "cursor_up", "Cursor Up", show=False),
|
|
Binding("down,j", "cursor_down", "Cursor Down", show=False),
|
|
Binding("left,h", "cursor_left", "Cursor Left", show=False),
|
|
Binding("right,l", "cursor_right", "Cursor Right", show=False),
|
|
Binding("home", "cursor_home", "First Row", show=False),
|
|
Binding("end", "cursor_end", "Last Row", show=False),
|
|
Binding("pageup", "page_up", "Page Up", show=False),
|
|
Binding("pagedown", "page_down", "Page Down", show=False),
|
|
Binding("enter", "select_cursor", "Select", show=False),
|
|
Binding("space", "toggle_selection", "Toggle Selection", show=False),
|
|
Binding("ctrl+a", "select_all", "Select All", show=False),
|
|
Binding("ctrl+shift+a", "clear_selection", "Clear Selection", show=False),
|
|
]
|
|
|
|
def __init__(self, **kwargs: Any) -> None:
|
|
super().__init__(**kwargs)
|
|
self.cursor_type = "row" # Default to row selection
|
|
self.zebra_stripes = True # Enable zebra striping for better visibility
|
|
self.show_cursor = True
|
|
|
|
def on_key(self, event: events.Key) -> None:
|
|
"""Handle additional keyboard shortcuts."""
|
|
if event.key == "ctrl+1":
|
|
# Jump to first column
|
|
self.move_cursor(column=0)
|
|
event.prevent_default()
|
|
elif event.key == "ctrl+9":
|
|
# Jump to last column
|
|
if self.columns:
|
|
self.move_cursor(column=len(self.columns) - 1)
|
|
event.prevent_default()
|
|
elif event.key == "/":
|
|
# Start quick search (to be implemented by parent)
|
|
self.post_message(self.QuickSearch(self))
|
|
event.prevent_default()
|
|
elif event.key == "escape":
|
|
# Clear selection or exit search
|
|
# Clear selection by calling action
|
|
self.action_clear_selection()
|
|
event.prevent_default()
|
|
# No else clause needed - just handle our events
|
|
|
|
def action_cursor_home(self) -> None:
|
|
"""Move cursor to first row."""
|
|
if self.row_count > 0:
|
|
self.move_cursor(row=0)
|
|
|
|
def action_cursor_end(self) -> None:
|
|
"""Move cursor to last row."""
|
|
if self.row_count > 0:
|
|
self.move_cursor(row=self.row_count - 1)
|
|
|
|
def action_page_up(self) -> None:
|
|
"""Move cursor up by visible page size."""
|
|
if self.row_count > 0:
|
|
page_size = max(1, self.size.height // 2) # Approximate visible rows
|
|
new_row = max(0, self.cursor_coordinate.row - page_size)
|
|
self.move_cursor(row=new_row)
|
|
|
|
def action_page_down(self) -> None:
|
|
"""Move cursor down by visible page size."""
|
|
if self.row_count > 0:
|
|
page_size = max(1, self.size.height // 2) # Approximate visible rows
|
|
new_row = min(self.row_count - 1, self.cursor_coordinate.row + page_size)
|
|
self.move_cursor(row=new_row)
|
|
|
|
def action_toggle_selection(self) -> None:
|
|
"""Toggle selection of current row."""
|
|
if self.row_count > 0:
|
|
current_row = self.cursor_coordinate.row
|
|
# This will be handled by the parent screen
|
|
self.post_message(self.RowToggled(self, current_row))
|
|
|
|
def action_select_all(self) -> None:
|
|
"""Select all rows."""
|
|
# This will be handled by the parent screen
|
|
self.post_message(self.SelectAll(self))
|
|
|
|
def action_clear_selection(self) -> None:
|
|
"""Clear all selections."""
|
|
# This will be handled by the parent screen
|
|
self.post_message(self.ClearSelection(self))
|
|
|
|
# Custom messages for enhanced functionality
|
|
class QuickSearch(Message):
|
|
"""Posted when user wants to start a quick search."""
|
|
|
|
def __init__(self, table: "EnhancedDataTable") -> None:
|
|
super().__init__()
|
|
self.table = table
|
|
|
|
class RowToggled(Message):
|
|
"""Posted when a row selection is toggled."""
|
|
|
|
def __init__(self, table: "EnhancedDataTable", row_index: int) -> None:
|
|
super().__init__()
|
|
self.table = table
|
|
self.row_index = row_index
|
|
|
|
class SelectAll(Message):
|
|
"""Posted when user wants to select all rows."""
|
|
|
|
def __init__(self, table: "EnhancedDataTable") -> None:
|
|
super().__init__()
|
|
self.table = table
|
|
|
|
class ClearSelection(Message):
|
|
"""Posted when user wants to clear selection."""
|
|
|
|
def __init__(self, table: "EnhancedDataTable") -> None:
|
|
super().__init__()
|
|
self.table = table
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/cli/tui/__init__.py">
|
|
"""Enhanced TUI package with keyboard navigation and modular architecture."""
|
|
|
|
from .app import CollectionManagementApp
|
|
from .models import CollectionInfo, DocumentInfo
|
|
from .utils import dashboard, run_textual_tui
|
|
|
|
__all__ = [
|
|
"CollectionManagementApp",
|
|
"CollectionInfo",
|
|
"DocumentInfo",
|
|
"dashboard",
|
|
"run_textual_tui",
|
|
]
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/cli/__init__.py">
|
|
"""CLI module for the ingestion pipeline."""
|
|
|
|
from .main import app
|
|
|
|
__all__ = ["app"]
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/core/__init__.py">
|
|
"""Core module for ingestion pipeline."""
|
|
|
|
from .exceptions import (
|
|
IngestionError,
|
|
StorageError,
|
|
VectorizationError,
|
|
)
|
|
from .models import (
|
|
Document,
|
|
IngestionJob,
|
|
IngestionResult,
|
|
IngestionSource,
|
|
IngestionStatus,
|
|
StorageBackend,
|
|
)
|
|
|
|
__all__ = [
|
|
"Document",
|
|
"IngestionJob",
|
|
"IngestionResult",
|
|
"IngestionSource",
|
|
"IngestionStatus",
|
|
"StorageBackend",
|
|
"IngestionError",
|
|
"StorageError",
|
|
"VectorizationError",
|
|
]
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/core/exceptions.py">
|
|
"""Custom exceptions for the ingestion pipeline."""
|
|
|
|
|
|
class IngestionError(Exception):
|
|
"""Base exception for ingestion errors."""
|
|
|
|
pass
|
|
|
|
|
|
class StorageError(IngestionError):
|
|
"""Exception for storage-related errors."""
|
|
|
|
pass
|
|
|
|
|
|
class VectorizationError(IngestionError):
|
|
"""Exception for vectorization errors."""
|
|
|
|
pass
|
|
|
|
|
|
class ConfigurationError(IngestionError):
|
|
"""Exception for configuration errors."""
|
|
|
|
pass
|
|
|
|
|
|
class SourceNotFoundError(IngestionError):
|
|
"""Exception when source cannot be found or accessed."""
|
|
|
|
pass
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/flows/__init__.py">
|
|
"""Prefect flows for orchestration."""
|
|
|
|
from .ingestion import create_ingestion_flow
|
|
from .scheduler import create_scheduled_deployment
|
|
|
|
__all__ = [
|
|
"create_ingestion_flow",
|
|
"create_scheduled_deployment",
|
|
]
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/storage/r2r/__init__.py">
|
|
"""R2R storage package providing comprehensive R2R integration."""
|
|
|
|
from .storage import R2RStorage
|
|
|
|
__all__ = ["R2RStorage"]
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/storage/types.py">
|
|
"""Shared types for storage adapters."""
|
|
|
|
from typing import TypedDict
|
|
|
|
|
|
class CollectionSummary(TypedDict):
|
|
"""Collection metadata for describe_collections."""
|
|
|
|
name: str
|
|
count: int
|
|
size_mb: float
|
|
|
|
|
|
class DocumentInfo(TypedDict):
|
|
"""Document information for list_documents."""
|
|
|
|
id: str
|
|
title: str
|
|
source_url: str
|
|
description: str
|
|
content_type: str
|
|
content_preview: str
|
|
word_count: int
|
|
timestamp: str
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/utils/__init__.py">
|
|
"""Utility modules."""
|
|
|
|
from .metadata_tagger import MetadataTagger
|
|
from .vectorizer import Vectorizer
|
|
|
|
__all__ = ["MetadataTagger", "Vectorizer"]
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/__main__.py">
|
|
"""Main entry point for the ingestion pipeline."""
|
|
|
|
from .cli.main import app
|
|
|
|
if __name__ == "__main__":
|
|
app()
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/cli/tui/screens/__init__.py">
|
|
"""Screen components for the TUI application."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from .dashboard import CollectionOverviewScreen
|
|
from .dialogs import ConfirmDeleteScreen, ConfirmDocumentDeleteScreen
|
|
from .documents import DocumentManagementScreen
|
|
from .help import HelpScreen
|
|
from .ingestion import IngestionScreen
|
|
from .search import SearchScreen
|
|
|
|
__all__ = [
|
|
"CollectionOverviewScreen",
|
|
"IngestionScreen",
|
|
"SearchScreen",
|
|
"DocumentManagementScreen",
|
|
"ConfirmDeleteScreen",
|
|
"ConfirmDocumentDeleteScreen",
|
|
"HelpScreen",
|
|
]
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/cli/tui/screens/base.py">
|
|
"""Base screen classes for common CRUD patterns."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Generic, TypeVar
|
|
|
|
from textual import work
|
|
from textual.app import ComposeResult
|
|
from textual.binding import Binding
|
|
from textual.containers import Container
|
|
from textual.screen import ModalScreen, Screen
|
|
from textual.widget import Widget
|
|
from textual.widgets import Button, DataTable, LoadingIndicator, Static
|
|
from typing_extensions import override
|
|
|
|
if TYPE_CHECKING:
|
|
from ..utils.storage_manager import StorageManager
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class BaseScreen(Screen[object]):
|
|
"""Base screen with common functionality."""
|
|
|
|
def __init__(
|
|
self,
|
|
storage_manager: StorageManager,
|
|
*,
|
|
name: str | None = None,
|
|
id: str | None = None,
|
|
classes: str | None = None,
|
|
**kwargs: object,
|
|
) -> None:
|
|
"""Initialize base screen."""
|
|
super().__init__(name=name, id=id, classes=classes)
|
|
# Ignore any additional kwargs to avoid type issues
|
|
self.storage_manager = storage_manager
|
|
|
|
|
|
class CRUDScreen(BaseScreen, Generic[T]):
|
|
"""Base class for Create/Read/Update/Delete operations."""
|
|
|
|
BINDINGS = [
|
|
Binding("ctrl+n", "create_item", "New"),
|
|
Binding("ctrl+e", "edit_item", "Edit"),
|
|
Binding("ctrl+d", "delete_item", "Delete"),
|
|
Binding("f5", "refresh", "Refresh"),
|
|
Binding("escape", "app.pop_screen", "Back"),
|
|
]
|
|
|
|
def __init__(
|
|
self,
|
|
storage_manager: StorageManager,
|
|
*,
|
|
name: str | None = None,
|
|
id: str | None = None,
|
|
classes: str | None = None,
|
|
**kwargs: object,
|
|
) -> None:
|
|
"""Initialize CRUD screen."""
|
|
super().__init__(storage_manager, name=name, id=id, classes=classes)
|
|
self.items: list[T] = []
|
|
self.selected_item: T | None = None
|
|
self.loading = False
|
|
|
|
@override
|
|
def compose(self) -> ComposeResult:
|
|
"""Compose CRUD screen layout."""
|
|
yield Container(
|
|
Static(self.get_title(), classes="screen-title"),
|
|
self.create_toolbar(),
|
|
self.create_list_view(),
|
|
LoadingIndicator(id="loading"),
|
|
classes="crud-container",
|
|
)
|
|
|
|
def get_title(self) -> str:
|
|
"""Get screen title."""
|
|
return "CRUD Operations"
|
|
|
|
def create_toolbar(self) -> Container:
|
|
"""Create action toolbar."""
|
|
return Container(
|
|
Button("📝 New", id="new_btn", variant="primary"),
|
|
Button("✏️ Edit", id="edit_btn", variant="default"),
|
|
Button("🗑️ Delete", id="delete_btn", variant="error"),
|
|
Button("🔄 Refresh", id="refresh_btn", variant="default"),
|
|
classes="toolbar",
|
|
)
|
|
|
|
def create_list_view(self) -> DataTable[str]:
|
|
"""Create list view widget."""
|
|
table = DataTable[str](id="items_table")
|
|
table.add_columns(*self.get_table_columns())
|
|
return table
|
|
|
|
def get_table_columns(self) -> list[str]:
|
|
"""Get table column headers."""
|
|
raise NotImplementedError("Subclasses must implement get_table_columns")
|
|
|
|
async def load_items(self) -> list[T]:
|
|
"""Load items from storage."""
|
|
raise NotImplementedError("Subclasses must implement load_items")
|
|
|
|
def item_to_row(self, item: T) -> list[str]:
|
|
"""Convert item to table row."""
|
|
raise NotImplementedError("Subclasses must implement item_to_row")
|
|
|
|
async def create_item_dialog(self) -> T | None:
|
|
"""Show create item dialog."""
|
|
raise NotImplementedError("Subclasses must implement create_item_dialog")
|
|
|
|
async def edit_item_dialog(self, item: T) -> T | None:
|
|
"""Show edit item dialog."""
|
|
raise NotImplementedError("Subclasses must implement edit_item_dialog")
|
|
|
|
async def delete_item(self, item: T) -> bool:
|
|
"""Delete item."""
|
|
raise NotImplementedError("Subclasses must implement delete_item")
|
|
|
|
def on_mount(self) -> None:
|
|
"""Initialize screen."""
|
|
self.query_one("#loading").display = False
|
|
self.refresh_items()
|
|
|
|
@work(exclusive=True)
|
|
async def refresh_items(self) -> None:
|
|
"""Refresh items list."""
|
|
self.set_loading(True)
|
|
try:
|
|
self.items = await self.load_items()
|
|
await self.update_table()
|
|
finally:
|
|
self.set_loading(False)
|
|
|
|
async def update_table(self) -> None:
|
|
"""Update table with current items."""
|
|
table = self.query_one("#items_table", DataTable)
|
|
table.clear()
|
|
|
|
for item in self.items:
|
|
row_data = self.item_to_row(item)
|
|
table.add_row(*row_data)
|
|
|
|
def set_loading(self, loading: bool) -> None:
|
|
"""Set loading state."""
|
|
self.loading = loading
|
|
loading_widget = self.query_one("#loading")
|
|
loading_widget.display = loading
|
|
|
|
def action_create_item(self) -> None:
|
|
"""Create new item."""
|
|
self.run_worker(self._create_item_worker())
|
|
|
|
def action_edit_item(self) -> None:
|
|
"""Edit selected item."""
|
|
if self.selected_item:
|
|
self.run_worker(self._edit_item_worker())
|
|
|
|
def action_delete_item(self) -> None:
|
|
"""Delete selected item."""
|
|
if self.selected_item:
|
|
self.run_worker(self._delete_item_worker())
|
|
|
|
def action_refresh(self) -> None:
|
|
"""Refresh items."""
|
|
self.refresh_items()
|
|
|
|
async def _create_item_worker(self) -> None:
|
|
"""Worker for creating items."""
|
|
item = await self.create_item_dialog()
|
|
if item:
|
|
self.refresh_items()
|
|
|
|
async def _edit_item_worker(self) -> None:
|
|
"""Worker for editing items."""
|
|
if self.selected_item:
|
|
item = await self.edit_item_dialog(self.selected_item)
|
|
if item:
|
|
self.refresh_items()
|
|
|
|
async def _delete_item_worker(self) -> None:
|
|
"""Worker for deleting items."""
|
|
if self.selected_item:
|
|
success = await self.delete_item(self.selected_item)
|
|
if success:
|
|
self.refresh_items()
|
|
|
|
|
|
class ListScreen(BaseScreen, Generic[T]):
|
|
"""Base for paginated list views."""
|
|
|
|
def __init__(
|
|
self,
|
|
storage_manager: StorageManager,
|
|
page_size: int = 20,
|
|
*,
|
|
name: str | None = None,
|
|
id: str | None = None,
|
|
classes: str | None = None,
|
|
**kwargs: object,
|
|
) -> None:
|
|
"""Initialize list screen."""
|
|
super().__init__(storage_manager, name=name, id=id, classes=classes)
|
|
self.page_size = page_size
|
|
self.current_page = 0
|
|
self.total_items = 0
|
|
self.items: list[T] = []
|
|
|
|
@override
|
|
def compose(self) -> ComposeResult:
|
|
"""Compose list screen layout."""
|
|
yield Container(
|
|
Static(self.get_title(), classes="screen-title"),
|
|
self.create_filters(),
|
|
self.create_list_view(),
|
|
self.create_pagination(),
|
|
LoadingIndicator(id="loading"),
|
|
classes="list-container",
|
|
)
|
|
|
|
def get_title(self) -> str:
|
|
"""Get screen title."""
|
|
raise NotImplementedError("Subclasses must implement get_title")
|
|
|
|
def create_filters(self) -> Container:
|
|
"""Create filter widgets."""
|
|
raise NotImplementedError("Subclasses must implement create_filters")
|
|
|
|
def create_list_view(self) -> Widget:
|
|
"""Create list view widget."""
|
|
raise NotImplementedError("Subclasses must implement create_list_view")
|
|
|
|
async def load_page(self, page: int, page_size: int) -> tuple[list[T], int]:
|
|
"""Load page of items."""
|
|
raise NotImplementedError("Subclasses must implement load_page")
|
|
|
|
def create_pagination(self) -> Container:
|
|
"""Create pagination controls."""
|
|
return Container(
|
|
Button("◀ Previous", id="prev_btn", variant="default"),
|
|
Static("Page 1 of 1", id="page_info"),
|
|
Button("Next ▶", id="next_btn", variant="default"),
|
|
classes="pagination",
|
|
)
|
|
|
|
def on_mount(self) -> None:
|
|
"""Initialize screen."""
|
|
self.query_one("#loading").display = False
|
|
self.load_current_page()
|
|
|
|
@work(exclusive=True)
|
|
async def load_current_page(self) -> None:
|
|
"""Load current page."""
|
|
self.set_loading(True)
|
|
try:
|
|
self.items, self.total_items = await self.load_page(self.current_page, self.page_size)
|
|
await self.update_list_view()
|
|
self.update_pagination_info()
|
|
finally:
|
|
self.set_loading(False)
|
|
|
|
async def update_list_view(self) -> None:
|
|
"""Update list view with current items."""
|
|
raise NotImplementedError("Subclasses must implement update_list_view")
|
|
|
|
def update_pagination_info(self) -> None:
|
|
"""Update pagination information."""
|
|
total_pages = max(1, (self.total_items + self.page_size - 1) // self.page_size)
|
|
current_page_display = self.current_page + 1
|
|
|
|
page_info = self.query_one("#page_info", Static)
|
|
page_info.update(f"Page {current_page_display} of {total_pages}")
|
|
|
|
prev_btn = self.query_one("#prev_btn", Button)
|
|
next_btn = self.query_one("#next_btn", Button)
|
|
|
|
prev_btn.disabled = self.current_page == 0
|
|
next_btn.disabled = self.current_page >= total_pages - 1
|
|
|
|
def set_loading(self, loading: bool) -> None:
|
|
"""Set loading state."""
|
|
loading_widget = self.query_one("#loading")
|
|
loading_widget.display = loading
|
|
|
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
"""Handle button presses."""
|
|
if event.button.id == "prev_btn" and self.current_page > 0:
|
|
self.current_page -= 1
|
|
self.load_current_page()
|
|
elif event.button.id == "next_btn":
|
|
total_pages = (self.total_items + self.page_size - 1) // self.page_size
|
|
if self.current_page < total_pages - 1:
|
|
self.current_page += 1
|
|
self.load_current_page()
|
|
|
|
|
|
class FormScreen(ModalScreen[T], Generic[T]):
|
|
"""Base for input forms with validation."""
|
|
|
|
BINDINGS = [
|
|
Binding("escape", "app.pop_screen", "Cancel"),
|
|
Binding("ctrl+s", "save", "Save"),
|
|
Binding("enter", "save", "Save"),
|
|
]
|
|
|
|
def __init__(
|
|
self,
|
|
item: T | None = None,
|
|
*,
|
|
name: str | None = None,
|
|
id: str | None = None,
|
|
classes: str | None = None,
|
|
**kwargs: object,
|
|
) -> None:
|
|
"""Initialize form screen."""
|
|
super().__init__(name=name, id=id, classes=classes)
|
|
# Ignore any additional kwargs to avoid type issues
|
|
self.item = item
|
|
self.is_edit_mode = item is not None
|
|
|
|
@override
|
|
def compose(self) -> ComposeResult:
|
|
"""Compose form layout."""
|
|
title = "Edit" if self.is_edit_mode else "Create"
|
|
yield Container(
|
|
Static(f"{title} {self.get_item_type()}", classes="form-title"),
|
|
self.create_form_fields(),
|
|
Container(
|
|
Button("💾 Save", id="save_btn", variant="success"),
|
|
Button("❌ Cancel", id="cancel_btn", variant="default"),
|
|
classes="form-actions",
|
|
),
|
|
classes="form-container",
|
|
)
|
|
|
|
def get_item_type(self) -> str:
|
|
"""Get item type name for title."""
|
|
raise NotImplementedError("Subclasses must implement get_item_type")
|
|
|
|
def create_form_fields(self) -> Container:
|
|
"""Create form input fields."""
|
|
raise NotImplementedError("Subclasses must implement create_form_fields")
|
|
|
|
def validate_form(self) -> tuple[bool, list[str]]:
|
|
"""Validate form data."""
|
|
raise NotImplementedError("Subclasses must implement validate_form")
|
|
|
|
def get_form_data(self) -> T:
|
|
"""Get item from form data."""
|
|
raise NotImplementedError("Subclasses must implement get_form_data")
|
|
|
|
def on_mount(self) -> None:
|
|
"""Initialize form."""
|
|
if self.is_edit_mode and self.item:
|
|
self.populate_form(self.item)
|
|
|
|
def populate_form(self, item: T) -> None:
|
|
"""Populate form with item data."""
|
|
raise NotImplementedError("Subclasses must implement populate_form")
|
|
|
|
def action_save(self) -> None:
|
|
"""Save form data."""
|
|
is_valid, errors = self.validate_form()
|
|
if is_valid:
|
|
try:
|
|
item = self.get_form_data()
|
|
self.dismiss(item)
|
|
except Exception as e:
|
|
self.show_validation_errors([str(e)])
|
|
else:
|
|
self.show_validation_errors(errors)
|
|
|
|
def show_validation_errors(self, errors: list[str]) -> None:
|
|
"""Show validation errors to user."""
|
|
# This would typically show a notification or update error display
|
|
pass
|
|
|
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
"""Handle button presses."""
|
|
if event.button.id == "save_btn":
|
|
self.action_save()
|
|
elif event.button.id == "cancel_btn":
|
|
self.dismiss(None)
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/cli/tui/widgets/indicators.py">
|
|
"""Status indicators and progress bars with enhanced visual feedback."""
|
|
|
|
from typing import Any
|
|
|
|
from textual.app import ComposeResult
|
|
from textual.widgets import ProgressBar, Static
|
|
from typing_extensions import override
|
|
|
|
|
|
class StatusIndicator(Static):
|
|
"""Modern status indicator with color coding and animations."""
|
|
|
|
status: str
|
|
|
|
def __init__(self, status: str, **kwargs: Any) -> None:
|
|
super().__init__(**kwargs)
|
|
self.status = status
|
|
self.update_status(status)
|
|
|
|
def update_status(self, status: str) -> None:
|
|
"""Update the status display with enhanced visual feedback."""
|
|
self.status = status
|
|
|
|
# Remove previous status classes
|
|
self.remove_class("status-active", "status-error", "status-warning", "pulse", "glow")
|
|
|
|
status_lower = status.lower()
|
|
|
|
if (
|
|
status_lower in {"active", "online", "connected", "✓ active"}
|
|
or status_lower.endswith("active")
|
|
or "✓" in status_lower
|
|
and "active" in status_lower
|
|
):
|
|
self.add_class("status-active")
|
|
self.add_class("glow")
|
|
self.update(f"🟢 {status}")
|
|
elif status_lower in {"error", "failed", "offline", "disconnected"}:
|
|
self.add_class("status-error")
|
|
self.add_class("pulse")
|
|
self.update(f"🔴 {status}")
|
|
elif status_lower in {"warning", "pending", "in_progress"}:
|
|
self.add_class("status-warning")
|
|
self.add_class("pulse")
|
|
self.update(f"🟡 {status}")
|
|
elif status_lower in {"loading", "connecting"}:
|
|
self.add_class("shimmer")
|
|
self.update(f"🔄 {status}")
|
|
else:
|
|
self.update(f"⚪ {status}")
|
|
|
|
|
|
class EnhancedProgressBar(Static):
|
|
"""Enhanced progress bar with better visual feedback."""
|
|
|
|
total: int
|
|
progress: int
|
|
status_text: str
|
|
|
|
def __init__(self, total: int = 100, **kwargs: Any) -> None:
|
|
super().__init__(**kwargs)
|
|
self.total = total
|
|
self.progress = 0
|
|
self.status_text = "Ready"
|
|
|
|
@override
|
|
def compose(self) -> ComposeResult:
|
|
yield Static("", id="progress_status", classes="progress-label")
|
|
yield ProgressBar(total=self.total, id="progress_bar", show_eta=True, classes="shimmer")
|
|
|
|
def update_progress(self, progress: int, status: str = "") -> None:
|
|
"""Update progress with enhanced feedback."""
|
|
self.progress = progress
|
|
if status:
|
|
self.status_text = status
|
|
|
|
# Update the progress bar
|
|
progress_bar = self.query_one("#progress_bar", ProgressBar)
|
|
progress_bar.update(progress=progress)
|
|
|
|
# Update status text with icons
|
|
status_display = self.query_one("#progress_status", Static)
|
|
if progress >= 100:
|
|
status_display.update(f"✅ {self.status_text}")
|
|
progress_bar.add_class("glow")
|
|
elif progress >= 75:
|
|
status_display.update(f"🔥 {self.status_text}")
|
|
elif progress >= 50:
|
|
status_display.update(f"⚡ {self.status_text}")
|
|
elif progress >= 25:
|
|
status_display.update(f"🔄 {self.status_text}")
|
|
else:
|
|
status_display.update(f"🚀 {self.status_text}")
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/cli/tui/models.py">
|
|
"""Data models and TypedDict definitions for the TUI."""
|
|
|
|
from enum import IntEnum
|
|
from typing import TypedDict
|
|
|
|
|
|
class StorageCapabilities(IntEnum):
|
|
"""Storage backend capabilities (ordered by feature completeness)."""
|
|
|
|
NONE = 0
|
|
BASIC = 1 # Basic CRUD operations
|
|
VECTOR_SEARCH = 2 # Vector search capabilities
|
|
KNOWLEDGE_BASE = 3 # Knowledge base features
|
|
FULL_FEATURED = 4 # All features including chunks and entities
|
|
|
|
|
|
class CollectionInfo(TypedDict):
|
|
"""Information about a collection."""
|
|
|
|
name: str
|
|
type: str
|
|
count: int
|
|
backend: str | list[str] # Support both single backend and multi-backend
|
|
status: str
|
|
last_updated: str
|
|
size_mb: float
|
|
|
|
|
|
class DocumentInfo(TypedDict):
|
|
"""Information about a document."""
|
|
|
|
id: str
|
|
title: str
|
|
source_url: str
|
|
description: str
|
|
content_type: str
|
|
content_preview: str
|
|
word_count: int
|
|
timestamp: str
|
|
|
|
|
|
class ChunkInfo(TypedDict):
|
|
"""Information about a document chunk (R2R specific)."""
|
|
|
|
id: str
|
|
document_id: str
|
|
content: str
|
|
start_index: int
|
|
end_index: int
|
|
metadata: dict[str, object]
|
|
|
|
|
|
class EntityInfo(TypedDict):
|
|
"""Information about an extracted entity (R2R specific)."""
|
|
|
|
id: str
|
|
name: str
|
|
type: str
|
|
confidence: float
|
|
metadata: dict[str, object]
|
|
|
|
|
|
class FirecrawlOptions(TypedDict, total=False):
|
|
"""Advanced Firecrawl scraping options."""
|
|
|
|
# Scraping options
|
|
formats: list[str] # ["markdown", "html", "screenshot"]
|
|
only_main_content: bool
|
|
include_tags: list[str]
|
|
exclude_tags: list[str]
|
|
wait_for: int # milliseconds
|
|
|
|
# Mapping options
|
|
search: str | None
|
|
include_subdomains: bool
|
|
limit: int
|
|
max_depth: int
|
|
|
|
# Extraction options
|
|
extract_schema: dict[str, object] | None
|
|
extract_prompt: str | None
|
|
|
|
|
|
class IngestionConfig(TypedDict):
|
|
"""Configuration for ingestion operations."""
|
|
|
|
source_url: str
|
|
source_type: str # "web", "repository", "documentation"
|
|
target_collection: str
|
|
storage_backend: str
|
|
firecrawl_options: FirecrawlOptions
|
|
batch_size: int
|
|
max_concurrent: int
|
|
|
|
|
|
class SearchFilter(TypedDict, total=False):
|
|
"""Search filtering options."""
|
|
|
|
backends: list[str]
|
|
collections: list[str]
|
|
content_types: list[str]
|
|
date_range: tuple[str, str] | None
|
|
word_count_range: tuple[int, int] | None
|
|
similarity_threshold: float
|
|
|
|
|
|
class IngestionProgress(TypedDict):
|
|
"""Real-time ingestion progress information."""
|
|
|
|
total_urls: int
|
|
processed_urls: int
|
|
successful_ingestions: int
|
|
failed_ingestions: int
|
|
current_url: str
|
|
elapsed_time: float
|
|
estimated_remaining: float
|
|
errors: list[str]
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/ingestors/__init__.py">
|
|
"""Ingestors module for different data sources."""
|
|
|
|
from .base import BaseIngestor
|
|
from .firecrawl import FirecrawlIngestor, FirecrawlPage
|
|
from .repomix import RepomixIngestor
|
|
|
|
__all__ = [
|
|
"BaseIngestor",
|
|
"FirecrawlIngestor",
|
|
"FirecrawlPage",
|
|
"RepomixIngestor",
|
|
]
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/storage/r2r/collections.py">
|
|
"""Comprehensive collection CRUD operations for R2R."""
|
|
|
|
from typing import TypedDict, cast
|
|
from uuid import UUID
|
|
|
|
from r2r import R2RAsyncClient
|
|
|
|
from ...core.exceptions import StorageError
|
|
|
|
# JSON serializable type for API responses
|
|
JsonData = dict[str, str | int | bool | None]
|
|
|
|
|
|
class DocumentAddResult(TypedDict, total=False):
|
|
"""Result of adding a document to a collection."""
|
|
|
|
document_id: str
|
|
added: bool
|
|
result: JsonData
|
|
error: str
|
|
|
|
|
|
class DocumentRemoveResult(TypedDict, total=False):
|
|
"""Result of removing a document from a collection."""
|
|
|
|
document_id: str
|
|
removed: bool
|
|
error: str
|
|
|
|
|
|
class ExportResult(TypedDict):
|
|
"""Result of a CSV export operation."""
|
|
|
|
exported: int
|
|
path: str
|
|
|
|
|
|
class R2RCollections:
|
|
"""Comprehensive collection management for R2R."""
|
|
|
|
client: R2RAsyncClient
|
|
|
|
def __init__(self, client: R2RAsyncClient) -> None:
|
|
"""Initialize collections manager with R2R client."""
|
|
self.client = client
|
|
|
|
async def create(self, name: str, description: str | None = None) -> JsonData:
|
|
"""Create a new collection in R2R.
|
|
|
|
Args:
|
|
name: Collection name
|
|
description: Optional collection description
|
|
|
|
Returns:
|
|
Created collection information
|
|
|
|
Raises:
|
|
StorageError: If collection creation fails
|
|
"""
|
|
try:
|
|
response = await self.client.collections.create(
|
|
name=name,
|
|
description=description,
|
|
)
|
|
# response.results is a list, not a model with model_dump()
|
|
return cast(JsonData, response.results if isinstance(response.results, list) else [])
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to create collection '{name}': {e}") from e
|
|
|
|
async def retrieve(self, collection_id: str | UUID) -> JsonData:
|
|
"""Retrieve a collection by ID.
|
|
|
|
Args:
|
|
collection_id: Collection ID to retrieve
|
|
|
|
Returns:
|
|
Collection information
|
|
|
|
Raises:
|
|
StorageError: If collection retrieval fails
|
|
"""
|
|
try:
|
|
response = await self.client.collections.retrieve(str(collection_id))
|
|
# response.results is a list, not a model with model_dump()
|
|
return cast(JsonData, response.results if isinstance(response.results, list) else [])
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to retrieve collection {collection_id}: {e}") from e
|
|
|
|
async def update(
|
|
self,
|
|
collection_id: str | UUID,
|
|
name: str | None = None,
|
|
description: str | None = None,
|
|
) -> JsonData:
|
|
"""Update collection metadata.
|
|
|
|
Args:
|
|
collection_id: Collection ID to update
|
|
name: New name (optional)
|
|
description: New description (optional)
|
|
|
|
Returns:
|
|
Updated collection information
|
|
|
|
Raises:
|
|
StorageError: If collection update fails
|
|
"""
|
|
try:
|
|
response = await self.client.collections.update(
|
|
id=str(collection_id),
|
|
name=name,
|
|
description=description,
|
|
)
|
|
# response.results is a list, not a model with model_dump()
|
|
return cast(JsonData, response.results if isinstance(response.results, list) else [])
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to update collection {collection_id}: {e}") from e
|
|
|
|
async def delete(self, collection_id: str | UUID) -> bool:
|
|
"""Delete a collection by ID.
|
|
|
|
Args:
|
|
collection_id: Collection ID to delete
|
|
|
|
Returns:
|
|
True if deletion was successful
|
|
|
|
Raises:
|
|
StorageError: If collection deletion fails
|
|
"""
|
|
try:
|
|
_ = await self.client.collections.delete(str(collection_id))
|
|
return True
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to delete collection {collection_id}: {e}") from e
|
|
|
|
async def list_all(
|
|
self, offset: int = 0, limit: int = 100, owner_only: bool = False
|
|
) -> JsonData:
|
|
"""List collections with pagination support.
|
|
|
|
Args:
|
|
offset: Starting offset for pagination
|
|
limit: Maximum number of collections to return
|
|
owner_only: Only return collections owned by current user
|
|
|
|
Returns:
|
|
Paginated list of collections
|
|
|
|
Raises:
|
|
StorageError: If collection listing fails
|
|
"""
|
|
try:
|
|
response = await self.client.collections.list(
|
|
offset=offset,
|
|
limit=limit,
|
|
owner_only=owner_only,
|
|
)
|
|
# response.results is a list, not a model with model_dump()
|
|
return cast(JsonData, response.results if isinstance(response.results, list) else [])
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to list collections: {e}") from e
|
|
|
|
async def get_by_name(
|
|
self, collection_name: str, owner_id: str | UUID | None = None
|
|
) -> JsonData:
|
|
"""Get collection by name with optional owner filter.
|
|
|
|
Args:
|
|
collection_name: Name of the collection
|
|
owner_id: Optional owner ID filter
|
|
|
|
Returns:
|
|
Collection information
|
|
|
|
Raises:
|
|
StorageError: If collection retrieval fails
|
|
"""
|
|
try:
|
|
# List all collections and find by name
|
|
collections_response = await self.client.collections.list()
|
|
for collection in collections_response.results:
|
|
if (
|
|
owner_id is None or str(collection.owner_id) == str(owner_id)
|
|
) and collection.name == collection_name:
|
|
return cast(JsonData, collection.model_dump())
|
|
raise StorageError(f"Collection '{collection_name}' not found")
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to get collection by name '{collection_name}': {e}") from e
|
|
|
|
async def add_document(self, collection_id: str | UUID, document_id: str | UUID) -> JsonData:
|
|
"""Associate a document with a collection.
|
|
|
|
Args:
|
|
collection_id: Collection ID
|
|
document_id: Document ID to add
|
|
|
|
Returns:
|
|
Association result
|
|
|
|
Raises:
|
|
StorageError: If document association fails
|
|
"""
|
|
try:
|
|
response = await self.client.collections.add_document(
|
|
id=str(collection_id),
|
|
document_id=str(document_id),
|
|
)
|
|
# response.results is a list, not a model with model_dump()
|
|
return cast(JsonData, response.results if isinstance(response.results, list) else [])
|
|
except Exception as e:
|
|
raise StorageError(
|
|
f"Failed to add document {document_id} to collection {collection_id}: {e}"
|
|
) from e
|
|
|
|
async def remove_document(self, collection_id: str | UUID, document_id: str | UUID) -> bool:
|
|
"""Remove document association from collection.
|
|
|
|
Args:
|
|
collection_id: Collection ID
|
|
document_id: Document ID to remove
|
|
|
|
Returns:
|
|
True if removal was successful
|
|
|
|
Raises:
|
|
StorageError: If document removal fails
|
|
"""
|
|
try:
|
|
await self.client.collections.remove_document(
|
|
id=str(collection_id),
|
|
document_id=str(document_id),
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
raise StorageError(
|
|
f"Failed to remove document {document_id} from collection {collection_id}: {e}"
|
|
) from e
|
|
|
|
async def list_documents(
|
|
self, collection_id: str | UUID, offset: int = 0, limit: int = 100
|
|
) -> JsonData:
|
|
"""List all documents in a collection with pagination.
|
|
|
|
Args:
|
|
collection_id: Collection ID
|
|
offset: Starting offset for pagination
|
|
limit: Maximum number of documents to return
|
|
|
|
Returns:
|
|
Paginated list of documents in collection
|
|
|
|
Raises:
|
|
StorageError: If document listing fails
|
|
"""
|
|
try:
|
|
response = await self.client.collections.list_documents(
|
|
id=str(collection_id),
|
|
offset=offset,
|
|
limit=limit,
|
|
)
|
|
# response.results is a list, not a model with model_dump()
|
|
return cast(JsonData, response.results if isinstance(response.results, list) else [])
|
|
except Exception as e:
|
|
raise StorageError(
|
|
f"Failed to list documents in collection {collection_id}: {e}"
|
|
) from e
|
|
|
|
async def add_user(self, collection_id: str | UUID, user_id: str | UUID) -> JsonData:
|
|
"""Grant user access to a collection.
|
|
|
|
Args:
|
|
collection_id: Collection ID
|
|
user_id: User ID to grant access
|
|
|
|
Returns:
|
|
Access grant result
|
|
|
|
Raises:
|
|
StorageError: If user access grant fails
|
|
"""
|
|
try:
|
|
response = await self.client.collections.add_user(
|
|
id=str(collection_id),
|
|
user_id=str(user_id),
|
|
)
|
|
# response.results is a list, not a model with model_dump()
|
|
return cast(JsonData, response.results if isinstance(response.results, list) else [])
|
|
except Exception as e:
|
|
raise StorageError(
|
|
f"Failed to add user {user_id} to collection {collection_id}: {e}"
|
|
) from e
|
|
|
|
async def remove_user(self, collection_id: str | UUID, user_id: str | UUID) -> bool:
|
|
"""Revoke user access from a collection.
|
|
|
|
Args:
|
|
collection_id: Collection ID
|
|
user_id: User ID to revoke access
|
|
|
|
Returns:
|
|
True if revocation was successful
|
|
|
|
Raises:
|
|
StorageError: If user access revocation fails
|
|
"""
|
|
try:
|
|
await self.client.collections.remove_user(
|
|
id=str(collection_id),
|
|
user_id=str(user_id),
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
raise StorageError(
|
|
f"Failed to remove user {user_id} from collection {collection_id}: {e}"
|
|
) from e
|
|
|
|
async def list_users(
|
|
self, collection_id: str | UUID, offset: int = 0, limit: int = 100
|
|
) -> JsonData:
|
|
"""List all users with access to a collection.
|
|
|
|
Args:
|
|
collection_id: Collection ID
|
|
offset: Starting offset for pagination
|
|
limit: Maximum number of users to return
|
|
|
|
Returns:
|
|
Paginated list of users with collection access
|
|
|
|
Raises:
|
|
StorageError: If user listing fails
|
|
"""
|
|
try:
|
|
response = await self.client.collections.list_users(
|
|
id=str(collection_id),
|
|
offset=offset,
|
|
limit=limit,
|
|
)
|
|
# response.results is a list, not a model with model_dump()
|
|
return cast(JsonData, response.results if isinstance(response.results, list) else [])
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to list users for collection {collection_id}: {e}") from e
|
|
|
|
async def extract_entities(
|
|
self,
|
|
collection_id: str | UUID,
|
|
run_with_orchestration: bool = True,
|
|
settings: JsonData | None = None,
|
|
) -> JsonData:
|
|
"""Extract entities and relationships from collection documents.
|
|
|
|
Args:
|
|
collection_id: Collection ID
|
|
run_with_orchestration: Whether to run with orchestration
|
|
settings: Extraction configuration settings
|
|
|
|
Returns:
|
|
Extraction results
|
|
|
|
Raises:
|
|
StorageError: If entity extraction fails
|
|
"""
|
|
try:
|
|
response = await self.client.collections.extract(
|
|
id=str(collection_id),
|
|
run_with_orchestration=run_with_orchestration,
|
|
settings=cast(dict[str, object], settings or {}),
|
|
)
|
|
# response.results is a list, not a model with model_dump()
|
|
return cast(JsonData, response.results if isinstance(response.results, list) else [])
|
|
except Exception as e:
|
|
raise StorageError(
|
|
f"Failed to extract entities from collection {collection_id}: {e}"
|
|
) from e
|
|
|
|
async def export_to_csv(
|
|
self, output_path: str, columns: list[str] | None = None, include_header: bool = True
|
|
) -> ExportResult:
|
|
"""Export collections to CSV format.
|
|
|
|
Args:
|
|
output_path: Path for the exported CSV file
|
|
columns: Specific columns to export (optional)
|
|
include_header: Whether to include header row
|
|
|
|
Returns:
|
|
Export result information
|
|
|
|
Raises:
|
|
StorageError: If export fails
|
|
"""
|
|
# R2R SDK doesn't currently support collection export
|
|
# Implement a basic CSV export using list()
|
|
try:
|
|
import csv
|
|
from pathlib import Path
|
|
|
|
collections_response = await self.client.collections.list()
|
|
collections_data = [
|
|
{
|
|
"id": str(c.id),
|
|
"name": c.name,
|
|
"description": c.description or "",
|
|
"owner_id": str(c.owner_id) if hasattr(c, "owner_id") else "",
|
|
}
|
|
for c in collections_response.results
|
|
]
|
|
|
|
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(output_path, "w", newline="", encoding="utf-8") as csvfile:
|
|
if not collections_data:
|
|
return {"exported": 0, "path": output_path}
|
|
|
|
fieldnames = columns or list(collections_data[0].keys())
|
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
|
|
|
if include_header:
|
|
writer.writeheader()
|
|
|
|
for collection in collections_data:
|
|
filtered_collection = {k: v for k, v in collection.items() if k in fieldnames}
|
|
writer.writerow(filtered_collection)
|
|
|
|
return {"exported": len(collections_data), "path": output_path}
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to export collections: {e}") from e
|
|
|
|
async def batch_add_documents(
|
|
self, collection_id: str | UUID, document_ids: list[str | UUID]
|
|
) -> list[DocumentAddResult]:
|
|
"""Add multiple documents to a collection efficiently.
|
|
|
|
Args:
|
|
collection_id: Collection ID
|
|
document_ids: List of document IDs to add
|
|
|
|
Returns:
|
|
List of addition results
|
|
"""
|
|
results: list[DocumentAddResult] = []
|
|
for doc_id in document_ids:
|
|
try:
|
|
result = await self.add_document(collection_id, doc_id)
|
|
results.append({"document_id": str(doc_id), "added": True, "result": result})
|
|
except StorageError as e:
|
|
results.append({"document_id": str(doc_id), "added": False, "error": str(e)})
|
|
return results
|
|
|
|
async def batch_remove_documents(
|
|
self, collection_id: str | UUID, document_ids: list[str | UUID]
|
|
) -> list[DocumentRemoveResult]:
|
|
"""Remove multiple documents from a collection efficiently.
|
|
|
|
Args:
|
|
collection_id: Collection ID
|
|
document_ids: List of document IDs to remove
|
|
|
|
Returns:
|
|
List of removal results
|
|
"""
|
|
results: list[DocumentRemoveResult] = []
|
|
for doc_id in document_ids:
|
|
try:
|
|
success = await self.remove_document(collection_id, doc_id)
|
|
results.append({"document_id": str(doc_id), "removed": success})
|
|
except StorageError as e:
|
|
results.append({"document_id": str(doc_id), "removed": False, "error": str(e)})
|
|
return results
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/storage/__init__.py">
|
|
"""Storage adapters for different backends."""
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
from .base import BaseStorage
|
|
from .openwebui import OpenWebUIStorage
|
|
from .weaviate import WeaviateStorage
|
|
|
|
if TYPE_CHECKING:
|
|
from .r2r import R2RStorage as _R2RStorage
|
|
|
|
try:
|
|
from .r2r.storage import R2RStorage as _RuntimeR2RStorage
|
|
|
|
R2RStorage: type[BaseStorage] | None = _RuntimeR2RStorage
|
|
except ImportError:
|
|
R2RStorage = None
|
|
|
|
__all__ = [
|
|
"BaseStorage",
|
|
"WeaviateStorage",
|
|
"OpenWebUIStorage",
|
|
"R2RStorage",
|
|
"_R2RStorage",
|
|
]
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/cli/tui/screens/ingestion.py">
|
|
"""Enhanced ingestion screen with multi-storage support."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, cast
|
|
|
|
from textual import work
|
|
from textual.app import ComposeResult
|
|
from textual.binding import Binding
|
|
from textual.containers import Container, Horizontal
|
|
from textual.screen import ModalScreen
|
|
from textual.widgets import Button, Checkbox, Input, Label, LoadingIndicator, Rule, Static
|
|
from typing_extensions import override
|
|
|
|
from ....core.models import IngestionResult, IngestionSource, StorageBackend
|
|
from ....flows.ingestion import create_ingestion_flow
|
|
from ..models import CollectionInfo
|
|
from ..utils.storage_manager import StorageManager
|
|
from ..widgets import EnhancedProgressBar
|
|
|
|
if TYPE_CHECKING:
|
|
from ..app import CollectionManagementApp
|
|
|
|
|
|
BACKEND_ORDER: tuple[StorageBackend, ...] = (
|
|
StorageBackend.WEAVIATE,
|
|
StorageBackend.OPEN_WEBUI,
|
|
StorageBackend.R2R,
|
|
)
|
|
|
|
BACKEND_LABELS: dict[StorageBackend, str] = {
|
|
StorageBackend.WEAVIATE: "🗄️ Weaviate",
|
|
StorageBackend.OPEN_WEBUI: "🌐 OpenWebUI",
|
|
StorageBackend.R2R: "🧠 R2R",
|
|
}
|
|
|
|
|
|
class IngestionScreen(ModalScreen[None]):
|
|
"""Modern ingestion screen with multi-backend fan-out."""
|
|
|
|
collection: CollectionInfo
|
|
storage_manager: StorageManager
|
|
selected_type: IngestionSource
|
|
progress_value: int
|
|
available_backends: list[StorageBackend]
|
|
selected_backends: list[StorageBackend]
|
|
|
|
BINDINGS = [
|
|
Binding("escape", "app.pop_screen", "Cancel"),
|
|
Binding("ctrl+i", "start_ingestion", "Start"),
|
|
Binding("1", "select_web", "Web", show=False),
|
|
Binding("2", "select_repo", "Repository", show=False),
|
|
Binding("3", "select_docs", "Documentation", show=False),
|
|
Binding("enter", "start_ingestion", "Start Ingestion"),
|
|
Binding("tab", "focus_next", "Next Field"),
|
|
Binding("shift+tab", "focus_previous", "Previous Field"),
|
|
]
|
|
|
|
def __init__(self, collection: CollectionInfo, storage_manager: StorageManager) -> None:
|
|
super().__init__()
|
|
self.collection = collection
|
|
self.storage_manager = storage_manager
|
|
self.selected_type = IngestionSource.WEB
|
|
self.progress_value = 0
|
|
self.available_backends = list(storage_manager.get_available_backends())
|
|
if not self.available_backends:
|
|
raise ValueError("No storage backends are available for ingestion")
|
|
self.selected_backends = self._derive_initial_backends()
|
|
|
|
@override
|
|
def compose(self) -> ComposeResult:
|
|
target_name = self.collection["name"]
|
|
backend_info = self.collection["backend"]
|
|
|
|
# Format backend label for display
|
|
if isinstance(backend_info, list):
|
|
# Ensure all elements are strings for safe joining
|
|
backend_strings = [str(b) for b in backend_info if b is not None]
|
|
target_backend_label = " + ".join(backend_strings) if backend_strings else "unknown"
|
|
else:
|
|
target_backend_label = str(backend_info) if backend_info is not None else "unknown"
|
|
|
|
with Container(classes="modal-container"):
|
|
yield Static("📥 Modern Ingestion Interface", classes="title")
|
|
yield Static(
|
|
f"Target: {target_name} ({target_backend_label})",
|
|
classes="subtitle",
|
|
)
|
|
yield Rule()
|
|
|
|
yield Container(
|
|
Label("🌐 Source URL:", classes="input-label"),
|
|
Input(
|
|
placeholder="https://docs.example.com or file:///path/to/repo",
|
|
id="url_input",
|
|
classes="modern-input",
|
|
),
|
|
Label("📝 Collection Name:", classes="input-label"),
|
|
Input(
|
|
placeholder="Enter collection name (or leave empty to auto-generate)",
|
|
id="collection_input",
|
|
classes="modern-input",
|
|
value=self.collection.get("name", ""),
|
|
),
|
|
Label("📋 Source Type (Press 1/2/3):", classes="input-label"),
|
|
Horizontal(
|
|
Button("🌐 Web (1)", id="web_btn", variant="primary", classes="type-button"),
|
|
Button(
|
|
"📦 Repository (2)", id="repo_btn", variant="default", classes="type-button"
|
|
),
|
|
Button(
|
|
"📖 Documentation (3)",
|
|
id="docs_btn",
|
|
variant="default",
|
|
classes="type-button",
|
|
),
|
|
classes="type_buttons",
|
|
),
|
|
Rule(line_style="dashed"),
|
|
Label(
|
|
f"🗄️ Target Storages ({len(self.available_backends)} available):",
|
|
classes="input-label",
|
|
id="backend_label",
|
|
),
|
|
Container(
|
|
*self._create_backend_checkbox_widgets(),
|
|
classes="backend-selection",
|
|
),
|
|
Container(
|
|
Button("Select All Storages", id="select_all_backends", variant="default"),
|
|
Button("Clear Selection", id="clear_backends", variant="default"),
|
|
classes="backend-actions",
|
|
),
|
|
Static("📋 Selected: None", id="selection_status", classes="selection-status"),
|
|
classes="input-section card",
|
|
)
|
|
|
|
yield Container(
|
|
Label("🔄 Progress:", classes="progress-label"),
|
|
EnhancedProgressBar(id="enhanced_progress", total=100),
|
|
Static("Ready to start", id="progress_text", classes="status-text"),
|
|
classes="progress-section card",
|
|
)
|
|
|
|
yield Horizontal(
|
|
Button("🚀 Start Ingestion", id="start_btn", variant="success"),
|
|
Button("❌ Cancel", id="cancel_btn", variant="error"),
|
|
classes="action_buttons",
|
|
)
|
|
|
|
yield LoadingIndicator(id="loading", classes="pulse")
|
|
|
|
def _create_backend_checkbox_widgets(self) -> list[Checkbox]:
|
|
"""Create checkbox widgets for each available backend."""
|
|
checkboxes: list[Checkbox] = [
|
|
Checkbox(
|
|
BACKEND_LABELS.get(backend, backend.value),
|
|
value=backend in self.selected_backends,
|
|
id=f"backend_{backend.value}",
|
|
)
|
|
for backend in BACKEND_ORDER
|
|
if backend in self.available_backends
|
|
]
|
|
return checkboxes
|
|
|
|
def on_mount(self) -> None:
|
|
"""Initialize the screen state once widgets exist."""
|
|
self.query_one("#loading").display = False
|
|
self.query_one("#url_input", Input).focus()
|
|
self._set_backend_selection(self.selected_backends)
|
|
self._update_selection_status()
|
|
|
|
def action_select_web(self) -> None:
|
|
self.selected_type = IngestionSource.WEB
|
|
self._update_type_buttons("web")
|
|
|
|
def action_select_repo(self) -> None:
|
|
self.selected_type = IngestionSource.REPOSITORY
|
|
self._update_type_buttons("repo")
|
|
|
|
def action_select_docs(self) -> None:
|
|
self.selected_type = IngestionSource.DOCUMENTATION
|
|
self._update_type_buttons("docs")
|
|
|
|
def _update_type_buttons(self, selected: str) -> None:
|
|
buttons = {
|
|
"web": self.query_one("#web_btn", Button),
|
|
"repo": self.query_one("#repo_btn", Button),
|
|
"docs": self.query_one("#docs_btn", Button),
|
|
}
|
|
for kind, button in buttons.items():
|
|
button.variant = "primary" if kind == selected else "default"
|
|
|
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
button_id = event.button.id
|
|
if button_id == "web_btn":
|
|
self.action_select_web()
|
|
elif button_id == "repo_btn":
|
|
self.action_select_repo()
|
|
elif button_id == "docs_btn":
|
|
self.action_select_docs()
|
|
elif button_id == "select_all_backends":
|
|
self._set_backend_selection(self.available_backends)
|
|
self._update_selection_status()
|
|
elif button_id == "clear_backends":
|
|
self._set_backend_selection([])
|
|
self._update_selection_status()
|
|
elif button_id == "start_btn":
|
|
self.action_start_ingestion()
|
|
elif button_id == "cancel_btn":
|
|
self.app.pop_screen()
|
|
|
|
def on_checkbox_changed(self, event: Checkbox.Changed) -> None:
|
|
"""Handle checkbox state changes for backend selection."""
|
|
if event.checkbox.id and event.checkbox.id.startswith("backend_"):
|
|
# Update the selected backends list based on current checkbox states
|
|
self.selected_backends = self._resolve_selected_backends()
|
|
self._update_selection_status()
|
|
|
|
def on_input_submitted(self, event: Input.Submitted) -> None:
|
|
if event.input.id in ("url_input", "collection_input"):
|
|
self.action_start_ingestion()
|
|
|
|
def action_start_ingestion(self) -> None:
|
|
url_input = self.query_one("#url_input", Input)
|
|
collection_input = self.query_one("#collection_input", Input)
|
|
|
|
source_url = url_input.value.strip()
|
|
collection_name = collection_input.value.strip()
|
|
|
|
if not source_url:
|
|
cast("CollectionManagementApp", self.app).safe_notify(
|
|
"🔍 Please enter a source URL", severity="error"
|
|
)
|
|
url_input.focus()
|
|
return
|
|
|
|
# Validate URL format
|
|
if not self._validate_url(source_url):
|
|
cast("CollectionManagementApp", self.app).safe_notify(
|
|
"❌ Invalid URL format. Please enter a valid HTTP/HTTPS URL or file:// path",
|
|
severity="error",
|
|
)
|
|
url_input.focus()
|
|
return
|
|
|
|
resolved_backends = self._resolve_selected_backends()
|
|
if not resolved_backends:
|
|
cast("CollectionManagementApp", self.app).safe_notify(
|
|
"⚠️ Select at least one storage backend", severity="warning"
|
|
)
|
|
return
|
|
|
|
self.selected_backends = resolved_backends
|
|
self.perform_ingestion(source_url, collection_name)
|
|
|
|
def _validate_url(self, url: str) -> bool:
|
|
"""Validate URL format for security."""
|
|
if not url:
|
|
return False
|
|
|
|
# Basic URL validation
|
|
url_lower = url.lower()
|
|
|
|
# Allow HTTP/HTTPS URLs
|
|
if url_lower.startswith(("http://", "https://")):
|
|
# Additional validation could be added here
|
|
return True
|
|
|
|
# Allow file:// URLs for repository paths
|
|
if url_lower.startswith("file://"):
|
|
return True
|
|
|
|
# Allow local file paths that look like repositories
|
|
return "/" in url and not url_lower.startswith(("javascript:", "data:", "vbscript:"))
|
|
|
|
def _resolve_selected_backends(self) -> list[StorageBackend]:
|
|
selected: list[StorageBackend] = []
|
|
for backend in BACKEND_ORDER:
|
|
if backend not in self.available_backends:
|
|
continue
|
|
checkbox_id = f"#backend_{backend.value}"
|
|
checkbox = self.query_one(checkbox_id, Checkbox)
|
|
if checkbox.value:
|
|
selected.append(backend)
|
|
return selected
|
|
|
|
def _set_backend_selection(self, backends: list[StorageBackend]) -> None:
|
|
normalized = [backend for backend in BACKEND_ORDER if backend in backends]
|
|
for backend in BACKEND_ORDER:
|
|
if backend not in self.available_backends:
|
|
continue
|
|
checkbox_id = f"#backend_{backend.value}"
|
|
checkbox = self.query_one(checkbox_id, Checkbox)
|
|
checkbox.value = backend in normalized
|
|
self.selected_backends = normalized
|
|
|
|
def _update_selection_status(self) -> None:
|
|
"""Update the visual indicator showing current storage selection."""
|
|
try:
|
|
status_widget = self.query_one("#selection_status", Static)
|
|
|
|
if not self.selected_backends:
|
|
status_widget.update("📋 Selected: None")
|
|
elif len(self.selected_backends) == 1:
|
|
backend_name = BACKEND_LABELS.get(
|
|
self.selected_backends[0], self.selected_backends[0].value
|
|
)
|
|
status_widget.update(f"📋 Selected: {backend_name}")
|
|
else:
|
|
# Multiple backends selected
|
|
backend_names = [
|
|
BACKEND_LABELS.get(backend, backend.value) for backend in self.selected_backends
|
|
]
|
|
if len(backend_names) <= 3:
|
|
# Show all names if 3 or fewer
|
|
names_str = ", ".join(backend_names)
|
|
status_widget.update(f"📋 Selected: {names_str}")
|
|
else:
|
|
# Show count if more than 3
|
|
status_widget.update(f"📋 Selected: {len(self.selected_backends)} backends")
|
|
except Exception:
|
|
# Widget might not exist yet during initialization
|
|
pass
|
|
|
|
def _derive_initial_backends(self) -> list[StorageBackend]:
|
|
backend_info = self.collection.get("backend", "")
|
|
|
|
# Handle both single backend (str) and multi-backend (list[str])
|
|
if isinstance(backend_info, list):
|
|
# Multi-backend: try to match all backends
|
|
matched_backends = []
|
|
for backend_name in backend_info:
|
|
backend_name_lower = backend_name.lower()
|
|
for backend in BACKEND_ORDER:
|
|
if backend not in self.available_backends:
|
|
continue
|
|
if (
|
|
backend.value.lower() == backend_name_lower
|
|
or backend.name.lower() == backend_name_lower
|
|
):
|
|
matched_backends.append(backend)
|
|
break
|
|
return matched_backends or [self.available_backends[0]]
|
|
else:
|
|
# Single backend: original logic
|
|
backend_label = str(backend_info).lower()
|
|
for backend in BACKEND_ORDER:
|
|
if backend not in self.available_backends:
|
|
continue
|
|
if backend.value in backend_label or backend.name.lower() in backend_label:
|
|
return [backend]
|
|
return [self.available_backends[0]]
|
|
|
|
@work(exclusive=True, thread=True)
|
|
def perform_ingestion(self, source_url: str, collection_name: str = "") -> None:
|
|
import asyncio
|
|
from typing import cast
|
|
|
|
backends = self._resolve_selected_backends()
|
|
self.selected_backends = backends
|
|
|
|
def update_ui(action: str) -> None:
|
|
def _update() -> None:
|
|
try:
|
|
loading = self.query_one("#loading")
|
|
if action == "show_loading":
|
|
loading.display = True
|
|
elif action == "hide_loading":
|
|
loading.display = False
|
|
except Exception:
|
|
pass
|
|
|
|
cast("CollectionManagementApp", self.app).call_from_thread(_update)
|
|
|
|
def progress_reporter(percent: int, message: str) -> None:
|
|
def _update_progress() -> None:
|
|
try:
|
|
progress = self.query_one("#enhanced_progress", EnhancedProgressBar)
|
|
progress_text = self.query_one("#progress_text", Static)
|
|
progress.update_progress(percent, message)
|
|
progress_text.update(message)
|
|
except Exception:
|
|
pass
|
|
|
|
cast("CollectionManagementApp", self.app).call_from_thread(_update_progress)
|
|
|
|
try:
|
|
update_ui("show_loading")
|
|
progress_reporter(5, "🚀 Starting Prefect flows...")
|
|
|
|
# Use user-provided collection name or fall back to default
|
|
final_collection_name = collection_name or self.collection.get("name")
|
|
|
|
total_successful = 0
|
|
total_failed = 0
|
|
flow_errors: list[str] = []
|
|
|
|
for i, backend in enumerate(backends):
|
|
progress_percent = 20 + (60 * i) // len(backends)
|
|
progress_reporter(
|
|
progress_percent,
|
|
f"🔗 Processing {backend.value} backend ({i + 1}/{len(backends)})...",
|
|
)
|
|
|
|
try:
|
|
# Run the Prefect flow for this backend using asyncio.run with timeout
|
|
import asyncio
|
|
|
|
async def run_flow_with_timeout(
|
|
current_backend: StorageBackend = backend,
|
|
) -> IngestionResult:
|
|
return await asyncio.wait_for(
|
|
create_ingestion_flow(
|
|
source_url=source_url,
|
|
source_type=self.selected_type,
|
|
storage_backend=current_backend,
|
|
collection_name=final_collection_name,
|
|
progress_callback=progress_reporter,
|
|
),
|
|
timeout=600.0, # 10 minute timeout
|
|
)
|
|
|
|
result = asyncio.run(run_flow_with_timeout())
|
|
|
|
total_successful += result.documents_processed
|
|
total_failed += result.documents_failed
|
|
|
|
if result.error_messages:
|
|
flow_errors.extend(
|
|
[f"{backend.value}: {err}" for err in result.error_messages]
|
|
)
|
|
|
|
except TimeoutError:
|
|
error_msg = f"{backend.value}: Timeout after 10 minutes"
|
|
flow_errors.append(error_msg)
|
|
progress_reporter(0, f"❌ {backend.value} timed out")
|
|
|
|
def notify_timeout(
|
|
msg: str = f"⏰ {backend.value} flow timed out after 10 minutes",
|
|
) -> None:
|
|
try:
|
|
self.notify(msg, severity="error", markup=False)
|
|
except Exception:
|
|
pass
|
|
|
|
cast("CollectionManagementApp", self.app).call_from_thread(notify_timeout)
|
|
except Exception as exc:
|
|
flow_errors.append(f"{backend.value}: {exc}")
|
|
|
|
def notify_error(msg: str = f"❌ {backend.value} flow failed: {exc}") -> None:
|
|
try:
|
|
self.notify(msg, severity="error", markup=False)
|
|
except Exception:
|
|
pass
|
|
|
|
cast("CollectionManagementApp", self.app).call_from_thread(notify_error)
|
|
|
|
successful = total_successful
|
|
failed = total_failed
|
|
|
|
progress_reporter(100, "🎉 Completed successfully!")
|
|
|
|
def notify_results() -> None:
|
|
try:
|
|
if successful > 0:
|
|
self.notify(
|
|
f"🎉 Successfully ingested {successful} documents across {len(backends)} backend(s) via Prefect!",
|
|
severity="information",
|
|
)
|
|
if failed > 0:
|
|
self.notify(f"⚠️ {failed} documents failed to process", severity="warning")
|
|
|
|
if flow_errors:
|
|
for error in flow_errors:
|
|
self.notify(f"⚠️ {error}", severity="warning", markup=False)
|
|
except Exception:
|
|
pass
|
|
|
|
cast("CollectionManagementApp", self.app).call_from_thread(notify_results)
|
|
|
|
def _pop() -> None:
|
|
try:
|
|
self.app.pop_screen()
|
|
except Exception:
|
|
pass
|
|
|
|
# Schedule screen pop via timer instead of blocking
|
|
cast("CollectionManagementApp", self.app).call_from_thread(
|
|
lambda: self.app.set_timer(2.0, _pop)
|
|
)
|
|
|
|
except Exception as exc: # pragma: no cover - defensive
|
|
progress_reporter(0, f"❌ Prefect flows error: {exc}")
|
|
|
|
def notify_error(msg: str = f"❌ Prefect flows failed: {exc}") -> None:
|
|
try:
|
|
self.notify(msg, severity="error")
|
|
except Exception:
|
|
pass
|
|
|
|
cast("CollectionManagementApp", self.app).call_from_thread(notify_error)
|
|
|
|
def _pop_on_error() -> None:
|
|
try:
|
|
self.app.pop_screen()
|
|
except Exception:
|
|
pass
|
|
|
|
# Schedule screen pop via timer for error case too
|
|
cast("CollectionManagementApp", self.app).call_from_thread(
|
|
lambda: self.app.set_timer(2.0, _pop_on_error)
|
|
)
|
|
finally:
|
|
update_ui("hide_loading")
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/cli/tui/screens/search.py">
|
|
"""Search screen for finding documents within collections."""
|
|
|
|
from textual.app import ComposeResult
|
|
from textual.binding import Binding
|
|
from textual.containers import Container
|
|
from textual.screen import Screen
|
|
from textual.widgets import Button, Footer, Header, Input, LoadingIndicator, Static
|
|
from typing_extensions import override
|
|
|
|
from ....storage.openwebui import OpenWebUIStorage
|
|
from ....storage.weaviate import WeaviateStorage
|
|
from ..models import CollectionInfo
|
|
from ..widgets import EnhancedDataTable
|
|
|
|
|
|
class SearchScreen(Screen[None]):
|
|
"""Screen for searching within a collection with enhanced keyboard navigation."""
|
|
|
|
collection: CollectionInfo
|
|
weaviate: WeaviateStorage | None
|
|
openwebui: OpenWebUIStorage | None
|
|
|
|
BINDINGS = [
|
|
Binding("escape", "app.pop_screen", "Back"),
|
|
Binding("enter", "perform_search", "Search"),
|
|
Binding("ctrl+f", "focus_search", "Focus Search"),
|
|
Binding("f3", "perform_search", "Search Again"),
|
|
Binding("ctrl+r", "clear_results", "Clear Results"),
|
|
Binding("/", "focus_search", "Quick Search"),
|
|
]
|
|
|
|
def __init__(
|
|
self,
|
|
collection: CollectionInfo,
|
|
weaviate: WeaviateStorage | None,
|
|
openwebui: OpenWebUIStorage | None,
|
|
):
|
|
super().__init__()
|
|
self.collection = collection
|
|
self.weaviate = weaviate
|
|
self.openwebui = openwebui
|
|
|
|
@override
|
|
def compose(self) -> ComposeResult:
|
|
yield Header()
|
|
# Check if search is supported for this backend
|
|
backends = self.collection["backend"]
|
|
if isinstance(backends, str):
|
|
backends = [backends]
|
|
search_supported = "weaviate" in backends
|
|
search_indicator = "✅ Search supported" if search_supported else "❌ Search not supported"
|
|
|
|
yield Container(
|
|
Static(
|
|
f"🔍 Search in: {self.collection['name']} ({', '.join(backends)}) - {search_indicator}",
|
|
classes="title",
|
|
),
|
|
Static(
|
|
"Press / or Ctrl+F to focus search, Enter to search"
|
|
if search_supported
|
|
else "Search functionality not available for this backend",
|
|
classes="subtitle",
|
|
),
|
|
Input(placeholder="Enter search query... (press Enter to search)", id="search_input"),
|
|
Button("🔍 Search", id="search_btn", variant="primary"),
|
|
Button("🗑️ Clear Results", id="clear_btn", variant="default"),
|
|
EnhancedDataTable(id="results_table"),
|
|
Static(
|
|
"Enter your search query to find relevant documents.",
|
|
id="search_status",
|
|
classes="status-text",
|
|
),
|
|
LoadingIndicator(id="loading"),
|
|
classes="main_container",
|
|
)
|
|
yield Footer()
|
|
|
|
def on_mount(self) -> None:
|
|
"""Initialize the screen."""
|
|
self.query_one("#loading").display = False
|
|
|
|
# Setup results table with enhanced metadata
|
|
table = self.query_one("#results_table", EnhancedDataTable)
|
|
table.add_columns("Title", "Source URL", "Type", "Content Preview", "Words", "Score")
|
|
|
|
# Focus search input
|
|
self.query_one("#search_input").focus()
|
|
|
|
def action_focus_search(self) -> None:
|
|
"""Focus the search input field."""
|
|
search_input = self.query_one("#search_input", Input)
|
|
search_input.focus()
|
|
|
|
def action_clear_results(self) -> None:
|
|
"""Clear search results."""
|
|
table = self.query_one("#results_table", EnhancedDataTable)
|
|
table.clear()
|
|
table.add_columns("Title", "Source URL", "Type", "Content Preview", "Words", "Score")
|
|
|
|
status = self.query_one("#search_status", Static)
|
|
status.update("Search results cleared. Enter a new query to search.")
|
|
|
|
def on_input_submitted(self, event: Input.Submitted) -> None:
|
|
"""Handle search input submission."""
|
|
if event.input.id == "search_input":
|
|
self.action_perform_search()
|
|
|
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
"""Handle button presses."""
|
|
if event.button.id == "search_btn":
|
|
self.action_perform_search()
|
|
elif event.button.id == "clear_btn":
|
|
self.action_clear_results()
|
|
|
|
def action_perform_search(self) -> None:
|
|
"""Perform search."""
|
|
search_input = self.query_one("#search_input", Input)
|
|
if not search_input.value.strip():
|
|
self.notify("Please enter a search query", severity="warning")
|
|
search_input.focus()
|
|
return
|
|
|
|
self.run_worker(self.search_collection(search_input.value.strip()))
|
|
|
|
async def search_collection(self, query: str) -> None:
|
|
"""Search the collection."""
|
|
loading = self.query_one("#loading", LoadingIndicator)
|
|
table = self.query_one("#results_table", EnhancedDataTable)
|
|
status = self.query_one("#search_status", Static)
|
|
|
|
try:
|
|
self._setup_search_ui(loading, table, status, query)
|
|
results = await self._execute_search(query)
|
|
self._populate_results_table(table, results)
|
|
self._update_search_status(status, query, results, table)
|
|
except Exception as e:
|
|
status.update(f"Search error: {e}")
|
|
self.notify(f"Search error: {e}", severity="error", markup=False)
|
|
finally:
|
|
loading.display = False
|
|
|
|
def _setup_search_ui(
|
|
self, loading: LoadingIndicator, table: EnhancedDataTable, status: Static, query: str
|
|
) -> None:
|
|
"""Setup the search UI elements."""
|
|
loading.display = True
|
|
status.update(f"🔍 Searching for '{query}'...")
|
|
table.clear()
|
|
table.add_columns("Title", "Source URL", "Type", "Content Preview", "Words", "Score")
|
|
|
|
async def _execute_search(self, query: str) -> list[dict[str, str | float]]:
|
|
"""Execute the search based on collection type."""
|
|
if self.collection["type"] == "weaviate" and self.weaviate:
|
|
return await self.search_weaviate(query)
|
|
elif self.collection["type"] == "openwebui" and self.openwebui:
|
|
# OpenWebUI search is not yet implemented
|
|
self.notify("Search not supported for OpenWebUI collections", severity="warning")
|
|
return []
|
|
elif self.collection["type"] == "r2r":
|
|
# R2R search would go here when implemented
|
|
self.notify("Search not supported for R2R collections", severity="warning")
|
|
return []
|
|
return []
|
|
|
|
def _populate_results_table(
|
|
self, table: EnhancedDataTable, results: list[dict[str, str | float]]
|
|
) -> None:
|
|
"""Populate the results table with search results."""
|
|
for result in results:
|
|
row_data = self._format_result_row(result)
|
|
table.add_row(*row_data)
|
|
|
|
def _format_result_row(self, result: dict[str, str | float]) -> tuple[str, ...]:
|
|
"""Format a single result row for the table."""
|
|
title = self._truncate_text(result.get("title", "Untitled"), 30)
|
|
source_url = self._truncate_text(result.get("source_url", ""), 40)
|
|
type_display = self._format_content_type(result.get("content_type", "text/plain"))
|
|
content_preview = self._format_content_preview(result.get("content", ""))
|
|
word_count = str(result.get("word_count", 0))
|
|
score_display = self._format_score(result.get("score"))
|
|
|
|
return (title, source_url, type_display, content_preview, word_count, score_display)
|
|
|
|
def _truncate_text(self, text: str | float | None, max_length: int) -> str:
|
|
"""Truncate text to specified length."""
|
|
if not isinstance(text, str):
|
|
text = str(text) if text is not None else ""
|
|
return text[:max_length]
|
|
|
|
def _format_content_type(self, content_type: str | float) -> str:
|
|
"""Format content type with appropriate icon."""
|
|
content_type = str(content_type).lower()
|
|
if "markdown" in content_type:
|
|
return "📝 md"
|
|
elif "html" in content_type:
|
|
return "🌐 html"
|
|
elif "text" in content_type:
|
|
return "📄 txt"
|
|
else:
|
|
return f"📄 {content_type.split('/')[-1][:5]}"
|
|
|
|
def _format_content_preview(self, content: str | float) -> str:
|
|
"""Format content preview with truncation."""
|
|
if not isinstance(content, str):
|
|
content = str(content) if content is not None else ""
|
|
return f"{content[:60]}..." if len(content) > 60 else content
|
|
|
|
def _format_score(self, score: object) -> str:
|
|
"""Format search score for display."""
|
|
if isinstance(score, (int, float)):
|
|
return f"{score:.3f}"
|
|
elif score is None:
|
|
return "-"
|
|
else:
|
|
return str(score)
|
|
|
|
def _update_search_status(
|
|
self,
|
|
status: Static,
|
|
query: str,
|
|
results: list[dict[str, str | float]],
|
|
table: EnhancedDataTable,
|
|
) -> None:
|
|
"""Update search status and notifications based on results."""
|
|
if not results:
|
|
status.update(f"No results found for '{query}'. Try different keywords.")
|
|
self.notify("No results found", severity="information")
|
|
else:
|
|
status.update(
|
|
f"Found {len(results)} results for '{query}'. Use arrow keys to navigate."
|
|
)
|
|
self.notify(f"Found {len(results)} results", severity="information")
|
|
table.focus()
|
|
|
|
async def search_weaviate(self, query: str) -> list[dict[str, str | float]]:
|
|
"""Search Weaviate collection."""
|
|
if not self.weaviate:
|
|
return []
|
|
|
|
try:
|
|
await self.weaviate.initialize()
|
|
# Use the search_documents method which returns more metadata
|
|
results = await self.weaviate.search_documents(
|
|
query,
|
|
limit=20,
|
|
collection_name=self.collection["name"],
|
|
)
|
|
|
|
# Convert Document objects to dict format expected by the UI
|
|
formatted_results = []
|
|
for doc in results:
|
|
metadata = getattr(doc, "metadata", {})
|
|
|
|
score_value: float | None = None
|
|
raw_score = getattr(doc, "score", None)
|
|
if raw_score is not None:
|
|
try:
|
|
score_value = float(raw_score)
|
|
except (TypeError, ValueError):
|
|
score_value = None
|
|
|
|
formatted_results.append(
|
|
{
|
|
"title": metadata.get("title", "Untitled"),
|
|
"source_url": metadata.get("source_url", ""),
|
|
"content_type": metadata.get("content_type", "text/plain"),
|
|
"content": getattr(doc, "content", ""),
|
|
"word_count": metadata.get("word_count", 0),
|
|
"score": score_value if score_value is not None else 0.0,
|
|
}
|
|
)
|
|
return formatted_results
|
|
except Exception as e:
|
|
self.notify(f"Weaviate search error: {e}", severity="error", markup=False)
|
|
return []
|
|
|
|
async def search_openwebui(self, query: str) -> list[dict[str, str | float]]:
|
|
"""Search OpenWebUI collection."""
|
|
if not self.openwebui:
|
|
return []
|
|
|
|
try:
|
|
# OpenWebUI does not have a direct search API, so return empty
|
|
# In a real implementation, you would need to implement search via their API
|
|
self.notify("OpenWebUI search not yet implemented", severity="warning")
|
|
return []
|
|
except Exception as e:
|
|
self.notify(f"OpenWebUI search error: {e}", severity="error", markup=False)
|
|
return []
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/cli/tui/utils/storage_manager.py">
|
|
"""Storage management utilities for TUI applications."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections.abc import AsyncGenerator, Coroutine, Sequence
|
|
from typing import TYPE_CHECKING, Protocol
|
|
|
|
from pydantic import SecretStr
|
|
|
|
from ....core.exceptions import StorageError
|
|
from ....core.models import Document, StorageBackend, StorageConfig
|
|
from ....storage.base import BaseStorage
|
|
from ....storage.openwebui import OpenWebUIStorage
|
|
from ....storage.r2r.storage import R2RStorage
|
|
from ....storage.weaviate import WeaviateStorage
|
|
from ..models import CollectionInfo, StorageCapabilities
|
|
|
|
if TYPE_CHECKING:
|
|
from ....config.settings import Settings
|
|
|
|
|
|
class StorageBackendProtocol(Protocol):
|
|
"""Protocol defining storage backend interface."""
|
|
|
|
async def initialize(self) -> None: ...
|
|
async def count(self, *, collection_name: str | None = None) -> int: ...
|
|
async def list_collections(self) -> list[str]: ...
|
|
async def search(
|
|
self,
|
|
query: str,
|
|
limit: int = 10,
|
|
threshold: float = 0.7,
|
|
*,
|
|
collection_name: str | None = None,
|
|
) -> AsyncGenerator[Document, None]: ...
|
|
async def close(self) -> None: ...
|
|
|
|
|
|
class MultiStorageAdapter(BaseStorage):
|
|
"""Mirror writes to multiple storage backends."""
|
|
|
|
def __init__(self, storages: Sequence[BaseStorage]) -> None:
|
|
if not storages:
|
|
raise ValueError("MultiStorageAdapter requires at least one storage backend")
|
|
|
|
unique: list[BaseStorage] = []
|
|
seen_ids: set[int] = set()
|
|
for storage in storages:
|
|
storage_id = id(storage)
|
|
if storage_id in seen_ids:
|
|
continue
|
|
seen_ids.add(storage_id)
|
|
unique.append(storage)
|
|
|
|
self._storages: list[BaseStorage] = unique
|
|
self._primary: BaseStorage = unique[0]
|
|
super().__init__(self._primary.config)
|
|
|
|
async def initialize(self) -> None:
|
|
for storage in self._storages:
|
|
await storage.initialize()
|
|
|
|
async def store(self, document: Document, *, collection_name: str | None = None) -> str:
|
|
# Store in primary backend first
|
|
primary_id: str = await self._primary.store(document, collection_name=collection_name)
|
|
|
|
# Replicate to secondary backends concurrently
|
|
if len(self._storages) > 1:
|
|
|
|
async def replicate_to_backend(
|
|
storage: BaseStorage,
|
|
) -> tuple[BaseStorage, bool, Exception | None]:
|
|
try:
|
|
await storage.store(document, collection_name=collection_name)
|
|
return storage, True, None
|
|
except Exception as exc:
|
|
return storage, False, exc
|
|
|
|
tasks = [replicate_to_backend(storage) for storage in self._storages[1:]]
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
failures: list[str] = []
|
|
errors: list[Exception] = []
|
|
|
|
for result in results:
|
|
if isinstance(result, tuple):
|
|
storage, success, error = result
|
|
if not success and error is not None:
|
|
failures.append(self._format_backend_label(storage))
|
|
errors.append(error)
|
|
elif isinstance(result, Exception):
|
|
failures.append("unknown")
|
|
errors.append(result)
|
|
|
|
if failures:
|
|
backends = ", ".join(failures)
|
|
primary_error = errors[0] if errors else Exception("Unknown replication error")
|
|
raise StorageError(
|
|
f"Document stored in primary backend but replication failed for: {backends}"
|
|
) from primary_error
|
|
|
|
return primary_id
|
|
|
|
async def store_batch(
|
|
self, documents: list[Document], *, collection_name: str | None = None
|
|
) -> list[str]:
|
|
# Store in primary backend first
|
|
primary_ids: list[str] = await self._primary.store_batch(
|
|
documents, collection_name=collection_name
|
|
)
|
|
|
|
# Replicate to secondary backends concurrently
|
|
if len(self._storages) > 1:
|
|
|
|
async def replicate_batch_to_backend(
|
|
storage: BaseStorage,
|
|
) -> tuple[BaseStorage, bool, Exception | None]:
|
|
try:
|
|
await storage.store_batch(documents, collection_name=collection_name)
|
|
return storage, True, None
|
|
except Exception as exc:
|
|
return storage, False, exc
|
|
|
|
tasks = [replicate_batch_to_backend(storage) for storage in self._storages[1:]]
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
failures: list[str] = []
|
|
errors: list[Exception] = []
|
|
|
|
for result in results:
|
|
if isinstance(result, tuple):
|
|
storage, success, error = result
|
|
if not success and error is not None:
|
|
failures.append(self._format_backend_label(storage))
|
|
errors.append(error)
|
|
elif isinstance(result, Exception):
|
|
failures.append("unknown")
|
|
errors.append(result)
|
|
|
|
if failures:
|
|
backends = ", ".join(failures)
|
|
primary_error = (
|
|
errors[0] if errors else Exception("Unknown batch replication error")
|
|
)
|
|
raise StorageError(
|
|
f"Batch stored in primary backend but replication failed for: {backends}"
|
|
) from primary_error
|
|
|
|
return primary_ids
|
|
|
|
async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool:
|
|
# Delete from primary backend first
|
|
primary_deleted: bool = await self._primary.delete(
|
|
document_id, collection_name=collection_name
|
|
)
|
|
|
|
# Delete from secondary backends concurrently
|
|
if len(self._storages) > 1:
|
|
|
|
async def delete_from_backend(
|
|
storage: BaseStorage,
|
|
) -> tuple[BaseStorage, bool, Exception | None]:
|
|
try:
|
|
await storage.delete(document_id, collection_name=collection_name)
|
|
return storage, True, None
|
|
except Exception as exc:
|
|
return storage, False, exc
|
|
|
|
tasks = [delete_from_backend(storage) for storage in self._storages[1:]]
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
failures: list[str] = []
|
|
errors: list[Exception] = []
|
|
|
|
for result in results:
|
|
if isinstance(result, tuple):
|
|
storage, success, error = result
|
|
if not success and error is not None:
|
|
failures.append(self._format_backend_label(storage))
|
|
errors.append(error)
|
|
elif isinstance(result, Exception):
|
|
failures.append("unknown")
|
|
errors.append(result)
|
|
|
|
if failures:
|
|
backends = ", ".join(failures)
|
|
primary_error = errors[0] if errors else Exception("Unknown deletion error")
|
|
raise StorageError(
|
|
f"Document deleted from primary backend but failed for: {backends}"
|
|
) from primary_error
|
|
|
|
return primary_deleted
|
|
|
|
async def count(self, *, collection_name: str | None = None) -> int:
|
|
count_result: int = await self._primary.count(collection_name=collection_name)
|
|
return count_result
|
|
|
|
async def list_collections(self) -> list[str]:
|
|
list_fn = getattr(self._primary, "list_collections", None)
|
|
if list_fn is None:
|
|
return []
|
|
collections_result: list[str] = await list_fn()
|
|
return collections_result
|
|
|
|
async def search(
|
|
self,
|
|
query: str,
|
|
limit: int = 10,
|
|
threshold: float = 0.7,
|
|
*,
|
|
collection_name: str | None = None,
|
|
) -> AsyncGenerator[Document, None]:
|
|
async for item in self._primary.search(
|
|
query,
|
|
limit=limit,
|
|
threshold=threshold,
|
|
collection_name=collection_name,
|
|
):
|
|
yield item
|
|
|
|
async def close(self) -> None:
|
|
for storage in self._storages:
|
|
close_fn = getattr(storage, "close", None)
|
|
if close_fn is not None:
|
|
await close_fn()
|
|
|
|
def _format_backend_label(self, storage: BaseStorage) -> str:
|
|
backend = getattr(storage.config, "backend", None)
|
|
if isinstance(backend, StorageBackend):
|
|
backend_value: str = backend.value
|
|
return backend_value
|
|
class_name: str = storage.__class__.__name__
|
|
return class_name
|
|
|
|
|
|
class StorageManager:
|
|
"""Centralized manager for all storage backend operations."""
|
|
|
|
def __init__(self, settings: Settings) -> None:
|
|
"""Initialize storage manager with application settings."""
|
|
self.settings: Settings = settings
|
|
self.backends: dict[StorageBackend, BaseStorage] = {}
|
|
self.capabilities: dict[StorageBackend, StorageCapabilities] = {}
|
|
self._initialized: bool = False
|
|
|
|
async def initialize_all_backends(self) -> dict[StorageBackend, bool]:
|
|
"""Initialize all available storage backends with timeout protection."""
|
|
results: dict[StorageBackend, bool] = {}
|
|
|
|
async def init_backend(
|
|
backend_type: StorageBackend, config: StorageConfig, storage_class: type[BaseStorage]
|
|
) -> bool:
|
|
"""Initialize a single backend with timeout."""
|
|
try:
|
|
storage = storage_class(config)
|
|
await asyncio.wait_for(storage.initialize(), timeout=30.0)
|
|
self.backends[backend_type] = storage
|
|
if backend_type == StorageBackend.WEAVIATE:
|
|
self.capabilities[backend_type] = StorageCapabilities.VECTOR_SEARCH
|
|
elif backend_type == StorageBackend.OPEN_WEBUI:
|
|
self.capabilities[backend_type] = StorageCapabilities.KNOWLEDGE_BASE
|
|
elif backend_type == StorageBackend.R2R:
|
|
self.capabilities[backend_type] = StorageCapabilities.FULL_FEATURED
|
|
return True
|
|
except (TimeoutError, Exception):
|
|
return False
|
|
|
|
# Initialize backends concurrently with timeout protection
|
|
tasks: list[tuple[StorageBackend, Coroutine[None, None, bool]]] = []
|
|
|
|
# Try Weaviate
|
|
if self.settings.weaviate_endpoint:
|
|
config = StorageConfig(
|
|
backend=StorageBackend.WEAVIATE,
|
|
endpoint=self.settings.weaviate_endpoint,
|
|
api_key=SecretStr(self.settings.weaviate_api_key)
|
|
if self.settings.weaviate_api_key
|
|
else None,
|
|
collection_name="default",
|
|
)
|
|
tasks.append(
|
|
(
|
|
StorageBackend.WEAVIATE,
|
|
init_backend(StorageBackend.WEAVIATE, config, WeaviateStorage),
|
|
)
|
|
)
|
|
else:
|
|
results[StorageBackend.WEAVIATE] = False
|
|
|
|
# Try OpenWebUI
|
|
if self.settings.openwebui_endpoint and self.settings.openwebui_api_key:
|
|
config = StorageConfig(
|
|
backend=StorageBackend.OPEN_WEBUI,
|
|
endpoint=self.settings.openwebui_endpoint,
|
|
api_key=SecretStr(self.settings.openwebui_api_key)
|
|
if self.settings.openwebui_api_key
|
|
else None,
|
|
collection_name="default",
|
|
)
|
|
tasks.append(
|
|
(
|
|
StorageBackend.OPEN_WEBUI,
|
|
init_backend(StorageBackend.OPEN_WEBUI, config, OpenWebUIStorage),
|
|
)
|
|
)
|
|
else:
|
|
results[StorageBackend.OPEN_WEBUI] = False
|
|
|
|
# Try R2R
|
|
if self.settings.r2r_endpoint:
|
|
config = StorageConfig(
|
|
backend=StorageBackend.R2R,
|
|
endpoint=self.settings.r2r_endpoint,
|
|
api_key=SecretStr(self.settings.r2r_api_key) if self.settings.r2r_api_key else None,
|
|
collection_name="default",
|
|
)
|
|
tasks.append((StorageBackend.R2R, init_backend(StorageBackend.R2R, config, R2RStorage)))
|
|
else:
|
|
results[StorageBackend.R2R] = False
|
|
|
|
# Execute initialization tasks concurrently
|
|
if tasks:
|
|
backend_types, task_coroutines = zip(*tasks, strict=False)
|
|
task_results: Sequence[bool | BaseException] = await asyncio.gather(
|
|
*task_coroutines, return_exceptions=True
|
|
)
|
|
|
|
for backend_type, task_result in zip(backend_types, task_results, strict=False):
|
|
results[backend_type] = task_result if isinstance(task_result, bool) else False
|
|
self._initialized = True
|
|
return results
|
|
|
|
def get_backend(self, backend_type: StorageBackend) -> BaseStorage | None:
|
|
"""Get storage backend by type."""
|
|
return self.backends.get(backend_type)
|
|
|
|
def build_multi_storage_adapter(
|
|
self, backends: Sequence[StorageBackend]
|
|
) -> MultiStorageAdapter:
|
|
storages: list[BaseStorage] = []
|
|
seen: set[StorageBackend] = set()
|
|
for backend in backends:
|
|
backend_enum = (
|
|
backend if isinstance(backend, StorageBackend) else StorageBackend(backend)
|
|
)
|
|
if backend_enum in seen:
|
|
continue
|
|
seen.add(backend_enum)
|
|
storage = self.backends.get(backend_enum)
|
|
if storage is None:
|
|
raise ValueError(f"Storage backend {backend_enum.value} is not initialized")
|
|
storages.append(storage)
|
|
return MultiStorageAdapter(storages)
|
|
|
|
def get_available_backends(self) -> list[StorageBackend]:
|
|
"""Get list of successfully initialized backends."""
|
|
return list(self.backends.keys())
|
|
|
|
def has_capability(self, backend: StorageBackend, capability: StorageCapabilities) -> bool:
|
|
"""Check if backend has specific capability."""
|
|
backend_caps = self.capabilities.get(backend, StorageCapabilities.BASIC)
|
|
return capability.value <= backend_caps.value
|
|
|
|
async def get_all_collections(self) -> list[CollectionInfo]:
|
|
"""Get collections from all available backends, merging collections with same name."""
|
|
collection_map: dict[str, CollectionInfo] = {}
|
|
|
|
for backend_type, storage in self.backends.items():
|
|
try:
|
|
backend_collections = await storage.list_collections()
|
|
for collection_name in backend_collections:
|
|
# Validate collection name
|
|
if not collection_name or not isinstance(collection_name, str):
|
|
continue
|
|
|
|
try:
|
|
count = await storage.count(collection_name=collection_name)
|
|
# Validate count is non-negative
|
|
count = max(count, 0)
|
|
except StorageError as e:
|
|
# Storage-specific errors - log and use 0 count
|
|
import logging
|
|
|
|
logging.warning(
|
|
f"Failed to get count for {collection_name} on {backend_type.value}: {e}"
|
|
)
|
|
count = 0
|
|
except Exception as e:
|
|
# Unexpected errors - log and skip this collection from this backend
|
|
import logging
|
|
|
|
logging.warning(
|
|
f"Unexpected error counting {collection_name} on {backend_type.value}: {e}"
|
|
)
|
|
continue
|
|
|
|
size_mb = count * 0.01 # Rough estimate: 10KB per document
|
|
|
|
if collection_name in collection_map:
|
|
# Merge with existing collection
|
|
existing = collection_map[collection_name]
|
|
existing_backends = existing["backend"]
|
|
backend_value = backend_type.value
|
|
|
|
if isinstance(existing_backends, str):
|
|
existing["backend"] = [existing_backends, backend_value]
|
|
elif isinstance(existing_backends, list):
|
|
# Prevent duplicates
|
|
if backend_value not in existing_backends:
|
|
existing_backends.append(backend_value)
|
|
|
|
# Aggregate counts and sizes
|
|
existing["count"] += count
|
|
existing["size_mb"] += size_mb
|
|
else:
|
|
# Create new collection entry
|
|
collection_info: CollectionInfo = {
|
|
"name": collection_name,
|
|
"type": self._get_collection_type(collection_name, backend_type),
|
|
"count": count,
|
|
"backend": backend_type.value,
|
|
"status": "active",
|
|
"last_updated": "2024-01-01T00:00:00Z",
|
|
"size_mb": size_mb,
|
|
}
|
|
collection_map[collection_name] = collection_info
|
|
except Exception:
|
|
continue
|
|
|
|
return list(collection_map.values())
|
|
|
|
def _get_collection_type(self, collection_name: str, backend: StorageBackend) -> str:
|
|
"""Determine collection type based on name and backend."""
|
|
# Prioritize definitive backend type first
|
|
if backend == StorageBackend.R2R:
|
|
return "r2r"
|
|
elif backend == StorageBackend.WEAVIATE:
|
|
return "weaviate"
|
|
elif backend == StorageBackend.OPEN_WEBUI:
|
|
return "openwebui"
|
|
|
|
# Fallback to name-based guessing if backend is not specific
|
|
name_lower = collection_name.lower()
|
|
if "web" in name_lower or "doc" in name_lower:
|
|
return "documentation"
|
|
elif "repo" in name_lower or "code" in name_lower:
|
|
return "repository"
|
|
else:
|
|
return "general"
|
|
|
|
async def search_across_backends(
|
|
self,
|
|
query: str,
|
|
limit: int = 10,
|
|
backends: list[StorageBackend] | None = None,
|
|
) -> dict[StorageBackend, list[Document]]:
|
|
"""Search across multiple backends and return grouped results."""
|
|
if backends is None:
|
|
backends = self.get_available_backends()
|
|
|
|
results: dict[StorageBackend, list[Document]] = {}
|
|
|
|
async def search_backend(backend_type: StorageBackend) -> None:
|
|
storage = self.backends.get(backend_type)
|
|
if storage:
|
|
try:
|
|
documents: list[Document] = []
|
|
async for doc in storage.search(query, limit=limit):
|
|
documents.append(doc)
|
|
results[backend_type] = documents
|
|
except Exception:
|
|
results[backend_type] = []
|
|
|
|
# Run searches in parallel
|
|
tasks = [search_backend(backend) for backend in backends]
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
return results
|
|
|
|
def get_r2r_storage(self) -> R2RStorage | None:
|
|
"""Get R2R storage instance if available."""
|
|
storage = self.backends.get(StorageBackend.R2R)
|
|
return storage if isinstance(storage, R2RStorage) else None
|
|
|
|
async def get_backend_status(
|
|
self,
|
|
) -> dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]]:
|
|
"""Get comprehensive status for all backends."""
|
|
status: dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]] = {}
|
|
|
|
for backend_type, storage in self.backends.items():
|
|
try:
|
|
collections = await storage.list_collections()
|
|
total_docs = 0
|
|
for collection in collections:
|
|
total_docs += await storage.count(collection_name=collection)
|
|
|
|
backend_status: dict[str, str | int | bool | StorageCapabilities] = {
|
|
"available": True,
|
|
"collections": len(collections),
|
|
"total_documents": total_docs,
|
|
"capabilities": self.capabilities.get(backend_type, StorageCapabilities.BASIC),
|
|
"endpoint": getattr(storage.config, "endpoint", "unknown"),
|
|
}
|
|
status[backend_type] = backend_status
|
|
except Exception as e:
|
|
status[backend_type] = {
|
|
"available": False,
|
|
"error": str(e),
|
|
"capabilities": StorageCapabilities.NONE,
|
|
}
|
|
|
|
return status
|
|
|
|
async def close_all(self) -> None:
|
|
"""Close all storage connections."""
|
|
for storage in self.backends.values():
|
|
try:
|
|
await storage.close()
|
|
except Exception:
|
|
pass
|
|
|
|
self.backends.clear()
|
|
self.capabilities.clear()
|
|
self._initialized = False
|
|
|
|
@property
|
|
def is_initialized(self) -> bool:
|
|
"""Check if storage manager is initialized."""
|
|
return self._initialized
|
|
|
|
def supports_advanced_features(self, backend: StorageBackend) -> bool:
|
|
"""Check if backend supports advanced features like chunks and entities."""
|
|
return self.has_capability(backend, StorageCapabilities.FULL_FEATURED)
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/cli/tui/widgets/firecrawl_config.py">
|
|
"""Firecrawl configuration widgets for advanced scraping options."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import cast
|
|
|
|
from textual.app import ComposeResult
|
|
from textual.containers import Container, Horizontal
|
|
from textual.validation import Integer
|
|
from textual.widget import Widget
|
|
from textual.widgets import Button, Checkbox, Input, Label, Switch, TextArea
|
|
from typing_extensions import override
|
|
|
|
from ..models import FirecrawlOptions
|
|
|
|
|
|
class ScrapeOptionsForm(Widget):
|
|
"""Form for configuring Firecrawl scraping options."""
|
|
|
|
DEFAULT_CSS = """
|
|
ScrapeOptionsForm {
|
|
border: solid $border;
|
|
background: $surface;
|
|
padding: 1;
|
|
height: auto;
|
|
}
|
|
|
|
ScrapeOptionsForm .form-section {
|
|
margin-bottom: 2;
|
|
padding: 1;
|
|
border: solid $border-lighten-1;
|
|
background: $surface-lighten-1;
|
|
}
|
|
|
|
ScrapeOptionsForm .form-row {
|
|
layout: horizontal;
|
|
align-items: center;
|
|
height: auto;
|
|
margin-bottom: 1;
|
|
}
|
|
|
|
ScrapeOptionsForm .form-label {
|
|
width: 30%;
|
|
min-width: 15;
|
|
text-align: right;
|
|
padding-right: 2;
|
|
}
|
|
|
|
ScrapeOptionsForm .form-input {
|
|
width: 70%;
|
|
}
|
|
|
|
ScrapeOptionsForm .checkbox-row {
|
|
layout: horizontal;
|
|
align-items: center;
|
|
height: 3;
|
|
margin-bottom: 1;
|
|
}
|
|
|
|
ScrapeOptionsForm .checkbox-label {
|
|
margin-left: 2;
|
|
}
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
name: str | None = None,
|
|
id: str | None = None,
|
|
classes: str | None = None,
|
|
disabled: bool = False,
|
|
markup: bool = True,
|
|
) -> None:
|
|
"""Initialize scrape options form."""
|
|
super().__init__(name=name, id=id, classes=classes, disabled=disabled, markup=markup)
|
|
|
|
@override
|
|
def compose(self) -> ComposeResult:
|
|
"""Compose scrape options form."""
|
|
yield Label("🔧 Scraping Configuration", classes="form-title")
|
|
|
|
# Output formats section
|
|
yield Container(
|
|
Label("Output Formats", classes="section-title"),
|
|
Horizontal(
|
|
Checkbox("Markdown", id="format_markdown", value=True, classes="checkbox"),
|
|
Label("Markdown", classes="checkbox-label"),
|
|
classes="checkbox-row",
|
|
),
|
|
Horizontal(
|
|
Checkbox("HTML", id="format_html", value=False, classes="checkbox"),
|
|
Label("HTML", classes="checkbox-label"),
|
|
classes="checkbox-row",
|
|
),
|
|
Horizontal(
|
|
Checkbox("Screenshot", id="format_screenshot", value=False, classes="checkbox"),
|
|
Label("Screenshot", classes="checkbox-label"),
|
|
classes="checkbox-row",
|
|
),
|
|
classes="form-section",
|
|
)
|
|
|
|
# Content filtering section
|
|
yield Container(
|
|
Label("Content Filtering", classes="section-title"),
|
|
Horizontal(
|
|
Label("Only Main Content:", classes="form-label"),
|
|
Switch(id="only_main_content", value=True, classes="form-input"),
|
|
classes="form-row",
|
|
),
|
|
Horizontal(
|
|
Label("Include Tags:", classes="form-label"),
|
|
Input(
|
|
placeholder="p, div, article (comma-separated)",
|
|
id="include_tags",
|
|
classes="form-input",
|
|
),
|
|
classes="form-row",
|
|
),
|
|
Horizontal(
|
|
Label("Exclude Tags:", classes="form-label"),
|
|
Input(
|
|
placeholder="nav, footer, script (comma-separated)",
|
|
id="exclude_tags",
|
|
classes="form-input",
|
|
),
|
|
classes="form-row",
|
|
),
|
|
classes="form-section",
|
|
)
|
|
|
|
# Performance settings section
|
|
yield Container(
|
|
Label("Performance Settings", classes="section-title"),
|
|
Horizontal(
|
|
Label("Wait Time (ms):", classes="form-label"),
|
|
Input(
|
|
placeholder="0",
|
|
id="wait_for",
|
|
validators=[Integer(minimum=0, maximum=30000)],
|
|
classes="form-input",
|
|
),
|
|
classes="form-row",
|
|
),
|
|
classes="form-section",
|
|
)
|
|
|
|
def get_scrape_options(self) -> dict[str, object]:
|
|
"""Get scraping options from form."""
|
|
# Collect formats
|
|
formats = []
|
|
if self.query_one("#format_markdown", Checkbox).value:
|
|
formats.append("markdown")
|
|
if self.query_one("#format_html", Checkbox).value:
|
|
formats.append("html")
|
|
if self.query_one("#format_screenshot", Checkbox).value:
|
|
formats.append("screenshot")
|
|
options: dict[str, object] = {
|
|
"formats": formats,
|
|
"only_main_content": self.query_one("#only_main_content", Switch).value,
|
|
}
|
|
include_tags_input = self.query_one("#include_tags", Input).value
|
|
if include_tags_input.strip():
|
|
options["include_tags"] = [tag.strip() for tag in include_tags_input.split(",")]
|
|
|
|
exclude_tags_input = self.query_one("#exclude_tags", Input).value
|
|
if exclude_tags_input.strip():
|
|
options["exclude_tags"] = [tag.strip() for tag in exclude_tags_input.split(",")]
|
|
|
|
# Performance
|
|
wait_for_input = self.query_one("#wait_for", Input).value
|
|
if wait_for_input.strip():
|
|
try:
|
|
options["wait_for"] = int(wait_for_input)
|
|
except ValueError:
|
|
pass
|
|
|
|
return options
|
|
|
|
def set_scrape_options(self, options: dict[str, object]) -> None:
|
|
"""Set form values from options."""
|
|
# Set formats
|
|
formats = options.get("formats", ["markdown"])
|
|
formats_list = formats if isinstance(formats, list) else []
|
|
self.query_one("#format_markdown", Checkbox).value = "markdown" in formats_list
|
|
self.query_one("#format_html", Checkbox).value = "html" in formats_list
|
|
self.query_one("#format_screenshot", Checkbox).value = "screenshot" in formats_list
|
|
|
|
# Set content filtering
|
|
main_content_val = options.get("only_main_content", True)
|
|
self.query_one("#only_main_content", Switch).value = bool(main_content_val)
|
|
|
|
if include_tags := options.get("include_tags", []):
|
|
include_list = include_tags if isinstance(include_tags, list) else []
|
|
self.query_one("#include_tags", Input).value = ", ".join(
|
|
str(tag) for tag in include_list
|
|
)
|
|
|
|
if exclude_tags := options.get("exclude_tags", []):
|
|
exclude_list = exclude_tags if isinstance(exclude_tags, list) else []
|
|
self.query_one("#exclude_tags", Input).value = ", ".join(
|
|
str(tag) for tag in exclude_list
|
|
)
|
|
|
|
# Set performance
|
|
wait_for = options.get("wait_for")
|
|
if wait_for is not None:
|
|
self.query_one("#wait_for", Input).value = str(wait_for)
|
|
|
|
|
|
class MapOptionsForm(Widget):
|
|
"""Form for configuring site mapping options."""
|
|
|
|
DEFAULT_CSS = """
|
|
MapOptionsForm {
|
|
border: solid $border;
|
|
background: $surface;
|
|
padding: 1;
|
|
height: auto;
|
|
}
|
|
|
|
MapOptionsForm .form-section {
|
|
margin-bottom: 2;
|
|
padding: 1;
|
|
border: solid $border-lighten-1;
|
|
background: $surface-lighten-1;
|
|
}
|
|
|
|
MapOptionsForm .form-row {
|
|
layout: horizontal;
|
|
align-items: center;
|
|
height: auto;
|
|
margin-bottom: 1;
|
|
}
|
|
|
|
MapOptionsForm .form-label {
|
|
width: 30%;
|
|
min-width: 15;
|
|
text-align: right;
|
|
padding-right: 2;
|
|
}
|
|
|
|
MapOptionsForm .form-input {
|
|
width: 70%;
|
|
}
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
name: str | None = None,
|
|
id: str | None = None,
|
|
classes: str | None = None,
|
|
disabled: bool = False,
|
|
markup: bool = True,
|
|
) -> None:
|
|
"""Initialize map options form."""
|
|
super().__init__(name=name, id=id, classes=classes, disabled=disabled, markup=markup)
|
|
|
|
@override
|
|
def compose(self) -> ComposeResult:
|
|
"""Compose map options form."""
|
|
yield Label("🗺️ Site Mapping Configuration", classes="form-title")
|
|
|
|
# Discovery settings section
|
|
yield Container(
|
|
Label("Discovery Settings", classes="section-title"),
|
|
Horizontal(
|
|
Label("Search Pattern:", classes="form-label"),
|
|
Input(
|
|
placeholder="docs, api, guide (optional)",
|
|
id="search_pattern",
|
|
classes="form-input",
|
|
),
|
|
classes="form-row",
|
|
),
|
|
Horizontal(
|
|
Label("Include Subdomains:", classes="form-label"),
|
|
Switch(id="include_subdomains", value=False, classes="form-input"),
|
|
classes="form-row",
|
|
),
|
|
classes="form-section",
|
|
)
|
|
|
|
# Limits section
|
|
yield Container(
|
|
Label("Crawling Limits", classes="section-title"),
|
|
Horizontal(
|
|
Label("Max Pages:", classes="form-label"),
|
|
Input(
|
|
placeholder="100",
|
|
id="max_pages",
|
|
validators=[Integer(minimum=1, maximum=1000)],
|
|
classes="form-input",
|
|
),
|
|
classes="form-row",
|
|
),
|
|
Horizontal(
|
|
Label("Max Depth:", classes="form-label"),
|
|
Input(
|
|
placeholder="5",
|
|
id="max_depth",
|
|
validators=[Integer(minimum=1, maximum=20)],
|
|
classes="form-input",
|
|
),
|
|
classes="form-row",
|
|
),
|
|
classes="form-section",
|
|
)
|
|
|
|
def get_map_options(self) -> dict[str, object]:
|
|
"""Get mapping options from form."""
|
|
options: dict[str, object] = {}
|
|
|
|
# Discovery settings
|
|
search_pattern = self.query_one("#search_pattern", Input).value
|
|
if search_pattern.strip():
|
|
options["search"] = search_pattern.strip()
|
|
|
|
options["include_subdomains"] = self.query_one("#include_subdomains", Switch).value
|
|
|
|
# Limits
|
|
max_pages_input = self.query_one("#max_pages", Input).value
|
|
if max_pages_input.strip():
|
|
try:
|
|
options["limit"] = int(max_pages_input)
|
|
except ValueError:
|
|
pass
|
|
|
|
max_depth_input = self.query_one("#max_depth", Input).value
|
|
if max_depth_input.strip():
|
|
try:
|
|
options["max_depth"] = int(max_depth_input)
|
|
except ValueError:
|
|
pass
|
|
|
|
return options
|
|
|
|
def set_map_options(self, options: dict[str, object]) -> None:
|
|
"""Set form values from options."""
|
|
if search := options.get("search"):
|
|
self.query_one("#search_pattern", Input).value = str(search)
|
|
|
|
subdomains_val = options.get("include_subdomains", False)
|
|
self.query_one("#include_subdomains", Switch).value = bool(subdomains_val)
|
|
|
|
# Set limits
|
|
limit = options.get("limit")
|
|
if limit is not None:
|
|
self.query_one("#max_pages", Input).value = str(limit)
|
|
|
|
max_depth = options.get("max_depth")
|
|
if max_depth is not None:
|
|
self.query_one("#max_depth", Input).value = str(max_depth)
|
|
|
|
|
|
class ExtractOptionsForm(Widget):
|
|
"""Form for configuring data extraction options."""
|
|
|
|
DEFAULT_CSS = """
|
|
ExtractOptionsForm {
|
|
border: solid $border;
|
|
background: $surface;
|
|
padding: 1;
|
|
height: auto;
|
|
}
|
|
|
|
ExtractOptionsForm .form-section {
|
|
margin-bottom: 2;
|
|
padding: 1;
|
|
border: solid $border-lighten-1;
|
|
background: $surface-lighten-1;
|
|
}
|
|
|
|
ExtractOptionsForm .form-row {
|
|
layout: horizontal;
|
|
align-items: start;
|
|
height: auto;
|
|
margin-bottom: 1;
|
|
}
|
|
|
|
ExtractOptionsForm .form-label {
|
|
width: 30%;
|
|
min-width: 15;
|
|
text-align: right;
|
|
padding-right: 2;
|
|
padding-top: 1;
|
|
}
|
|
|
|
ExtractOptionsForm .form-input {
|
|
width: 70%;
|
|
}
|
|
|
|
ExtractOptionsForm .text-area {
|
|
height: 6;
|
|
}
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
name: str | None = None,
|
|
id: str | None = None,
|
|
classes: str | None = None,
|
|
disabled: bool = False,
|
|
markup: bool = True,
|
|
) -> None:
|
|
"""Initialize extract options form."""
|
|
super().__init__(name=name, id=id, classes=classes, disabled=disabled, markup=markup)
|
|
|
|
@override
|
|
def compose(self) -> ComposeResult:
|
|
"""Compose extract options form."""
|
|
yield Label("🎯 Data Extraction Configuration", classes="form-title")
|
|
|
|
# Extraction prompt section
|
|
yield Container(
|
|
Label("AI-Powered Extraction", classes="section-title"),
|
|
Horizontal(
|
|
Label("Custom Prompt:", classes="form-label"),
|
|
TextArea(
|
|
placeholder="Extract product names, prices, and descriptions...",
|
|
id="extract_prompt",
|
|
classes="form-input text-area",
|
|
),
|
|
classes="form-row",
|
|
),
|
|
classes="form-section",
|
|
)
|
|
|
|
# Schema definition section
|
|
yield Container(
|
|
Label("Structured Schema (JSON)", classes="section-title"),
|
|
Horizontal(
|
|
Label("Schema Definition:", classes="form-label"),
|
|
TextArea(
|
|
placeholder='{"product_name": "string", "price": "number", "description": "string"}',
|
|
id="extract_schema",
|
|
classes="form-input text-area",
|
|
),
|
|
classes="form-row",
|
|
),
|
|
Container(
|
|
Label("💡 Tip: Define the structure of data you want to extract"),
|
|
classes="help-text",
|
|
),
|
|
classes="form-section",
|
|
)
|
|
|
|
# Schema presets
|
|
yield Container(
|
|
Label("Quick Presets", classes="section-title"),
|
|
Horizontal(
|
|
Button("📄 Article", id="preset_article", variant="default"),
|
|
Button("🛍️ Product", id="preset_product", variant="default"),
|
|
Button("👤 Contact", id="preset_contact", variant="default"),
|
|
Button("📊 Data", id="preset_data", variant="default"),
|
|
classes="preset-buttons",
|
|
),
|
|
classes="form-section",
|
|
)
|
|
|
|
def get_extract_options(self) -> dict[str, object]:
|
|
"""Get extraction options from form."""
|
|
options: dict[str, object] = {}
|
|
|
|
# Extract prompt
|
|
prompt = self.query_one("#extract_prompt", TextArea).text
|
|
if prompt.strip():
|
|
options["extract_prompt"] = prompt.strip()
|
|
|
|
# Extract schema
|
|
schema_text = self.query_one("#extract_schema", TextArea).text
|
|
if schema_text.strip():
|
|
try:
|
|
schema = json.loads(schema_text)
|
|
options["extract_schema"] = schema
|
|
except json.JSONDecodeError:
|
|
# Invalid JSON, skip schema
|
|
pass
|
|
|
|
return options
|
|
|
|
def set_extract_options(self, options: dict[str, object]) -> None:
|
|
"""Set form values from options."""
|
|
if prompt := options.get("extract_prompt"):
|
|
self.query_one("#extract_prompt", TextArea).text = str(prompt)
|
|
|
|
if schema := options.get("extract_schema"):
|
|
import json
|
|
|
|
self.query_one("#extract_schema", TextArea).text = json.dumps(schema, indent=2)
|
|
|
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
"""Handle preset button presses."""
|
|
schema_widget = self.query_one("#extract_schema", TextArea)
|
|
prompt_widget = self.query_one("#extract_prompt", TextArea)
|
|
|
|
if event.button.id == "preset_article":
|
|
schema_widget.text = """{
|
|
"title": "string",
|
|
"author": "string",
|
|
"date": "string",
|
|
"content": "string",
|
|
"tags": ["string"]
|
|
}"""
|
|
prompt_widget.text = (
|
|
"Extract article title, author, publication date, main content, and associated tags"
|
|
)
|
|
|
|
elif event.button.id == "preset_product":
|
|
schema_widget.text = """{
|
|
"name": "string",
|
|
"price": "number",
|
|
"description": "string",
|
|
"category": "string",
|
|
"availability": "string"
|
|
}"""
|
|
prompt_widget.text = (
|
|
"Extract product name, price, description, category, and availability status"
|
|
)
|
|
|
|
elif event.button.id == "preset_contact":
|
|
schema_widget.text = """{
|
|
"name": "string",
|
|
"email": "string",
|
|
"phone": "string",
|
|
"company": "string",
|
|
"position": "string"
|
|
}"""
|
|
prompt_widget.text = (
|
|
"Extract contact information including name, email, phone, company, and position"
|
|
)
|
|
|
|
elif event.button.id == "preset_data":
|
|
schema_widget.text = """{
|
|
"metrics": [{"name": "string", "value": "number", "unit": "string"}],
|
|
"tables": [{"headers": ["string"], "rows": [["string"]]}]
|
|
}"""
|
|
prompt_widget.text = "Extract numerical data, metrics, and tabular information"
|
|
|
|
|
|
class FirecrawlConfigWidget(Widget):
|
|
"""Complete Firecrawl configuration widget with tabbed interface."""
|
|
|
|
DEFAULT_CSS = """
|
|
FirecrawlConfigWidget {
|
|
border: solid $border;
|
|
background: $surface;
|
|
height: 100%;
|
|
padding: 1;
|
|
}
|
|
|
|
FirecrawlConfigWidget .config-header {
|
|
dock: top;
|
|
height: 3;
|
|
background: $primary;
|
|
color: $text;
|
|
padding: 1;
|
|
margin: -1 -1 1 -1;
|
|
}
|
|
|
|
FirecrawlConfigWidget .tab-buttons {
|
|
dock: top;
|
|
height: 3;
|
|
layout: horizontal;
|
|
margin-bottom: 1;
|
|
}
|
|
|
|
FirecrawlConfigWidget .tab-button {
|
|
width: 1fr;
|
|
margin-right: 1;
|
|
}
|
|
|
|
FirecrawlConfigWidget .tab-content {
|
|
height: 1fr;
|
|
overflow: auto;
|
|
}
|
|
|
|
FirecrawlConfigWidget .actions {
|
|
dock: bottom;
|
|
height: 3;
|
|
layout: horizontal;
|
|
align: center;
|
|
margin-top: 1;
|
|
}
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
name: str | None = None,
|
|
id: str | None = None,
|
|
classes: str | None = None,
|
|
disabled: bool = False,
|
|
markup: bool = True,
|
|
) -> None:
|
|
"""Initialize Firecrawl config widget."""
|
|
super().__init__(name=name, id=id, classes=classes, disabled=disabled, markup=markup)
|
|
self.current_tab = "scrape"
|
|
|
|
@override
|
|
def compose(self) -> ComposeResult:
|
|
"""Compose config widget layout."""
|
|
yield Container(
|
|
Label("🔥 Firecrawl Configuration", classes="config-header"),
|
|
Horizontal(
|
|
Button("🔧 Scraping", id="tab_scrape", variant="primary", classes="tab-button"),
|
|
Button("🗺️ Mapping", id="tab_map", variant="default", classes="tab-button"),
|
|
Button("🎯 Extraction", id="tab_extract", variant="default", classes="tab-button"),
|
|
classes="tab-buttons",
|
|
),
|
|
Container(
|
|
ScrapeOptionsForm(id="scrape_form"),
|
|
classes="tab-content",
|
|
),
|
|
Horizontal(
|
|
Button("📋 Load Preset", id="load_preset", variant="default"),
|
|
Button("💾 Save Preset", id="save_preset", variant="default"),
|
|
Button("🔄 Reset", id="reset_config", variant="default"),
|
|
classes="actions",
|
|
),
|
|
)
|
|
|
|
def on_mount(self) -> None:
|
|
"""Initialize widget."""
|
|
self.show_tab("scrape")
|
|
|
|
def show_tab(self, tab_name: str) -> None:
|
|
"""Show specific configuration tab."""
|
|
self.current_tab = tab_name
|
|
|
|
# Update button states
|
|
for tab in ["scrape", "map", "extract"]:
|
|
button = self.query_one(f"#tab_{tab}", Button)
|
|
button.variant = "primary" if tab == tab_name else "default"
|
|
# Update tab content
|
|
content_container = self.query_one(".tab-content", Container)
|
|
content_container.remove_children()
|
|
|
|
if tab_name == "extract":
|
|
content_container.mount(ExtractOptionsForm(id="extract_form"))
|
|
elif tab_name == "map":
|
|
content_container.mount(MapOptionsForm(id="map_form"))
|
|
elif tab_name == "scrape":
|
|
content_container.mount(ScrapeOptionsForm(id="scrape_form"))
|
|
|
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
"""Handle button presses."""
|
|
if event.button.id and event.button.id.startswith("tab_"):
|
|
tab_name = event.button.id[4:] # Remove "tab_" prefix
|
|
self.show_tab(tab_name)
|
|
|
|
def get_all_options(self) -> FirecrawlOptions:
|
|
"""Get all configuration options."""
|
|
options: FirecrawlOptions = {}
|
|
|
|
# Try to get options from currently mounted form
|
|
if self.current_tab == "scrape":
|
|
try:
|
|
form = self.query_one("#scrape_form", ScrapeOptionsForm)
|
|
scrape_opts = form.get_scrape_options()
|
|
options.update(cast(FirecrawlOptions, scrape_opts))
|
|
except Exception:
|
|
pass
|
|
elif self.current_tab == "map":
|
|
try:
|
|
map_form = self.query_one("#map_form", MapOptionsForm)
|
|
map_opts = map_form.get_map_options()
|
|
options.update(cast(FirecrawlOptions, map_opts))
|
|
except Exception:
|
|
pass
|
|
elif self.current_tab == "extract":
|
|
try:
|
|
extract_form = self.query_one("#extract_form", ExtractOptionsForm)
|
|
extract_opts = extract_form.get_extract_options()
|
|
options.update(cast(FirecrawlOptions, extract_opts))
|
|
except Exception:
|
|
pass
|
|
|
|
return options
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/cli/tui/widgets/r2r_widgets.py">
|
|
"""R2R-specific widgets for chunk viewing and entity visualization."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
from textual import work
|
|
from textual.app import ComposeResult
|
|
from textual.containers import Container, Horizontal, Vertical, VerticalScroll
|
|
from textual.widget import Widget
|
|
from textual.widgets import Button, DataTable, Label, Markdown, ProgressBar, Static, Tree
|
|
from typing_extensions import override
|
|
|
|
from ....storage.r2r.storage import R2RStorage
|
|
from ..models import ChunkInfo, EntityInfo
|
|
|
|
|
|
class ChunkViewer(Widget):
|
|
"""Widget for viewing document chunks with navigation."""
|
|
|
|
DEFAULT_CSS = """
|
|
ChunkViewer {
|
|
border: solid $border;
|
|
background: $surface;
|
|
height: 100%;
|
|
}
|
|
|
|
ChunkViewer .chunk-header {
|
|
dock: top;
|
|
height: 3;
|
|
background: $primary;
|
|
color: $text;
|
|
padding: 1;
|
|
}
|
|
|
|
ChunkViewer .chunk-navigation {
|
|
dock: top;
|
|
height: 3;
|
|
background: $surface-lighten-1;
|
|
padding: 1;
|
|
}
|
|
|
|
ChunkViewer .chunk-content {
|
|
height: 1fr;
|
|
padding: 1;
|
|
overflow: auto;
|
|
}
|
|
|
|
ChunkViewer .chunk-footer {
|
|
dock: bottom;
|
|
height: 3;
|
|
background: $surface-darken-1;
|
|
padding: 1;
|
|
}
|
|
"""
|
|
|
|
def __init__(self, r2r_storage: R2RStorage, document_id: str, **kwargs: Any) -> None:
|
|
"""Initialize chunk viewer."""
|
|
super().__init__(**kwargs)
|
|
self.r2r_storage: R2RStorage = r2r_storage
|
|
self.document_id: str = document_id
|
|
self.chunks: list[ChunkInfo] = []
|
|
self.current_chunk_index: int = 0
|
|
|
|
@override
|
|
def compose(self) -> ComposeResult:
|
|
"""Compose chunk viewer layout."""
|
|
yield Container(
|
|
Static("📄 Document Chunks", classes="chunk-header"),
|
|
Horizontal(
|
|
Button("◀ Previous", id="prev_chunk", variant="default"),
|
|
Static("Chunk 1 of 1", id="chunk_info"),
|
|
Button("Next ▶", id="next_chunk", variant="default"),
|
|
classes="chunk-navigation",
|
|
),
|
|
VerticalScroll(
|
|
Markdown("", id="chunk_content"),
|
|
classes="chunk-content",
|
|
),
|
|
Container(
|
|
Static("Loading chunks...", id="chunk_status"),
|
|
classes="chunk-footer",
|
|
),
|
|
)
|
|
|
|
def on_mount(self) -> None:
|
|
"""Initialize chunk viewer."""
|
|
self.load_chunks()
|
|
|
|
@work(exclusive=True)
|
|
async def load_chunks(self) -> None:
|
|
"""Load document chunks."""
|
|
try:
|
|
chunks_data = await self.r2r_storage.get_document_chunks(self.document_id)
|
|
self.chunks = []
|
|
|
|
for chunk_data in chunks_data:
|
|
chunk_info: ChunkInfo = {
|
|
"id": str(chunk_data.get("id", "")),
|
|
"document_id": self.document_id,
|
|
"content": str(chunk_data.get("text", "")),
|
|
"start_index": (lambda si: int(si) if isinstance(si, (int, str)) else 0)(
|
|
chunk_data.get("start_index", 0)
|
|
),
|
|
"end_index": (lambda ei: int(ei) if isinstance(ei, (int, str)) else 0)(
|
|
chunk_data.get("end_index", 0)
|
|
),
|
|
"metadata": (
|
|
dict(metadata_val)
|
|
if (metadata_val := chunk_data.get("metadata"))
|
|
and isinstance(metadata_val, dict)
|
|
else {}
|
|
),
|
|
}
|
|
self.chunks.append(chunk_info)
|
|
|
|
if self.chunks:
|
|
self.current_chunk_index = 0
|
|
self.update_chunk_display()
|
|
else:
|
|
self.query_one("#chunk_status", Static).update("No chunks found")
|
|
|
|
except Exception as e:
|
|
self.query_one("#chunk_status", Static).update(f"Error loading chunks: {e}")
|
|
|
|
def update_chunk_display(self) -> None:
|
|
"""Update chunk display with current chunk."""
|
|
if not self.chunks:
|
|
return
|
|
|
|
chunk = self.chunks[self.current_chunk_index]
|
|
|
|
# Update content
|
|
content_widget = self.query_one("#chunk_content", Markdown)
|
|
content_widget.update(chunk["content"])
|
|
|
|
# Update navigation info
|
|
chunk_info = self.query_one("#chunk_info", Static)
|
|
chunk_info.update(f"Chunk {self.current_chunk_index + 1} of {len(self.chunks)}")
|
|
|
|
# Update status
|
|
status_widget = self.query_one("#chunk_status", Static)
|
|
status_widget.update(
|
|
f"Chunk {chunk['id']} | "
|
|
f"Range: {chunk['start_index']}-{chunk['end_index']} | "
|
|
f"Length: {len(chunk['content'])} chars"
|
|
)
|
|
|
|
# Update button states
|
|
prev_btn = self.query_one("#prev_chunk", Button)
|
|
next_btn = self.query_one("#next_chunk", Button)
|
|
prev_btn.disabled = self.current_chunk_index == 0
|
|
next_btn.disabled = self.current_chunk_index >= len(self.chunks) - 1
|
|
|
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
"""Handle button presses."""
|
|
if event.button.id == "prev_chunk" and self.current_chunk_index > 0:
|
|
self.current_chunk_index -= 1
|
|
self.update_chunk_display()
|
|
elif event.button.id == "next_chunk" and self.current_chunk_index < len(self.chunks) - 1:
|
|
self.current_chunk_index += 1
|
|
self.update_chunk_display()
|
|
|
|
|
|
class EntityGraph(Widget):
|
|
"""Widget for visualizing extracted entities and relationships."""
|
|
|
|
DEFAULT_CSS = """
|
|
EntityGraph {
|
|
border: solid $border;
|
|
background: $surface;
|
|
height: 100%;
|
|
}
|
|
|
|
EntityGraph .entity-header {
|
|
dock: top;
|
|
height: 3;
|
|
background: $primary;
|
|
color: $text;
|
|
padding: 1;
|
|
}
|
|
|
|
EntityGraph .entity-tree {
|
|
height: 1fr;
|
|
overflow: auto;
|
|
}
|
|
|
|
EntityGraph .entity-details {
|
|
dock: bottom;
|
|
height: 8;
|
|
background: $surface-lighten-1;
|
|
padding: 1;
|
|
border-top: solid $border;
|
|
}
|
|
"""
|
|
|
|
def __init__(self, r2r_storage: R2RStorage, document_id: str, **kwargs: Any) -> None:
|
|
"""Initialize entity graph."""
|
|
super().__init__(**kwargs)
|
|
self.r2r_storage: R2RStorage = r2r_storage
|
|
self.document_id: str = document_id
|
|
self.entities: list[EntityInfo] = []
|
|
|
|
@override
|
|
def compose(self) -> ComposeResult:
|
|
"""Compose entity graph layout."""
|
|
yield Container(
|
|
Static("🕸️ Entity Graph", classes="entity-header"),
|
|
Tree("Entities", id="entity_tree", classes="entity-tree"),
|
|
VerticalScroll(
|
|
Label("Entity Details"),
|
|
Static("Select an entity to view details", id="entity_details"),
|
|
classes="entity-details",
|
|
),
|
|
)
|
|
|
|
def on_mount(self) -> None:
|
|
"""Initialize entity graph."""
|
|
self.load_entities()
|
|
|
|
@work(exclusive=True)
|
|
async def load_entities(self) -> None:
|
|
"""Load entities from document."""
|
|
try:
|
|
entities_data = await self.r2r_storage.extract_entities(self.document_id)
|
|
self.entities = []
|
|
|
|
# Parse entities from R2R response
|
|
entities_list = entities_data.get("entities", [])
|
|
if not isinstance(entities_list, list):
|
|
entities_list = []
|
|
for entity_data in entities_list:
|
|
entity_info: EntityInfo = {
|
|
"id": str(entity_data.get("id", "")),
|
|
"name": str(entity_data.get("name", "")),
|
|
"type": str(entity_data.get("type", "unknown")),
|
|
"confidence": float(entity_data.get("confidence", 0.0)),
|
|
"metadata": dict(entity_data.get("metadata", {})),
|
|
}
|
|
self.entities.append(entity_info)
|
|
|
|
self.populate_entity_tree()
|
|
|
|
except Exception as e:
|
|
details_widget = self.query_one("#entity_details", Static)
|
|
details_widget.update(f"Error loading entities: {e}")
|
|
|
|
def populate_entity_tree(self) -> None:
|
|
"""Populate the entity tree."""
|
|
tree = self.query_one("#entity_tree", Tree)
|
|
tree.clear()
|
|
|
|
if not self.entities:
|
|
tree.root.add_leaf("No entities found")
|
|
return
|
|
|
|
# Group entities by type
|
|
entities_by_type: dict[str, list[EntityInfo]] = {}
|
|
for entity in self.entities:
|
|
entity_type = entity["type"]
|
|
if entity_type not in entities_by_type:
|
|
entities_by_type[entity_type] = []
|
|
entities_by_type[entity_type].append(entity)
|
|
|
|
# Add entities to tree grouped by type
|
|
for entity_type, type_entities in entities_by_type.items():
|
|
type_node = tree.root.add(f"{entity_type.title()} ({len(type_entities)})")
|
|
for entity in type_entities:
|
|
confidence_pct = int(entity["confidence"] * 100)
|
|
entity_node = type_node.add_leaf(f"{entity['name']} ({confidence_pct}%)")
|
|
entity_node.data = entity
|
|
|
|
tree.root.expand()
|
|
|
|
def on_tree_node_selected(self, event: Tree.NodeSelected[EntityInfo]) -> None:
|
|
"""Handle entity selection."""
|
|
if hasattr(event.node, "data") and event.node.data:
|
|
entity = event.node.data
|
|
self.show_entity_details(entity)
|
|
|
|
def show_entity_details(self, entity: EntityInfo) -> None:
|
|
"""Show detailed information about an entity."""
|
|
details_widget = self.query_one("#entity_details", Static)
|
|
|
|
details_text = f"""**Entity:** {entity["name"]}
|
|
**Type:** {entity["type"]}
|
|
**Confidence:** {entity["confidence"]:.2%}
|
|
**ID:** {entity["id"]}
|
|
|
|
**Metadata:**
|
|
"""
|
|
for key, value in entity["metadata"].items():
|
|
details_text += f"- **{key}:** {value}\n"
|
|
|
|
details_widget.update(details_text)
|
|
|
|
|
|
class CollectionStats(Widget):
|
|
"""Widget for showing R2R-specific collection statistics."""
|
|
|
|
DEFAULT_CSS = """
|
|
CollectionStats {
|
|
border: solid $border;
|
|
background: $surface;
|
|
height: 100%;
|
|
padding: 1;
|
|
}
|
|
|
|
CollectionStats .stats-header {
|
|
dock: top;
|
|
height: 3;
|
|
background: $primary;
|
|
color: $text;
|
|
padding: 1;
|
|
margin: -1 -1 1 -1;
|
|
}
|
|
|
|
CollectionStats .stats-grid {
|
|
layout: grid;
|
|
grid-size: 2;
|
|
grid-columns: 1fr 1fr;
|
|
grid-gutter: 1;
|
|
height: auto;
|
|
}
|
|
|
|
CollectionStats .stat-card {
|
|
background: $surface-lighten-1;
|
|
border: solid $border;
|
|
padding: 1;
|
|
height: auto;
|
|
}
|
|
|
|
CollectionStats .stat-value {
|
|
color: $primary;
|
|
text-style: bold;
|
|
text-align: center;
|
|
}
|
|
|
|
CollectionStats .stat-label {
|
|
color: $text-muted;
|
|
text-align: center;
|
|
margin-top: 1;
|
|
}
|
|
|
|
CollectionStats .progress-section {
|
|
margin-top: 2;
|
|
}
|
|
"""
|
|
|
|
def __init__(self, r2r_storage: R2RStorage, collection_name: str, **kwargs: Any) -> None:
|
|
"""Initialize collection stats."""
|
|
super().__init__(**kwargs)
|
|
self.r2r_storage: R2RStorage = r2r_storage
|
|
self.collection_name: str = collection_name
|
|
|
|
@override
|
|
def compose(self) -> ComposeResult:
|
|
"""Compose stats layout."""
|
|
yield Container(
|
|
Static(f"📊 {self.collection_name} Statistics", classes="stats-header"),
|
|
Container(
|
|
Container(
|
|
Static("0", id="document_count", classes="stat-value"),
|
|
Static("Documents", classes="stat-label"),
|
|
classes="stat-card",
|
|
),
|
|
Container(
|
|
Static("0", id="chunk_count", classes="stat-value"),
|
|
Static("Chunks", classes="stat-label"),
|
|
classes="stat-card",
|
|
),
|
|
Container(
|
|
Static("0", id="entity_count", classes="stat-value"),
|
|
Static("Entities", classes="stat-label"),
|
|
classes="stat-card",
|
|
),
|
|
Container(
|
|
Static("0 MB", id="storage_size", classes="stat-value"),
|
|
Static("Storage Used", classes="stat-label"),
|
|
classes="stat-card",
|
|
),
|
|
classes="stats-grid",
|
|
),
|
|
Container(
|
|
Label("Processing Progress"),
|
|
ProgressBar(id="processing_progress", total=100, show_eta=False),
|
|
Static("Idle", id="processing_status"),
|
|
classes="progress-section",
|
|
),
|
|
)
|
|
|
|
def on_mount(self) -> None:
|
|
"""Initialize stats display."""
|
|
self.refresh_stats()
|
|
|
|
@work(exclusive=True)
|
|
async def refresh_stats(self) -> None:
|
|
"""Refresh collection statistics."""
|
|
try:
|
|
# Get basic document count
|
|
doc_count = await self.r2r_storage.count(collection_name=self.collection_name)
|
|
self.query_one("#document_count", Static).update(str(doc_count))
|
|
|
|
# Estimate other stats (these would need real implementation)
|
|
estimated_chunks = doc_count * 5 # Rough estimate
|
|
estimated_entities = doc_count * 10 # Rough estimate
|
|
estimated_size_mb = doc_count * 0.05 # Rough estimate
|
|
|
|
self.query_one("#chunk_count", Static).update(str(estimated_chunks))
|
|
self.query_one("#entity_count", Static).update(str(estimated_entities))
|
|
self.query_one("#storage_size", Static).update(f"{estimated_size_mb:.1f} MB")
|
|
|
|
# Update progress (would be real-time in actual implementation)
|
|
progress_bar = self.query_one("#processing_progress", ProgressBar)
|
|
progress_bar.progress = 100 # Assume complete for now
|
|
|
|
status_widget = self.query_one("#processing_status", Static)
|
|
status_widget.update("All documents processed")
|
|
|
|
except Exception as e:
|
|
self.query_one("#processing_status", Static).update(f"Error: {e}")
|
|
|
|
|
|
class DocumentOverview(Widget):
|
|
"""Widget for comprehensive document overview and statistics."""
|
|
|
|
DEFAULT_CSS = """
|
|
DocumentOverview {
|
|
layout: vertical;
|
|
height: 100%;
|
|
}
|
|
|
|
DocumentOverview .overview-header {
|
|
dock: top;
|
|
height: 3;
|
|
background: $primary;
|
|
color: $text;
|
|
padding: 1;
|
|
}
|
|
|
|
DocumentOverview .overview-content {
|
|
height: 1fr;
|
|
layout: horizontal;
|
|
}
|
|
|
|
DocumentOverview .overview-left {
|
|
width: 50%;
|
|
padding: 1;
|
|
}
|
|
|
|
DocumentOverview .overview-right {
|
|
width: 50%;
|
|
padding: 1;
|
|
}
|
|
|
|
DocumentOverview .info-table {
|
|
height: auto;
|
|
margin-bottom: 2;
|
|
}
|
|
"""
|
|
|
|
def __init__(self, r2r_storage: R2RStorage, document_id: str, **kwargs: Any) -> None:
|
|
"""Initialize document overview."""
|
|
super().__init__(**kwargs)
|
|
self.r2r_storage: R2RStorage = r2r_storage
|
|
self.document_id: str = document_id
|
|
|
|
@override
|
|
def compose(self) -> ComposeResult:
|
|
"""Compose overview layout."""
|
|
yield Container(
|
|
Static("📋 Document Overview", classes="overview-header"),
|
|
Horizontal(
|
|
Vertical(
|
|
Label("Document Information"),
|
|
DataTable[str](id="doc_info_table", classes="info-table"),
|
|
Label("Processing Statistics"),
|
|
DataTable[str](id="stats_table", classes="info-table"),
|
|
classes="overview-left",
|
|
),
|
|
Vertical(
|
|
ChunkViewer(self.r2r_storage, self.document_id),
|
|
classes="overview-right",
|
|
),
|
|
classes="overview-content",
|
|
),
|
|
)
|
|
|
|
def on_mount(self) -> None:
|
|
"""Initialize overview."""
|
|
self.load_overview()
|
|
|
|
@work(exclusive=True)
|
|
async def load_overview(self) -> None:
|
|
"""Load comprehensive document overview."""
|
|
try:
|
|
overview_data = await self.r2r_storage.get_document_overview(self.document_id)
|
|
|
|
# Populate document info table
|
|
doc_table = self.query_one("#doc_info_table", DataTable)
|
|
doc_table.add_columns("Property", "Value")
|
|
|
|
document_info_raw = overview_data.get("document", {})
|
|
document_info = document_info_raw if isinstance(document_info_raw, dict) else {}
|
|
doc_table.add_row("ID", str(document_info.get("id", "N/A")))
|
|
doc_table.add_row("Title", str(document_info.get("title", "N/A")))
|
|
doc_table.add_row("Created", str(document_info.get("created_at", "N/A")))
|
|
doc_table.add_row("Modified", str(document_info.get("updated_at", "N/A")))
|
|
|
|
# Populate stats table
|
|
stats_table = self.query_one("#stats_table", DataTable)
|
|
stats_table.add_columns("Metric", "Count")
|
|
|
|
chunk_count = overview_data.get("chunk_count", 0)
|
|
stats_table.add_row("Chunks", str(chunk_count))
|
|
stats_table.add_row("Characters", str(len(str(document_info.get("content", "")))))
|
|
|
|
except Exception as e:
|
|
# Handle error by showing minimal info
|
|
doc_table = self.query_one("#doc_info_table", DataTable)
|
|
doc_table.add_columns("Property", "Value")
|
|
doc_table.add_row("Error", str(e))
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/cli/tui/app.py">
|
|
"""Main TUI application with enhanced keyboard navigation."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
from collections import deque
|
|
from pathlib import Path
|
|
from queue import Empty, Queue
|
|
from typing import TYPE_CHECKING, ClassVar, Literal
|
|
|
|
from textual import events
|
|
from textual.app import App
|
|
from textual.binding import Binding, BindingType
|
|
from textual.timer import Timer
|
|
|
|
from ...storage.base import BaseStorage
|
|
from ...storage.openwebui import OpenWebUIStorage
|
|
from ...storage.weaviate import WeaviateStorage
|
|
from .screens.dashboard import CollectionOverviewScreen
|
|
from .screens.help import HelpScreen
|
|
from .styles import TUI_CSS
|
|
from .utils.storage_manager import StorageManager
|
|
|
|
if TYPE_CHECKING:
|
|
from logging import Formatter, LogRecord
|
|
|
|
from ...storage.r2r.storage import R2RStorage
|
|
from .screens.dialogs import LogViewerScreen
|
|
else: # pragma: no cover - optional dependency fallback
|
|
R2RStorage = BaseStorage
|
|
|
|
|
|
class CollectionManagementApp(App[None]):
|
|
"""Enhanced modern Textual application with comprehensive keyboard navigation."""
|
|
|
|
CSS: ClassVar[str] = TUI_CSS
|
|
TITLE = "Collection Management"
|
|
SUB_TITLE = "Document Ingestion Pipeline"
|
|
|
|
def safe_notify(
|
|
self,
|
|
message: str,
|
|
*,
|
|
severity: Literal["information", "warning", "error"] = "information",
|
|
) -> None:
|
|
"""Safely notify with markup disabled to prevent parsing errors."""
|
|
self.notify(message, severity=severity, markup=False)
|
|
|
|
BINDINGS: ClassVar[list[BindingType]] = [
|
|
Binding("q", "quit", "Quit"),
|
|
Binding("ctrl+c", "quit", "Quit"),
|
|
Binding("ctrl+q", "quit", "Quit"),
|
|
Binding("f1", "help", "Help"),
|
|
Binding("ctrl+h", "help", "Help"),
|
|
Binding("?", "help", "Quick Help"),
|
|
# Global navigation shortcuts
|
|
Binding("ctrl+r", "refresh_current", "Refresh Current Screen"),
|
|
Binding("ctrl+w", "close_current", "Close Current Screen"),
|
|
Binding("ctrl+l", "toggle_logs", "Logs"),
|
|
# Tab navigation shortcuts
|
|
Binding("ctrl+1", "dashboard_tab", "Dashboard", show=False),
|
|
Binding("ctrl+2", "collections_tab", "Collections", show=False),
|
|
Binding("ctrl+3", "analytics_tab", "Analytics", show=False),
|
|
]
|
|
|
|
storage_manager: StorageManager
|
|
weaviate: WeaviateStorage | None
|
|
openwebui: OpenWebUIStorage | None
|
|
r2r: R2RStorage | BaseStorage | None
|
|
log_queue: Queue[LogRecord] | None
|
|
_log_formatter: Formatter
|
|
_log_buffer: deque[str]
|
|
_log_viewer: LogViewerScreen | None
|
|
_log_file: Path | None
|
|
_log_timer: Timer | None
|
|
|
|
def __init__(
|
|
self,
|
|
storage_manager: StorageManager,
|
|
weaviate: WeaviateStorage | None = None,
|
|
openwebui: OpenWebUIStorage | None = None,
|
|
r2r: R2RStorage | BaseStorage | None = None,
|
|
*,
|
|
log_queue: Queue[LogRecord] | None = None,
|
|
log_formatter: Formatter | None = None,
|
|
log_file: Path | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.storage_manager = storage_manager
|
|
self.weaviate = weaviate
|
|
self.openwebui = openwebui
|
|
self.r2r = r2r
|
|
# Remove direct assignment to read-only title properties
|
|
# These should be set through class attributes or overridden methods
|
|
self.log_queue = log_queue
|
|
self._log_formatter = log_formatter or logging.Formatter(
|
|
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
|
datefmt="%H:%M:%S",
|
|
)
|
|
self._log_buffer = deque(maxlen=500)
|
|
self._log_viewer = None
|
|
self._log_file = log_file
|
|
self._log_timer = None
|
|
|
|
def on_mount(self) -> None:
|
|
"""Initialize the enhanced app with better branding."""
|
|
self.title = "🚀 Enhanced Collection Management System"
|
|
self.sub_title = (
|
|
"Advanced Document Ingestion & Management Platform with Keyboard Navigation"
|
|
)
|
|
reduced_motion_env = os.getenv("TEXTUAL_REDUCED_MOTION") or os.getenv(
|
|
"PREFER_REDUCED_MOTION"
|
|
)
|
|
if reduced_motion_env is not None:
|
|
normalized = reduced_motion_env.strip().lower()
|
|
reduced_motion_enabled = normalized in {"1", "true", "yes", "on"}
|
|
else:
|
|
reduced_motion_enabled = False
|
|
_ = self.set_class(reduced_motion_enabled, "reduced-motion")
|
|
_ = self.push_screen(
|
|
CollectionOverviewScreen(
|
|
self.storage_manager,
|
|
self.weaviate,
|
|
self.openwebui,
|
|
self.r2r,
|
|
)
|
|
)
|
|
if self.log_queue is not None and self._log_timer is None:
|
|
# Poll the queue so log output is captured without blocking the UI loop
|
|
self._log_timer = self.set_interval(0.25, self._drain_log_queue)
|
|
|
|
def _drain_log_queue(self) -> None:
|
|
"""Drain queued log records and route them to the active log viewer."""
|
|
if self.log_queue is None:
|
|
return
|
|
|
|
drained: list[str] = []
|
|
while True:
|
|
try:
|
|
record = self.log_queue.get_nowait()
|
|
except Empty:
|
|
break
|
|
message = self._log_formatter.format(record)
|
|
self._log_buffer.append(message)
|
|
drained.append(message)
|
|
|
|
if drained and self._log_viewer is not None:
|
|
self._log_viewer.append_logs(drained)
|
|
|
|
def attach_log_viewer(self, viewer: LogViewerScreen) -> None:
|
|
"""Register an active log viewer and hydrate it with existing entries."""
|
|
self._log_viewer = viewer
|
|
viewer.replace_logs(list(self._log_buffer))
|
|
viewer.update_log_file(self._log_file)
|
|
# Drain once more to deliver any entries gathered between instantiation and mount
|
|
self._drain_log_queue()
|
|
|
|
def detach_log_viewer(self, viewer: LogViewerScreen) -> None:
|
|
"""Remove the current log viewer when it is dismissed."""
|
|
if self._log_viewer is viewer:
|
|
self._log_viewer = None
|
|
|
|
def get_log_file_path(self) -> Path | None:
|
|
"""Return the active log file path if configured."""
|
|
return self._log_file
|
|
|
|
def action_toggle_logs(self) -> None:
|
|
"""Toggle the log viewer modal screen."""
|
|
if self._log_viewer is not None:
|
|
_ = self.pop_screen()
|
|
return
|
|
|
|
from .screens.dialogs import LogViewerScreen # Local import to avoid cycle
|
|
|
|
_ = self.push_screen(LogViewerScreen())
|
|
|
|
def action_help(self) -> None:
|
|
"""Show comprehensive help information with all keyboard shortcuts."""
|
|
help_md = """
|
|
# 🚀 Enhanced Collection Management System
|
|
|
|
## 🎯 Global Navigation
|
|
- **F1** / **Ctrl+H** / **?**: Show this help
|
|
- **Q** / **Ctrl+C** / **Ctrl+Q**: Quit application
|
|
- **Ctrl+R**: Refresh current screen
|
|
- **Ctrl+W**: Close current screen/dialog
|
|
- **Escape**: Go back/cancel current action
|
|
|
|
## 📑 Tab Navigation
|
|
- **Tab** / **Shift+Tab**: Switch between tabs
|
|
- **Ctrl+1**: Jump to Dashboard tab
|
|
- **Ctrl+2**: Jump to Collections tab
|
|
- **Ctrl+3**: Jump to Analytics tab
|
|
|
|
## 📚 Collections Management
|
|
- **R**: Refresh collections list
|
|
- **I**: Start new ingestion
|
|
- **M**: Manage documents in selected collection
|
|
- **S**: Search within selected collection
|
|
- **Ctrl+D**: Delete selected collection
|
|
|
|
## 🗂️ Table Navigation
|
|
- **Arrow Keys** / **J/K/H/L**: Navigate table cells (Vi-style)
|
|
- **Home** / **End**: Jump to first/last row
|
|
- **Page Up** / **Page Down**: Scroll by page
|
|
- **Enter**: Select/activate current row
|
|
- **Space**: Toggle row selection
|
|
- **Ctrl+A**: Select all items
|
|
- **Ctrl+Shift+A**: Clear all selections
|
|
|
|
## 📄 Document Management
|
|
- **Space**: Toggle document selection
|
|
- **Delete** / **Ctrl+D**: Delete selected documents
|
|
- **A**: Select all documents on page
|
|
- **N**: Clear selection
|
|
- **Page Up/Down**: Navigate between pages
|
|
- **Home/End**: Go to first/last page
|
|
|
|
## 🔍 Search Features
|
|
- **/** : Quick search (focus search field)
|
|
- **Ctrl+F**: Focus search input
|
|
- **Enter**: Perform search
|
|
- **F3**: Repeat last search
|
|
- **Ctrl+R**: Clear search results
|
|
- **Escape**: Clear search/exit search mode
|
|
|
|
## 📥 Ingestion Interface
|
|
- **1/2/3**: Select ingestion type (Web/Repository/Documentation)
|
|
- **Tab/Shift+Tab**: Navigate between fields
|
|
- **Enter**: Start ingestion process
|
|
- **Ctrl+I**: Quick start ingestion
|
|
- **Escape**: Cancel ingestion
|
|
|
|
## 🎨 Visual Features
|
|
- Enhanced focus indicators with colored borders
|
|
- Smooth keyboard navigation with visual feedback
|
|
- Status indicators with real-time updates
|
|
- Progress bars with detailed status messages
|
|
- Responsive design with accessibility features
|
|
|
|
## 💡 Pro Tips
|
|
- Use **Vi-style** navigation (J/K/H/L) for efficient movement
|
|
- **Tab** through interactive elements for keyboard-only operation
|
|
- Hold **Shift** with arrow keys for range selection (where supported)
|
|
- Use **Ctrl+** shortcuts for power user efficiency
|
|
- **Escape** is your friend - it cancels most operations safely
|
|
|
|
## 🚀 Performance Features
|
|
- Lazy loading for large collections
|
|
- Paginated document views
|
|
- Background refresh operations
|
|
- Efficient memory management
|
|
- Responsive UI updates
|
|
|
|
---
|
|
|
|
**Enjoy the enhanced keyboard-driven interface!** 🎉
|
|
|
|
*Press Escape, Enter, or Q to close this help.*
|
|
"""
|
|
_ = self.push_screen(HelpScreen(help_md))
|
|
|
|
def action_refresh_current(self) -> None:
|
|
"""Refresh the current screen if it supports it."""
|
|
current_screen = self.screen
|
|
handler = getattr(current_screen, "action_refresh", None)
|
|
if callable(handler):
|
|
_ = handler()
|
|
return
|
|
self.notify("Current screen doesn't support refresh", severity="information")
|
|
|
|
def action_close_current(self) -> None:
|
|
"""Close current screen/dialog."""
|
|
if len(self.screen_stack) > 1: # Don't close the main screen
|
|
_ = self.pop_screen()
|
|
else:
|
|
self.notify("Cannot close main screen. Use Q to quit.", severity="warning")
|
|
|
|
def action_dashboard_tab(self) -> None:
|
|
"""Switch to dashboard tab in current screen."""
|
|
current_screen = self.screen
|
|
handler = getattr(current_screen, "action_tab_dashboard", None)
|
|
if callable(handler):
|
|
_ = handler()
|
|
|
|
def action_collections_tab(self) -> None:
|
|
"""Switch to collections tab in current screen."""
|
|
current_screen = self.screen
|
|
handler = getattr(current_screen, "action_tab_collections", None)
|
|
if callable(handler):
|
|
_ = handler()
|
|
|
|
def action_analytics_tab(self) -> None:
|
|
"""Switch to analytics tab in current screen."""
|
|
current_screen = self.screen
|
|
handler = getattr(current_screen, "action_tab_analytics", None)
|
|
if callable(handler):
|
|
_ = handler()
|
|
|
|
def on_key(self, event: events.Key) -> None:
|
|
"""Handle global keyboard shortcuts."""
|
|
# Handle global shortcuts that might not be bound to specific actions
|
|
if event.key == "ctrl+shift+?":
|
|
# Alternative help shortcut
|
|
self.action_help()
|
|
_ = event.prevent_default()
|
|
elif event.key == "ctrl+alt+r":
|
|
# Force refresh all connections
|
|
self.notify("🔄 Refreshing all connections...", severity="information")
|
|
# This could trigger a full reinit if needed
|
|
_ = event.prevent_default()
|
|
# No else clause needed - just handle our events
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/cli/tui/layouts.py">
|
|
"""Responsive layout system for TUI applications."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import cast
|
|
|
|
from textual.app import ComposeResult
|
|
from textual.containers import Container, VerticalScroll
|
|
from textual.widget import Widget
|
|
from textual.widgets import Static
|
|
from typing_extensions import override
|
|
|
|
|
|
class ResponsiveGrid(Container):
|
|
"""Grid that auto-adjusts based on terminal size."""
|
|
|
|
DEFAULT_CSS: str = """
|
|
ResponsiveGrid {
|
|
layout: grid;
|
|
grid-size: 1;
|
|
grid-columns: 1fr;
|
|
grid-rows: auto;
|
|
grid-gutter: 1;
|
|
padding: 1;
|
|
}
|
|
|
|
ResponsiveGrid.two-column {
|
|
grid-size: 2;
|
|
grid-columns: 1fr 1fr;
|
|
}
|
|
|
|
ResponsiveGrid.three-column {
|
|
grid-size: 3;
|
|
grid-columns: 1fr 1fr 1fr;
|
|
}
|
|
|
|
ResponsiveGrid.auto-fit {
|
|
grid-columns: repeat(auto-fit, minmax(20, 1fr));
|
|
}
|
|
|
|
ResponsiveGrid.compact {
|
|
grid-gutter: 0;
|
|
padding: 0;
|
|
}
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*children: Widget,
|
|
columns: int = 1,
|
|
auto_fit: bool = False,
|
|
compact: bool = False,
|
|
name: str | None = None,
|
|
id: str | None = None,
|
|
classes: str | None = None,
|
|
disabled: bool = False,
|
|
markup: bool = True,
|
|
) -> None:
|
|
"""Initialize responsive grid."""
|
|
super().__init__(
|
|
*children, name=name, id=id, classes=classes, disabled=disabled, markup=markup
|
|
)
|
|
self._columns: int = columns
|
|
self._auto_fit: bool = auto_fit
|
|
self._compact: bool = compact
|
|
|
|
def on_mount(self) -> None:
|
|
"""Apply responsive classes based on configuration."""
|
|
widget = cast(Widget, self)
|
|
if self._auto_fit:
|
|
widget.add_class("auto-fit")
|
|
elif self._columns == 2:
|
|
widget.add_class("two-column")
|
|
elif self._columns == 3:
|
|
widget.add_class("three-column")
|
|
|
|
if self._compact:
|
|
widget.add_class("compact")
|
|
|
|
def on_resize(self) -> None:
|
|
"""Adjust layout based on terminal size."""
|
|
if self._auto_fit:
|
|
# Let CSS handle auto-fit
|
|
return
|
|
|
|
widget = cast(Widget, self)
|
|
terminal_width = widget.size.width
|
|
if terminal_width < 60:
|
|
# Force single column on narrow terminals
|
|
widget.remove_class("two-column", "three-column")
|
|
widget.styles.grid_size_columns = 1
|
|
widget.styles.grid_columns = "1fr"
|
|
elif terminal_width < 100 and self._columns > 2:
|
|
# Force two columns on medium terminals
|
|
widget.remove_class("three-column")
|
|
widget.add_class("two-column")
|
|
widget.styles.grid_size_columns = 2
|
|
widget.styles.grid_columns = "1fr 1fr"
|
|
elif self._columns == 2:
|
|
widget.add_class("two-column")
|
|
elif self._columns == 3:
|
|
widget.add_class("three-column")
|
|
|
|
|
|
class CollapsibleSidebar(Container):
|
|
"""Sidebar that can be collapsed to save space."""
|
|
|
|
DEFAULT_CSS: str = """
|
|
CollapsibleSidebar {
|
|
dock: left;
|
|
width: 25%;
|
|
min-width: 20;
|
|
max-width: 40;
|
|
background: $surface;
|
|
border-right: solid $border;
|
|
padding: 1;
|
|
transition: width 300ms;
|
|
}
|
|
|
|
CollapsibleSidebar.collapsed {
|
|
width: 3;
|
|
min-width: 3;
|
|
overflow: hidden;
|
|
}
|
|
|
|
CollapsibleSidebar.collapsed > * {
|
|
display: none;
|
|
}
|
|
|
|
CollapsibleSidebar .sidebar-toggle {
|
|
dock: top;
|
|
height: 1;
|
|
background: $primary;
|
|
color: $text;
|
|
text-align: center;
|
|
margin-bottom: 1;
|
|
}
|
|
|
|
CollapsibleSidebar .sidebar-content {
|
|
height: 1fr;
|
|
overflow-y: auto;
|
|
}
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*children: Widget,
|
|
collapsed: bool = False,
|
|
name: str | None = None,
|
|
id: str | None = None,
|
|
classes: str | None = None,
|
|
disabled: bool = False,
|
|
markup: bool = True,
|
|
) -> None:
|
|
"""Initialize collapsible sidebar."""
|
|
super().__init__(name=name, id=id, classes=classes, disabled=disabled, markup=markup)
|
|
self._collapsed: bool = collapsed
|
|
self._children: tuple[Widget, ...] = children
|
|
|
|
@override
|
|
def compose(self) -> ComposeResult:
|
|
"""Compose sidebar with toggle and content."""
|
|
yield Static("☰", classes="sidebar-toggle")
|
|
with VerticalScroll(classes="sidebar-content"):
|
|
yield from self._children
|
|
|
|
def on_mount(self) -> None:
|
|
"""Apply initial collapsed state."""
|
|
if self._collapsed:
|
|
cast(Widget, self).add_class("collapsed")
|
|
|
|
def on_click(self) -> None:
|
|
"""Toggle sidebar when clicked."""
|
|
self.toggle()
|
|
|
|
def toggle(self) -> None:
|
|
"""Toggle sidebar collapsed state."""
|
|
self._collapsed = not self._collapsed
|
|
widget = cast(Widget, self)
|
|
if self._collapsed:
|
|
widget.add_class("collapsed")
|
|
else:
|
|
widget.remove_class("collapsed")
|
|
|
|
def expand_sidebar(self) -> None:
|
|
"""Expand sidebar."""
|
|
if self._collapsed:
|
|
self.toggle()
|
|
|
|
def collapse_sidebar(self) -> None:
|
|
"""Collapse sidebar."""
|
|
if not self._collapsed:
|
|
self.toggle()
|
|
|
|
|
|
class TabularLayout(Container):
|
|
"""Optimized layout for data tables with optional sidebar."""
|
|
|
|
DEFAULT_CSS: str = """
|
|
TabularLayout {
|
|
layout: horizontal;
|
|
height: 100%;
|
|
}
|
|
|
|
TabularLayout .main-content {
|
|
width: 1fr;
|
|
height: 100%;
|
|
layout: vertical;
|
|
}
|
|
|
|
TabularLayout .table-container {
|
|
height: 1fr;
|
|
overflow: auto;
|
|
border: solid $border;
|
|
background: $surface;
|
|
}
|
|
|
|
TabularLayout .table-header {
|
|
dock: top;
|
|
height: 3;
|
|
background: $primary;
|
|
color: $text;
|
|
padding: 1;
|
|
}
|
|
|
|
TabularLayout .table-footer {
|
|
dock: bottom;
|
|
height: 3;
|
|
background: $surface-lighten-1;
|
|
padding: 1;
|
|
border-top: solid $border;
|
|
}
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
table_widget: Widget,
|
|
header_content: Widget | None = None,
|
|
footer_content: Widget | None = None,
|
|
sidebar_content: Widget | None = None,
|
|
name: str | None = None,
|
|
id: str | None = None,
|
|
classes: str | None = None,
|
|
disabled: bool = False,
|
|
markup: bool = True,
|
|
) -> None:
|
|
"""Initialize tabular layout."""
|
|
super().__init__(name=name, id=id, classes=classes, disabled=disabled, markup=markup)
|
|
self.table_widget: Widget = table_widget
|
|
self.header_content: Widget | None = header_content
|
|
self.footer_content: Widget | None = footer_content
|
|
self.sidebar_content: Widget | None = sidebar_content
|
|
|
|
@override
|
|
def compose(self) -> ComposeResult:
|
|
"""Compose layout with optional sidebar."""
|
|
if self.sidebar_content:
|
|
yield CollapsibleSidebar(self.sidebar_content)
|
|
|
|
with Container(classes="main-content"):
|
|
if self.header_content:
|
|
yield Container(self.header_content, classes="table-header")
|
|
|
|
yield Container(self.table_widget, classes="table-container")
|
|
|
|
if self.footer_content:
|
|
yield Container(self.footer_content, classes="table-footer")
|
|
|
|
|
|
class CardLayout(ResponsiveGrid):
|
|
"""Grid layout optimized for card-based content."""
|
|
|
|
DEFAULT_CSS: str = """
|
|
CardLayout {
|
|
grid-gutter: 2;
|
|
padding: 2;
|
|
}
|
|
|
|
CardLayout .card {
|
|
background: $surface;
|
|
border: solid $border;
|
|
border-radius: 1;
|
|
padding: 2;
|
|
height: auto;
|
|
min-height: 10;
|
|
}
|
|
|
|
CardLayout .card:hover {
|
|
border: solid $accent;
|
|
background: $surface-lighten-1;
|
|
}
|
|
|
|
CardLayout .card:focus {
|
|
border: solid $primary;
|
|
}
|
|
|
|
CardLayout .card-header {
|
|
dock: top;
|
|
height: 3;
|
|
background: $primary-lighten-1;
|
|
color: $text;
|
|
padding: 1;
|
|
margin: -2 -2 1 -2;
|
|
border-radius: 1 1 0 0;
|
|
}
|
|
|
|
CardLayout .card-content {
|
|
height: 1fr;
|
|
overflow: auto;
|
|
}
|
|
|
|
CardLayout .card-footer {
|
|
dock: bottom;
|
|
height: 3;
|
|
background: $surface-darken-1;
|
|
padding: 1;
|
|
margin: 1 -2 -2 -2;
|
|
border-radius: 0 0 1 1;
|
|
}
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str | None = None,
|
|
id: str | None = None,
|
|
classes: str | None = None,
|
|
disabled: bool = False,
|
|
markup: bool = True,
|
|
) -> None:
|
|
"""Initialize card layout with default settings for cards."""
|
|
# Default to auto-fit cards with minimum width
|
|
super().__init__(
|
|
auto_fit=True, name=name, id=id, classes=classes, disabled=disabled, markup=markup
|
|
)
|
|
|
|
|
|
class SplitPane(Container):
|
|
"""Resizable split pane layout."""
|
|
|
|
DEFAULT_CSS: str = """
|
|
SplitPane {
|
|
layout: horizontal;
|
|
height: 100%;
|
|
}
|
|
|
|
SplitPane.vertical {
|
|
layout: vertical;
|
|
}
|
|
|
|
SplitPane .left-pane,
|
|
SplitPane .top-pane {
|
|
width: 50%;
|
|
height: 50%;
|
|
background: $surface;
|
|
border-right: solid $border;
|
|
border-bottom: solid $border;
|
|
}
|
|
|
|
SplitPane .right-pane,
|
|
SplitPane .bottom-pane {
|
|
width: 50%;
|
|
height: 50%;
|
|
background: $surface;
|
|
}
|
|
|
|
SplitPane .splitter {
|
|
width: 1;
|
|
height: 1;
|
|
background: $border;
|
|
}
|
|
|
|
SplitPane.vertical .splitter {
|
|
width: 100%;
|
|
height: 1;
|
|
}
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
left_content: Widget,
|
|
right_content: Widget,
|
|
vertical: bool = False,
|
|
split_ratio: float = 0.5,
|
|
name: str | None = None,
|
|
id: str | None = None,
|
|
classes: str | None = None,
|
|
disabled: bool = False,
|
|
markup: bool = True,
|
|
) -> None:
|
|
"""Initialize split pane."""
|
|
super().__init__(name=name, id=id, classes=classes, disabled=disabled, markup=markup)
|
|
self._left_content: Widget = left_content
|
|
self._right_content: Widget = right_content
|
|
self._vertical: bool = vertical
|
|
self._split_ratio: float = split_ratio
|
|
|
|
@override
|
|
def compose(self) -> ComposeResult:
|
|
"""Compose split pane layout."""
|
|
if self._vertical:
|
|
cast(Widget, self).add_class("vertical")
|
|
|
|
pane_classes = (
|
|
("top-pane", "bottom-pane") if self._vertical else ("left-pane", "right-pane")
|
|
)
|
|
|
|
yield Container(self._left_content, classes=pane_classes[0])
|
|
yield Static("", classes="splitter")
|
|
yield Container(self._right_content, classes=pane_classes[1])
|
|
|
|
def on_mount(self) -> None:
|
|
"""Apply split ratio."""
|
|
widget = cast(Widget, self)
|
|
if self._vertical:
|
|
widget.query_one(".top-pane").styles.height = f"{self._split_ratio * 100}%"
|
|
widget.query_one(".bottom-pane").styles.height = f"{(1 - self._split_ratio) * 100}%"
|
|
else:
|
|
widget.query_one(".left-pane").styles.width = f"{self._split_ratio * 100}%"
|
|
widget.query_one(".right-pane").styles.width = f"{(1 - self._split_ratio) * 100}%"
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/cli/main.py">
|
|
"""CLI interface for ingestion pipeline."""
|
|
|
|
import asyncio
|
|
from typing import Annotated
|
|
|
|
import typer
|
|
from pydantic import SecretStr
|
|
from rich.console import Console
|
|
from rich.panel import Panel
|
|
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
|
|
from rich.table import Table
|
|
|
|
from ..config import configure_prefect, get_settings
|
|
from ..core.models import (
|
|
IngestionResult,
|
|
IngestionSource,
|
|
StorageBackend,
|
|
StorageConfig,
|
|
)
|
|
from ..flows.ingestion import create_ingestion_flow
|
|
from ..flows.scheduler import create_scheduled_deployment, serve_deployments
|
|
|
|
app = typer.Typer(
|
|
name="ingest",
|
|
help="🚀 Modern Document Ingestion Pipeline - Advanced web and repository processing",
|
|
rich_markup_mode="rich",
|
|
add_completion=False,
|
|
)
|
|
console = Console()
|
|
|
|
|
|
@app.callback()
|
|
def main(
|
|
version: Annotated[
|
|
bool, typer.Option("--version", "-v", help="Show version information")
|
|
] = False,
|
|
) -> None:
|
|
"""
|
|
🚀 Modern Document Ingestion Pipeline
|
|
|
|
[bold cyan]Advanced document processing and management platform[/bold cyan]
|
|
|
|
Features:
|
|
• 🌐 Web scraping and crawling with Firecrawl
|
|
• 📦 Repository ingestion with Repomix
|
|
• 🗄️ Multiple storage backends (Weaviate, OpenWebUI, R2R)
|
|
• 📊 Modern TUI for collection management
|
|
• ⚡ Async processing with Prefect orchestration
|
|
• 🎨 Rich CLI with enhanced visuals
|
|
"""
|
|
settings = get_settings()
|
|
configure_prefect(settings)
|
|
|
|
if version:
|
|
console.print(
|
|
Panel(
|
|
(
|
|
"[bold magenta]Ingest Pipeline v0.1.0[/bold magenta]\n"
|
|
"[dim]Modern Document Ingestion & Management System[/dim]"
|
|
),
|
|
title="🚀 Version Info",
|
|
border_style="magenta",
|
|
)
|
|
)
|
|
raise typer.Exit()
|
|
|
|
|
|
@app.command()
|
|
def ingest(
|
|
source_url: Annotated[str, typer.Argument(help="URL or path to ingest from")],
|
|
source_type: Annotated[
|
|
IngestionSource, typer.Option("--type", "-t", help="Type of source")
|
|
] = IngestionSource.WEB,
|
|
storage: Annotated[
|
|
StorageBackend, typer.Option("--storage", "-s", help="Storage backend")
|
|
] = StorageBackend.WEAVIATE,
|
|
collection: Annotated[
|
|
str | None,
|
|
typer.Option(
|
|
"--collection", "-c", help="Target collection name (auto-generated if not specified)"
|
|
),
|
|
] = None,
|
|
validate: Annotated[
|
|
bool, typer.Option("--validate/--no-validate", help="Validate source before ingesting")
|
|
] = True,
|
|
) -> None:
|
|
"""
|
|
🚀 Run a one-time ingestion job with enhanced progress tracking.
|
|
|
|
This command processes documents from various sources and stores them in
|
|
your chosen backend with full progress visualization.
|
|
"""
|
|
# Enhanced startup message
|
|
console.print(
|
|
Panel(
|
|
(
|
|
f"[bold cyan]🚀 Starting Modern Ingestion[/bold cyan]\n\n"
|
|
f"[yellow]Source:[/yellow] {source_url}\n"
|
|
f"[yellow]Type:[/yellow] {source_type.value.title()}\n"
|
|
f"[yellow]Storage:[/yellow] {storage.value.replace('_', ' ').title()}\n"
|
|
f"[yellow]Collection:[/yellow] {collection or '[dim]Auto-generated[/dim]'}"
|
|
),
|
|
title="🔥 Ingestion Configuration",
|
|
border_style="cyan",
|
|
)
|
|
)
|
|
|
|
async def run_with_progress() -> IngestionResult:
|
|
with Progress(
|
|
SpinnerColumn(),
|
|
TextColumn("[progress.description]{task.description}"),
|
|
BarColumn(),
|
|
TaskProgressColumn(),
|
|
console=console,
|
|
) as progress:
|
|
task = progress.add_task("🔄 Processing documents...", total=100)
|
|
|
|
# Simulate progress updates during ingestion
|
|
progress.update(task, advance=20, description="🔗 Connecting to services...")
|
|
await asyncio.sleep(0.5)
|
|
|
|
progress.update(task, advance=30, description="📄 Fetching documents...")
|
|
result = await run_ingestion(
|
|
url=source_url,
|
|
source_type=source_type,
|
|
storage_backend=storage,
|
|
collection_name=collection,
|
|
validate_first=validate,
|
|
)
|
|
|
|
progress.update(task, advance=50, description="✅ Ingestion complete!")
|
|
return result
|
|
|
|
# Use asyncio.run() with proper event loop handling
|
|
try:
|
|
result = asyncio.run(run_with_progress())
|
|
except RuntimeError as e:
|
|
if "asyncio.run() cannot be called from a running event loop" in str(e):
|
|
# If we're already in an event loop (e.g., in Jupyter), use nest_asyncio
|
|
try:
|
|
import nest_asyncio
|
|
|
|
nest_asyncio.apply()
|
|
result = asyncio.run(run_with_progress())
|
|
except ImportError:
|
|
# Fallback: get the current loop and run the coroutine
|
|
loop = asyncio.get_event_loop()
|
|
result = loop.run_until_complete(run_with_progress())
|
|
else:
|
|
raise
|
|
|
|
# Enhanced results display
|
|
status_color = "green" if result.status.value == "completed" else "red"
|
|
|
|
# Create results table with enhanced styling
|
|
table = Table(
|
|
title="📊 Ingestion Results",
|
|
title_style="bold magenta",
|
|
border_style="cyan",
|
|
header_style="bold blue",
|
|
)
|
|
table.add_column("📋 Metric", style="cyan", no_wrap=True)
|
|
table.add_column("📈 Value", style=status_color, justify="right")
|
|
|
|
# Add enhanced status icon
|
|
status_icon = "✅" if result.status.value == "completed" else "❌"
|
|
table.add_row("Status", f"{status_icon} {result.status.value.title()}")
|
|
|
|
table.add_row("Documents Processed", f"📄 {result.documents_processed:,}")
|
|
table.add_row("Documents Failed", f"⚠️ {result.documents_failed:,}")
|
|
table.add_row("Duration", f"⏱️ {result.duration_seconds:.2f}s")
|
|
|
|
if result.error_messages:
|
|
error_text = "\n".join(f"❌ {error}" for error in result.error_messages[:3])
|
|
if len(result.error_messages) > 3:
|
|
error_text += f"\n... and {len(result.error_messages) - 3} more errors"
|
|
table.add_row("Errors", error_text)
|
|
|
|
console.print(table)
|
|
|
|
# Success celebration or error guidance
|
|
if result.status.value == "completed" and result.documents_processed > 0:
|
|
console.print(
|
|
Panel(
|
|
(
|
|
f"🎉 [bold green]Success![/bold green] {result.documents_processed} documents ingested\n\n"
|
|
f"💡 [dim]Try '[bold cyan]ingest modern[/bold cyan]' to explore your collections![/dim]"
|
|
),
|
|
title="✨ Ingestion Complete",
|
|
border_style="green",
|
|
)
|
|
)
|
|
elif result.error_messages:
|
|
console.print(
|
|
Panel(
|
|
(
|
|
"❌ [bold red]Ingestion encountered errors[/bold red]\n\n"
|
|
"💡 [dim]Check your configuration and try again[/dim]"
|
|
),
|
|
title="⚠️ Issues Detected",
|
|
border_style="red",
|
|
)
|
|
)
|
|
|
|
|
|
@app.command()
|
|
def schedule(
|
|
name: Annotated[str, typer.Argument(help="Deployment name")],
|
|
source_url: Annotated[str, typer.Argument(help="URL or path to ingest from")],
|
|
source_type: Annotated[
|
|
IngestionSource, typer.Option("--type", "-t", help="Type of source")
|
|
] = IngestionSource.WEB,
|
|
storage: Annotated[
|
|
StorageBackend, typer.Option("--storage", "-s", help="Storage backend")
|
|
] = StorageBackend.WEAVIATE,
|
|
cron: Annotated[
|
|
str | None, typer.Option("--cron", "-c", help="Cron expression for scheduling")
|
|
] = None,
|
|
interval: Annotated[int, typer.Option("--interval", "-i", help="Interval in minutes")] = 60,
|
|
serve_now: Annotated[
|
|
bool, typer.Option("--serve/--no-serve", help="Start serving immediately")
|
|
] = False,
|
|
) -> None:
|
|
"""
|
|
Create a scheduled deployment for recurring ingestion.
|
|
"""
|
|
console.print(f"[bold blue]Creating deployment: {name}[/bold blue]")
|
|
|
|
deployment = create_scheduled_deployment(
|
|
name=name,
|
|
source_url=source_url,
|
|
source_type=source_type,
|
|
storage_backend=storage,
|
|
schedule_type="cron" if cron else "interval",
|
|
cron_expression=cron,
|
|
interval_minutes=interval,
|
|
)
|
|
|
|
console.print(f"[green]✓ Deployment '{name}' created[/green]")
|
|
|
|
if serve_now:
|
|
console.print("[yellow]Starting deployment server...[/yellow]")
|
|
serve_deployments([deployment])
|
|
|
|
|
|
@app.command()
|
|
def serve(
|
|
config_file: Annotated[
|
|
str | None, typer.Option("--config", "-c", help="Path to deployments config file")
|
|
] = None,
|
|
ui: Annotated[
|
|
str | None, typer.Option("--ui", help="Launch user interface (options: tui, web)")
|
|
] = None,
|
|
) -> None:
|
|
"""
|
|
🚀 Serve configured deployments with optional UI interface.
|
|
|
|
Launch the deployment server to run scheduled ingestion jobs,
|
|
optionally with a modern Terminal User Interface (TUI) or web interface.
|
|
"""
|
|
# Handle UI mode first
|
|
if ui == "tui":
|
|
console.print(
|
|
Panel(
|
|
(
|
|
"[bold cyan]🚀 Launching Enhanced TUI[/bold cyan]\n\n"
|
|
"[yellow]Features:[/yellow]\n"
|
|
"• 📊 Interactive collection management\n"
|
|
"• ⌨️ Enhanced keyboard navigation\n"
|
|
"• 🎨 Modern design with focus indicators\n"
|
|
"• 📄 Document browsing and search\n"
|
|
"• 🔄 Real-time status updates"
|
|
),
|
|
title="🎉 TUI Mode",
|
|
border_style="cyan",
|
|
)
|
|
)
|
|
from .tui import dashboard
|
|
|
|
dashboard()
|
|
return
|
|
elif ui == "web":
|
|
console.print("[red]Web UI not yet implemented. Use --ui tui for Terminal UI.[/red]")
|
|
return
|
|
elif ui:
|
|
console.print(f"[red]Unknown UI option: {ui}[/red]")
|
|
console.print("[yellow]Available options: tui, web[/yellow]")
|
|
return
|
|
|
|
# Normal deployment server mode
|
|
if config_file:
|
|
# Load deployments from config
|
|
console.print(f"[yellow]Loading deployments from {config_file}[/yellow]")
|
|
# Implementation would load YAML/JSON config
|
|
else:
|
|
# Create example deployments
|
|
deployments = [
|
|
create_scheduled_deployment(
|
|
name="docs-daily",
|
|
source_url="https://docs.example.com",
|
|
source_type="documentation",
|
|
storage_backend="weaviate",
|
|
schedule_type="cron",
|
|
cron_expression="0 2 * * *", # Daily at 2 AM
|
|
),
|
|
create_scheduled_deployment(
|
|
name="repo-hourly",
|
|
source_url="https://github.com/example/repo",
|
|
source_type="repository",
|
|
storage_backend="open_webui",
|
|
schedule_type="interval",
|
|
interval_minutes=60,
|
|
),
|
|
]
|
|
|
|
console.print(
|
|
"[bold green]Starting deployment server with example deployments[/bold green]"
|
|
)
|
|
serve_deployments(deployments)
|
|
|
|
|
|
@app.command()
|
|
def tui() -> None:
|
|
"""
|
|
🚀 Launch the enhanced Terminal User Interface.
|
|
|
|
Quick shortcut for 'serve --ui tui' with modern keyboard navigation,
|
|
interactive collection management, and real-time status updates.
|
|
"""
|
|
console.print(
|
|
Panel(
|
|
(
|
|
"[bold cyan]🚀 Launching Enhanced TUI[/bold cyan]\n\n"
|
|
"[yellow]Features:[/yellow]\n"
|
|
"• 📊 Interactive collection management\n"
|
|
"• ⌨️ Enhanced keyboard navigation\n"
|
|
"• 🎨 Modern design with focus indicators\n"
|
|
"• 📄 Document browsing and search\n"
|
|
"• 🔄 Real-time status updates"
|
|
),
|
|
title="🎉 TUI Mode",
|
|
border_style="cyan",
|
|
)
|
|
)
|
|
from .tui import dashboard
|
|
|
|
dashboard()
|
|
|
|
|
|
@app.command()
|
|
def config() -> None:
|
|
"""
|
|
📋 Display current configuration with enhanced formatting.
|
|
|
|
Shows all configured endpoints, models, and settings in a beautiful
|
|
table format with status indicators.
|
|
"""
|
|
settings = get_settings()
|
|
|
|
console.print(
|
|
Panel(
|
|
(
|
|
"[bold cyan]⚙️ System Configuration[/bold cyan]\n"
|
|
"[dim]Current pipeline settings and endpoints[/dim]"
|
|
),
|
|
title="🔧 Configuration",
|
|
border_style="cyan",
|
|
)
|
|
)
|
|
|
|
# Enhanced configuration table
|
|
table = Table(
|
|
title="📊 Configuration Details",
|
|
title_style="bold magenta",
|
|
border_style="blue",
|
|
header_style="bold cyan",
|
|
show_lines=True,
|
|
)
|
|
table.add_column("🏷️ Setting", style="cyan", no_wrap=True, width=25)
|
|
table.add_column("🎯 Value", style="yellow", overflow="fold")
|
|
table.add_column("📊 Status", style="green", width=12, justify="center")
|
|
|
|
# Add configuration rows with status indicators
|
|
def get_status_indicator(value: str | None) -> str:
|
|
return "✅ Set" if value else "❌ Missing"
|
|
|
|
table.add_row("🤖 LLM Endpoint", str(settings.llm_endpoint), "✅ Active")
|
|
table.add_row("🔥 Firecrawl Endpoint", str(settings.firecrawl_endpoint), "✅ Active")
|
|
table.add_row(
|
|
"🗄️ Weaviate Endpoint",
|
|
str(settings.weaviate_endpoint),
|
|
get_status_indicator(str(settings.weaviate_api_key) if settings.weaviate_api_key else None),
|
|
)
|
|
table.add_row(
|
|
"🌐 OpenWebUI Endpoint",
|
|
str(settings.openwebui_endpoint),
|
|
get_status_indicator(settings.openwebui_api_key),
|
|
)
|
|
table.add_row("🧠 Embedding Model", settings.embedding_model, "✅ Set")
|
|
table.add_row("💾 Default Storage", settings.default_storage_backend.title(), "✅ Set")
|
|
table.add_row("📦 Default Batch Size", f"{settings.default_batch_size:,}", "✅ Set")
|
|
table.add_row("⚡ Max Concurrent Tasks", f"{settings.max_concurrent_tasks}", "✅ Set")
|
|
|
|
console.print(table)
|
|
|
|
# Additional helpful information
|
|
console.print(
|
|
Panel(
|
|
(
|
|
"💡 [bold cyan]Quick Tips[/bold cyan]\n\n"
|
|
"• Use '[bold]ingest list-collections[/bold]' to view all collections\n"
|
|
"• Use '[bold]ingest search[/bold]' to search content\n"
|
|
"• Configure API keys in your [yellow].env[/yellow] file\n"
|
|
"• Default collection names are auto-generated from URLs"
|
|
),
|
|
title="🚀 Usage Tips",
|
|
border_style="green",
|
|
)
|
|
)
|
|
|
|
|
|
@app.command()
|
|
def list_collections() -> None:
|
|
"""
|
|
📋 List all collections across storage backends.
|
|
"""
|
|
console.print("[bold cyan]📚 Collection Overview[/bold cyan]")
|
|
asyncio.run(run_list_collections())
|
|
|
|
|
|
@app.command()
|
|
def search(
|
|
query: Annotated[str, typer.Argument(help="Search query")],
|
|
collection: Annotated[
|
|
str | None, typer.Option("--collection", "-c", help="Target collection")
|
|
] = None,
|
|
backend: Annotated[
|
|
StorageBackend, typer.Option("--backend", "-b", help="Storage backend")
|
|
] = StorageBackend.WEAVIATE,
|
|
limit: Annotated[int, typer.Option("--limit", "-l", help="Result limit")] = 10,
|
|
) -> None:
|
|
"""
|
|
🔍 Search across collections.
|
|
"""
|
|
console.print(f"[bold cyan]🔍 Searching for: {query}[/bold cyan]")
|
|
asyncio.run(run_search(query, collection, backend.value, limit))
|
|
|
|
|
|
@app.command(name="blocks")
|
|
def blocks_command() -> None:
|
|
"""🧩 List and manage Prefect Blocks."""
|
|
console.print("[bold cyan]📦 Prefect Blocks Management[/bold cyan]")
|
|
console.print(
|
|
"Use 'prefect block register --module ingest_pipeline.core.models' to register custom blocks"
|
|
)
|
|
console.print("Use 'prefect block ls' to list available blocks")
|
|
|
|
|
|
@app.command(name="variables")
|
|
def variables_command() -> None:
|
|
"""📊 Manage Prefect Variables."""
|
|
console.print("[bold cyan]📊 Prefect Variables Management[/bold cyan]")
|
|
console.print("Use 'prefect variable set VARIABLE_NAME value' to set variables")
|
|
console.print("Use 'prefect variable ls' to list variables")
|
|
|
|
|
|
async def run_ingestion(
|
|
url: str,
|
|
source_type: IngestionSource,
|
|
storage_backend: StorageBackend,
|
|
collection_name: str | None = None,
|
|
validate_first: bool = True,
|
|
) -> IngestionResult:
|
|
"""
|
|
Run ingestion with support for targeted collections.
|
|
"""
|
|
# Auto-generate collection name if not provided
|
|
if not collection_name:
|
|
from urllib.parse import urlparse
|
|
|
|
parsed = urlparse(url)
|
|
domain = parsed.netloc.replace(".", "_").replace("-", "_")
|
|
collection_name = f"{domain}_{source_type.value}"
|
|
|
|
return await create_ingestion_flow(
|
|
source_url=url,
|
|
source_type=source_type,
|
|
storage_backend=storage_backend,
|
|
collection_name=collection_name,
|
|
validate_first=validate_first,
|
|
)
|
|
|
|
|
|
async def run_list_collections() -> None:
|
|
"""
|
|
List collections across storage backends.
|
|
"""
|
|
from ..config import get_settings
|
|
from ..core.models import StorageBackend
|
|
from ..storage.openwebui import OpenWebUIStorage
|
|
from ..storage.weaviate import WeaviateStorage
|
|
|
|
settings = get_settings()
|
|
|
|
console.print("🔍 [bold cyan]Scanning storage backends...[/bold cyan]")
|
|
|
|
# Try to connect to Weaviate
|
|
weaviate_collections: list[tuple[str, int]] = []
|
|
try:
|
|
weaviate_config = StorageConfig(
|
|
backend=StorageBackend.WEAVIATE,
|
|
endpoint=settings.weaviate_endpoint,
|
|
api_key=SecretStr(settings.weaviate_api_key)
|
|
if settings.weaviate_api_key is not None
|
|
else None,
|
|
collection_name="default",
|
|
)
|
|
weaviate = WeaviateStorage(weaviate_config)
|
|
await weaviate.initialize()
|
|
|
|
overview = await weaviate.describe_collections()
|
|
for item in overview:
|
|
name = str(item.get("name", "Unknown"))
|
|
count_val = item.get("count", 0)
|
|
count = int(count_val) if isinstance(count_val, (int, str)) else 0
|
|
weaviate_collections.append((name, count))
|
|
except Exception as e:
|
|
console.print(f"❌ [red]Weaviate connection failed: {e}[/red]")
|
|
|
|
# Try to connect to OpenWebUI
|
|
openwebui_collections: list[tuple[str, int]] = []
|
|
try:
|
|
openwebui_config = StorageConfig(
|
|
backend=StorageBackend.OPEN_WEBUI,
|
|
endpoint=settings.openwebui_endpoint,
|
|
api_key=SecretStr(settings.openwebui_api_key)
|
|
if settings.openwebui_api_key is not None
|
|
else None,
|
|
collection_name="default",
|
|
)
|
|
openwebui = OpenWebUIStorage(openwebui_config)
|
|
await openwebui.initialize()
|
|
|
|
overview = await openwebui.describe_collections()
|
|
for item in overview:
|
|
name = str(item.get("name", "Unknown"))
|
|
count_val = item.get("count", 0)
|
|
count = int(count_val) if isinstance(count_val, (int, str)) else 0
|
|
openwebui_collections.append((name, count))
|
|
except Exception as e:
|
|
console.print(f"❌ [red]OpenWebUI connection failed: {e}[/red]")
|
|
|
|
# Display results
|
|
if weaviate_collections or openwebui_collections:
|
|
# Create results table
|
|
from rich.table import Table
|
|
|
|
table = Table(
|
|
title="📚 Collection Overview",
|
|
title_style="bold magenta",
|
|
border_style="cyan",
|
|
header_style="bold blue",
|
|
)
|
|
table.add_column("🏷️ Collection", style="cyan", no_wrap=True)
|
|
table.add_column("📊 Backend", style="yellow")
|
|
table.add_column("📄 Documents", style="green", justify="right")
|
|
|
|
# Add Weaviate collections
|
|
for name, count in weaviate_collections:
|
|
table.add_row(name, "🗄️ Weaviate", f"{count:,}")
|
|
|
|
# Add OpenWebUI collections
|
|
for name, count in openwebui_collections:
|
|
table.add_row(name, "🌐 OpenWebUI", f"{count:,}")
|
|
|
|
console.print(table)
|
|
else:
|
|
console.print("❌ [yellow]No collections found in any backend[/yellow]")
|
|
|
|
|
|
async def run_search(query: str, collection: str | None, backend: str, limit: int) -> None:
|
|
"""
|
|
Search across collections.
|
|
"""
|
|
from ..config import get_settings
|
|
from ..core.models import StorageBackend
|
|
from ..storage.weaviate import WeaviateStorage
|
|
|
|
settings = get_settings()
|
|
|
|
console.print(f"🔍 Searching for: '[bold cyan]{query}[/bold cyan]'")
|
|
if collection:
|
|
console.print(f"📚 Target collection: [yellow]{collection}[/yellow]")
|
|
console.print(f"💾 Backend: [blue]{backend}[/blue]")
|
|
|
|
results = []
|
|
|
|
try:
|
|
if backend == "weaviate":
|
|
weaviate_config = StorageConfig(
|
|
backend=StorageBackend.WEAVIATE,
|
|
endpoint=settings.weaviate_endpoint,
|
|
api_key=SecretStr(settings.weaviate_api_key)
|
|
if settings.weaviate_api_key is not None
|
|
else None,
|
|
collection_name=collection or "default",
|
|
)
|
|
weaviate = WeaviateStorage(weaviate_config)
|
|
await weaviate.initialize()
|
|
|
|
results_generator = weaviate.search(query, limit=limit)
|
|
async for doc in results_generator:
|
|
results.append(
|
|
{
|
|
"title": getattr(doc, "title", "Untitled"),
|
|
"content": getattr(doc, "content", ""),
|
|
"score": getattr(doc, "score", 0.0),
|
|
"backend": "🗄️ Weaviate",
|
|
}
|
|
)
|
|
|
|
elif backend == "open_webui":
|
|
console.print("❌ [red]OpenWebUI search not yet implemented[/red]")
|
|
return
|
|
|
|
except Exception as e:
|
|
console.print(f"❌ [red]Search failed: {e}[/red]")
|
|
return
|
|
|
|
# Display results
|
|
if results:
|
|
from rich.table import Table
|
|
|
|
table = Table(
|
|
title=f"🔍 Search Results for '{query}'",
|
|
title_style="bold magenta",
|
|
border_style="green",
|
|
header_style="bold blue",
|
|
)
|
|
table.add_column("📄 Title", style="cyan", max_width=40)
|
|
table.add_column("📝 Preview", style="white", max_width=60)
|
|
table.add_column("📊 Score", style="yellow", justify="right")
|
|
|
|
for result in results[:limit]:
|
|
title = str(result["title"])
|
|
title_display = title[:40] + "..." if len(title) > 40 else title
|
|
|
|
content = str(result["content"])
|
|
content_display = content[:60] + "..." if len(content) > 60 else content
|
|
|
|
score = f"{result['score']:.3f}"
|
|
|
|
table.add_row(title_display, content_display, score)
|
|
|
|
console.print(table)
|
|
console.print(f"\n✅ [green]Found {len(results)} results[/green]")
|
|
else:
|
|
console.print("❌ [yellow]No results found[/yellow]")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app()
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/config/__init__.py">
|
|
"""Configuration management utilities."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from contextlib import ExitStack
|
|
|
|
from prefect.settings import Setting, temporary_settings
|
|
|
|
|
|
# Import Prefect settings with version compatibility - avoid static analysis issues
|
|
def _setup_prefect_settings() -> tuple[object, object, object]:
|
|
"""Setup Prefect settings with proper fallbacks."""
|
|
try:
|
|
import prefect.settings as ps
|
|
|
|
# Try to get the settings directly
|
|
api_key = getattr(ps, "PREFECT_API_KEY", None)
|
|
api_url = getattr(ps, "PREFECT_API_URL", None)
|
|
work_pool = getattr(ps, "PREFECT_DEFAULT_WORK_POOL_NAME", None)
|
|
|
|
if api_key is not None:
|
|
return api_key, api_url, work_pool
|
|
|
|
# Fallback to registry-based approach
|
|
registry = getattr(ps, "PREFECT_SETTING_REGISTRY", None)
|
|
if registry is not None:
|
|
Setting = getattr(ps, "Setting", None)
|
|
if Setting is not None:
|
|
api_key = registry.get("PREFECT_API_KEY") or Setting(
|
|
"PREFECT_API_KEY", type_=str, default=None
|
|
)
|
|
api_url = registry.get("PREFECT_API_URL") or Setting(
|
|
"PREFECT_API_URL", type_=str, default=None
|
|
)
|
|
work_pool = registry.get("PREFECT_DEFAULT_WORK_POOL_NAME") or Setting(
|
|
"PREFECT_DEFAULT_WORK_POOL_NAME", type_=str, default=None
|
|
)
|
|
return api_key, api_url, work_pool
|
|
|
|
except ImportError:
|
|
pass
|
|
|
|
# Ultimate fallback
|
|
return None, None, None
|
|
|
|
|
|
PREFECT_API_KEY, PREFECT_API_URL, PREFECT_DEFAULT_WORK_POOL_NAME = _setup_prefect_settings()
|
|
|
|
# Import after Prefect settings setup to avoid circular dependencies
|
|
from .settings import Settings, get_settings # noqa: E402
|
|
|
|
__all__ = ["Settings", "get_settings", "configure_prefect"]
|
|
|
|
_prefect_settings_stack: ExitStack | None = None
|
|
|
|
|
|
def configure_prefect(settings: Settings) -> None:
|
|
"""Apply Prefect settings from the application configuration."""
|
|
global _prefect_settings_stack
|
|
|
|
overrides: dict[Setting, str] = {}
|
|
|
|
if (
|
|
settings.prefect_api_url is not None
|
|
and PREFECT_API_URL is not None
|
|
and isinstance(PREFECT_API_URL, Setting)
|
|
):
|
|
overrides[PREFECT_API_URL] = str(settings.prefect_api_url)
|
|
if (
|
|
settings.prefect_api_key
|
|
and PREFECT_API_KEY is not None
|
|
and isinstance(PREFECT_API_KEY, Setting)
|
|
):
|
|
overrides[PREFECT_API_KEY] = settings.prefect_api_key
|
|
if (
|
|
settings.prefect_work_pool
|
|
and PREFECT_DEFAULT_WORK_POOL_NAME is not None
|
|
and isinstance(PREFECT_DEFAULT_WORK_POOL_NAME, Setting)
|
|
):
|
|
overrides[PREFECT_DEFAULT_WORK_POOL_NAME] = settings.prefect_work_pool
|
|
|
|
if not overrides:
|
|
return
|
|
|
|
filtered_overrides = {
|
|
setting: value for setting, value in overrides.items() if setting.value() != value
|
|
}
|
|
|
|
if not filtered_overrides:
|
|
return
|
|
|
|
new_stack = ExitStack()
|
|
new_stack.enter_context(temporary_settings(updates=filtered_overrides))
|
|
|
|
if _prefect_settings_stack is not None:
|
|
_prefect_settings_stack.close()
|
|
|
|
_prefect_settings_stack = new_stack
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/ingestors/base.py">
|
|
"""Base ingestor interface."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import AsyncGenerator
|
|
from typing import TYPE_CHECKING
|
|
|
|
from ..core.models import Document, IngestionJob
|
|
|
|
if TYPE_CHECKING:
|
|
from ..storage.base import BaseStorage
|
|
|
|
|
|
class BaseIngestor(ABC):
|
|
"""Abstract base class for all ingestors."""
|
|
|
|
@abstractmethod
|
|
def ingest(self, job: IngestionJob) -> AsyncGenerator[Document, None]:
|
|
"""
|
|
Ingest data from a source.
|
|
|
|
Args:
|
|
job: The ingestion job configuration
|
|
|
|
Yields:
|
|
Documents from the source
|
|
"""
|
|
... # pragma: no cover
|
|
|
|
@abstractmethod
|
|
async def validate_source(self, source_url: str) -> bool:
|
|
"""
|
|
Validate if the source is accessible.
|
|
|
|
Args:
|
|
source_url: URL or path to the source
|
|
|
|
Returns:
|
|
True if source is valid and accessible
|
|
"""
|
|
pass # pragma: no cover
|
|
|
|
@abstractmethod
|
|
async def estimate_size(self, source_url: str) -> int:
|
|
"""
|
|
Estimate the number of documents in the source.
|
|
|
|
Args:
|
|
source_url: URL or path to the source
|
|
|
|
Returns:
|
|
Estimated number of documents
|
|
"""
|
|
pass # pragma: no cover
|
|
|
|
async def ingest_with_dedup(
|
|
self,
|
|
job: IngestionJob,
|
|
storage_client: BaseStorage,
|
|
*,
|
|
collection_name: str | None = None,
|
|
stale_after_days: int = 30,
|
|
) -> AsyncGenerator[Document, None]:
|
|
"""
|
|
Ingest documents with duplicate detection (optional optimization).
|
|
|
|
Default implementation falls back to regular ingestion.
|
|
Subclasses can override to provide optimized deduplication.
|
|
|
|
Args:
|
|
job: The ingestion job configuration
|
|
storage_client: Storage client to check for existing documents
|
|
collection_name: Collection to check for duplicates
|
|
stale_after_days: Consider documents stale after this many days
|
|
|
|
Yields:
|
|
Documents from the source (with deduplication if implemented)
|
|
"""
|
|
# Default implementation: fall back to regular ingestion
|
|
async for document in self.ingest(job):
|
|
yield document
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/cli/tui/screens/documents.py">
|
|
"""Document management screen with enhanced navigation."""
|
|
|
|
from datetime import datetime
|
|
|
|
from textual.app import ComposeResult
|
|
from textual.binding import Binding
|
|
from textual.containers import Container, Horizontal, ScrollableContainer
|
|
from textual.screen import ModalScreen, Screen
|
|
from textual.widgets import Button, Footer, Header, Label, LoadingIndicator, Markdown, Static
|
|
from typing_extensions import override
|
|
|
|
from ....storage.base import BaseStorage
|
|
from ..models import CollectionInfo, DocumentInfo
|
|
from ..widgets import EnhancedDataTable
|
|
|
|
|
|
class DocumentManagementScreen(Screen[None]):
|
|
"""Screen for managing documents within a collection with enhanced keyboard navigation."""
|
|
|
|
collection: CollectionInfo
|
|
storage: BaseStorage | None
|
|
documents: list[DocumentInfo]
|
|
selected_docs: set[str]
|
|
current_offset: int
|
|
page_size: int
|
|
|
|
BINDINGS = [
|
|
Binding("escape", "app.pop_screen", "Back"),
|
|
Binding("r", "refresh", "Refresh"),
|
|
Binding("v", "view_document", "View"),
|
|
Binding("delete", "delete_selected", "Delete Selected"),
|
|
Binding("a", "select_all", "Select All"),
|
|
Binding("ctrl+a", "select_all", "Select All"),
|
|
Binding("n", "select_none", "Clear Selection"),
|
|
Binding("ctrl+shift+a", "select_none", "Clear Selection"),
|
|
Binding("space", "toggle_selection", "Toggle Selection"),
|
|
Binding("ctrl+d", "delete_selected", "Delete Selected"),
|
|
Binding("pageup", "prev_page", "Previous Page"),
|
|
Binding("pagedown", "next_page", "Next Page"),
|
|
Binding("home", "first_page", "First Page"),
|
|
Binding("end", "last_page", "Last Page"),
|
|
]
|
|
|
|
def __init__(self, collection: CollectionInfo, storage: BaseStorage | None):
|
|
super().__init__()
|
|
self.collection = collection
|
|
self.storage = storage
|
|
self.documents: list[DocumentInfo] = []
|
|
self.selected_docs: set[str] = set()
|
|
self.current_offset = 0
|
|
self.page_size = 50
|
|
|
|
@override
|
|
def compose(self) -> ComposeResult:
|
|
yield Header()
|
|
yield Container(
|
|
Static(f"📄 Document Management: {self.collection['name']}", classes="title"),
|
|
Static(
|
|
f"Total Documents: {self.collection['count']:,} | Use Space to select, Delete to remove",
|
|
classes="subtitle",
|
|
),
|
|
Label(f"Page size: {self.page_size} documents"),
|
|
EnhancedDataTable(id="documents_table", classes="enhanced-table"),
|
|
Horizontal(
|
|
Button("🔄 Refresh", id="refresh_docs_btn", variant="primary"),
|
|
Button("🗑️ Delete Selected", id="delete_selected_btn", variant="error"),
|
|
Button("✅ Select All", id="select_all_btn", variant="default"),
|
|
Button("❌ Clear Selection", id="clear_selection_btn", variant="default"),
|
|
Button("⬅️ Previous Page", id="prev_page_btn", variant="default"),
|
|
Button("➡️ Next Page", id="next_page_btn", variant="default"),
|
|
classes="button_bar",
|
|
),
|
|
Label("", id="selection_status"),
|
|
Static("", id="page_info", classes="status-text"),
|
|
LoadingIndicator(id="loading"),
|
|
classes="main_container",
|
|
)
|
|
yield Footer()
|
|
|
|
async def on_mount(self) -> None:
|
|
"""Initialize the screen."""
|
|
self.query_one("#loading").display = False
|
|
|
|
# Setup documents table with enhanced columns
|
|
table = self.query_one("#documents_table", EnhancedDataTable)
|
|
table.add_columns(
|
|
"✓", "Title", "Source URL", "Description", "Type", "Words", "Timestamp", "ID"
|
|
)
|
|
|
|
# Set up message handling for table events
|
|
table.can_focus = True
|
|
|
|
await self.load_documents()
|
|
|
|
async def load_documents(self) -> None:
|
|
"""Load documents from the collection."""
|
|
loading = self.query_one("#loading")
|
|
loading.display = True
|
|
|
|
try:
|
|
if self.storage:
|
|
# Try to load documents using the storage backend
|
|
try:
|
|
raw_docs = await self.storage.list_documents(
|
|
limit=self.page_size,
|
|
offset=self.current_offset,
|
|
collection_name=self.collection["name"],
|
|
)
|
|
# Cast to proper type with type checking
|
|
self.documents = [
|
|
DocumentInfo(
|
|
id=str(doc.get("id", f"doc_{i}")),
|
|
title=str(doc.get("title", "Untitled Document")),
|
|
source_url=str(doc.get("source_url", "")),
|
|
description=str(doc.get("description", "")),
|
|
content_type=str(doc.get("content_type", "text/plain")),
|
|
content_preview=str(doc.get("content_preview", "")),
|
|
word_count=(
|
|
lambda wc_val: int(wc_val)
|
|
if isinstance(wc_val, (int, str)) and str(wc_val).isdigit()
|
|
else 0
|
|
)(doc.get("word_count", 0)),
|
|
timestamp=str(doc.get("timestamp", "")),
|
|
)
|
|
for i, doc in enumerate(raw_docs)
|
|
]
|
|
except NotImplementedError:
|
|
# For storage backends that don't support document listing, show a message
|
|
self.notify(
|
|
f"Document listing not supported for {self.storage.__class__.__name__}",
|
|
severity="information",
|
|
)
|
|
self.documents = []
|
|
|
|
await self.update_table()
|
|
self.update_selection_status()
|
|
self.update_page_info()
|
|
|
|
except Exception as e:
|
|
self.notify(f"Error loading documents: {e}", severity="error", markup=False)
|
|
finally:
|
|
loading.display = False
|
|
|
|
async def update_table(self) -> None:
|
|
"""Update the documents table with enhanced metadata display."""
|
|
table = self.query_one("#documents_table", EnhancedDataTable)
|
|
table.clear(columns=True)
|
|
|
|
# Add enhanced columns with more metadata
|
|
table.add_columns(
|
|
"✓", "Title", "Source URL", "Description", "Type", "Words", "Timestamp", "ID"
|
|
)
|
|
|
|
# Add rows with enhanced metadata
|
|
for doc in self.documents:
|
|
selected = "✓" if doc["id"] in self.selected_docs else ""
|
|
|
|
# Get additional metadata from the raw docs
|
|
description = str(doc.get("description") or "").strip()[:40]
|
|
if not description:
|
|
description = "[dim]No description[/dim]"
|
|
elif len(str(doc.get("description") or "")) > 40:
|
|
description += "..."
|
|
|
|
# Format content type with appropriate icon
|
|
content_type = doc.get("content_type", "text/plain")
|
|
if "markdown" in content_type.lower():
|
|
type_display = "📝 md"
|
|
elif "html" in content_type.lower():
|
|
type_display = "🌐 html"
|
|
elif "text" in content_type.lower():
|
|
type_display = "📄 txt"
|
|
else:
|
|
type_display = f"📄 {content_type.split('/')[-1][:5]}"
|
|
|
|
# Format timestamp to be more readable
|
|
timestamp = doc.get("timestamp", "")
|
|
if timestamp:
|
|
try:
|
|
# Parse ISO format timestamp
|
|
dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
|
|
timestamp = dt.strftime("%m/%d %H:%M")
|
|
except Exception:
|
|
timestamp = str(timestamp)[:16] # Fallback
|
|
table.add_row(
|
|
selected,
|
|
doc.get("title", "Untitled")[:40],
|
|
doc.get("source_url", "")[:35],
|
|
description,
|
|
type_display,
|
|
str(doc.get("word_count", 0)),
|
|
timestamp,
|
|
doc["id"][:8] + "...", # Show truncated ID
|
|
)
|
|
|
|
def update_selection_status(self) -> None:
|
|
"""Update the selection status label."""
|
|
status_label = self.query_one("#selection_status", Label)
|
|
total_selected = len(self.selected_docs)
|
|
status_label.update(f"Selected: {total_selected} documents")
|
|
|
|
def update_page_info(self) -> None:
|
|
"""Update the page information."""
|
|
page_info = self.query_one("#page_info", Static)
|
|
total_docs = self.collection["count"]
|
|
start = self.current_offset + 1
|
|
end = min(self.current_offset + len(self.documents), total_docs)
|
|
page_num = (self.current_offset // self.page_size) + 1
|
|
total_pages = (total_docs + self.page_size - 1) // self.page_size
|
|
|
|
page_info.update(
|
|
f"Showing {start:,}-{end:,} of {total_docs:,} documents (Page {page_num} of {total_pages})"
|
|
)
|
|
|
|
def get_current_document(self) -> DocumentInfo | None:
|
|
"""Get the currently selected document."""
|
|
table = self.query_one("#documents_table", EnhancedDataTable)
|
|
try:
|
|
if 0 <= table.cursor_coordinate.row < len(self.documents):
|
|
return self.documents[table.cursor_coordinate.row]
|
|
except (AttributeError, IndexError):
|
|
pass
|
|
return None
|
|
|
|
# Action methods
|
|
def action_refresh(self) -> None:
|
|
"""Refresh the document list."""
|
|
self.run_worker(self.load_documents())
|
|
|
|
def action_toggle_selection(self) -> None:
|
|
"""Toggle selection of current row."""
|
|
if doc := self.get_current_document():
|
|
doc_id = doc["id"]
|
|
if doc_id in self.selected_docs:
|
|
self.selected_docs.remove(doc_id)
|
|
else:
|
|
self.selected_docs.add(doc_id)
|
|
|
|
self.run_worker(self.update_table())
|
|
self.update_selection_status()
|
|
|
|
def action_select_all(self) -> None:
|
|
"""Select all documents on current page."""
|
|
for doc in self.documents:
|
|
self.selected_docs.add(doc["id"])
|
|
self.run_worker(self.update_table())
|
|
self.update_selection_status()
|
|
|
|
def action_select_none(self) -> None:
|
|
"""Clear all selections."""
|
|
self.selected_docs.clear()
|
|
self.run_worker(self.update_table())
|
|
self.update_selection_status()
|
|
|
|
def action_delete_selected(self) -> None:
|
|
"""Delete selected documents."""
|
|
if self.selected_docs:
|
|
from .dialogs import ConfirmDocumentDeleteScreen
|
|
|
|
self.app.push_screen(
|
|
ConfirmDocumentDeleteScreen(list(self.selected_docs), self.collection, self)
|
|
)
|
|
else:
|
|
self.notify("No documents selected", severity="warning")
|
|
|
|
def action_next_page(self) -> None:
|
|
"""Go to next page."""
|
|
if self.current_offset + self.page_size < self.collection["count"]:
|
|
self.current_offset += self.page_size
|
|
self.run_worker(self.load_documents())
|
|
|
|
def action_prev_page(self) -> None:
|
|
"""Go to previous page."""
|
|
if self.current_offset >= self.page_size:
|
|
self.current_offset -= self.page_size
|
|
self.run_worker(self.load_documents())
|
|
|
|
def action_first_page(self) -> None:
|
|
"""Go to first page."""
|
|
if self.current_offset > 0:
|
|
self.current_offset = 0
|
|
self.run_worker(self.load_documents())
|
|
|
|
def action_last_page(self) -> None:
|
|
"""Go to last page."""
|
|
total_docs = self.collection["count"]
|
|
last_offset = ((total_docs - 1) // self.page_size) * self.page_size
|
|
if self.current_offset != last_offset:
|
|
self.current_offset = last_offset
|
|
self.run_worker(self.load_documents())
|
|
|
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
"""Handle button presses."""
|
|
if event.button.id == "refresh_docs_btn":
|
|
self.action_refresh()
|
|
elif event.button.id == "delete_selected_btn":
|
|
self.action_delete_selected()
|
|
elif event.button.id == "select_all_btn":
|
|
self.action_select_all()
|
|
elif event.button.id == "clear_selection_btn":
|
|
self.action_select_none()
|
|
elif event.button.id == "next_page_btn":
|
|
self.action_next_page()
|
|
elif event.button.id == "prev_page_btn":
|
|
self.action_prev_page()
|
|
|
|
def on_enhanced_data_table_row_toggled(self, event: EnhancedDataTable.RowToggled) -> None:
|
|
"""Handle row toggle from enhanced table."""
|
|
if 0 <= event.row_index < len(self.documents):
|
|
doc = self.documents[event.row_index]
|
|
doc_id = doc["id"]
|
|
|
|
if doc_id in self.selected_docs:
|
|
self.selected_docs.remove(doc_id)
|
|
else:
|
|
self.selected_docs.add(doc_id)
|
|
|
|
self.run_worker(self.update_table())
|
|
self.update_selection_status()
|
|
|
|
def on_enhanced_data_table_select_all(self, event: EnhancedDataTable.SelectAll) -> None:
|
|
"""Handle select all from enhanced table."""
|
|
self.action_select_all()
|
|
|
|
def on_enhanced_data_table_clear_selection(
|
|
self, event: EnhancedDataTable.ClearSelection
|
|
) -> None:
|
|
"""Handle clear selection from enhanced table."""
|
|
self.action_select_none()
|
|
|
|
def action_view_document(self) -> None:
|
|
"""View the content of the currently selected document."""
|
|
if doc := self.get_current_document():
|
|
if self.storage:
|
|
self.app.push_screen(
|
|
DocumentContentModal(doc, self.storage, self.collection["name"])
|
|
)
|
|
else:
|
|
self.notify("No storage backend available", severity="error")
|
|
else:
|
|
self.notify("No document selected", severity="warning")
|
|
|
|
|
|
class DocumentContentModal(ModalScreen[None]):
|
|
"""Modal screen for viewing document content."""
|
|
|
|
DEFAULT_CSS = """
|
|
DocumentContentModal {
|
|
align: center middle;
|
|
}
|
|
|
|
DocumentContentModal > Container {
|
|
width: 90%;
|
|
height: 85%;
|
|
background: $surface;
|
|
border: thick $primary;
|
|
}
|
|
|
|
DocumentContentModal .modal-header {
|
|
background: $primary;
|
|
color: $text;
|
|
padding: 1;
|
|
dock: top;
|
|
height: 3;
|
|
}
|
|
|
|
DocumentContentModal .modal-content {
|
|
padding: 1;
|
|
height: 1fr;
|
|
}
|
|
"""
|
|
|
|
BINDINGS = [
|
|
Binding("escape", "app.pop_screen", "Close"),
|
|
Binding("q", "app.pop_screen", "Close"),
|
|
]
|
|
|
|
def __init__(self, document: DocumentInfo, storage: BaseStorage, collection_name: str):
|
|
super().__init__()
|
|
self.document = document
|
|
self.storage = storage
|
|
self.collection_name = collection_name
|
|
|
|
def compose(self) -> ComposeResult:
|
|
yield Container(
|
|
Static(
|
|
f"📄 Document: {self.document['title'][:60]}{'...' if len(self.document['title']) > 60 else ''}",
|
|
classes="modal-header",
|
|
),
|
|
ScrollableContainer(
|
|
Markdown("Loading document content...", id="document_content"),
|
|
LoadingIndicator(id="content_loading"),
|
|
classes="modal-content",
|
|
),
|
|
)
|
|
|
|
async def on_mount(self) -> None:
|
|
"""Load and display the document content."""
|
|
content_widget = self.query_one("#document_content", Markdown)
|
|
loading = self.query_one("#content_loading")
|
|
|
|
try:
|
|
# Get full document content
|
|
doc_content = await self.storage.retrieve(
|
|
self.document["id"], collection_name=self.collection_name
|
|
)
|
|
|
|
# Format content for display
|
|
if isinstance(doc_content, str):
|
|
formatted_content = f"""# {self.document["title"]}
|
|
|
|
**Source:** {self.document.get("source_url", "N/A")}
|
|
**Type:** {self.document.get("content_type", "text/plain")}
|
|
**Words:** {self.document.get("word_count", 0):,}
|
|
**Timestamp:** {self.document.get("timestamp", "N/A")}
|
|
|
|
---
|
|
|
|
{doc_content}
|
|
"""
|
|
else:
|
|
formatted_content = f"""# {self.document["title"]}
|
|
|
|
**Source:** {self.document.get("source_url", "N/A")}
|
|
**Type:** {self.document.get("content_type", "text/plain")}
|
|
**Words:** {self.document.get("word_count", 0):,}
|
|
**Timestamp:** {self.document.get("timestamp", "N/A")}
|
|
|
|
---
|
|
|
|
*Content format not supported for display*
|
|
"""
|
|
|
|
content_widget.update(formatted_content)
|
|
|
|
except Exception as e:
|
|
content_widget.update(
|
|
f"# Error Loading Document\n\nFailed to load document content: {e}"
|
|
)
|
|
finally:
|
|
loading.display = False
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/cli/tui/utils/runners.py">
|
|
"""TUI runner functions and initialization."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from logging import Logger
|
|
from logging.handlers import QueueHandler, RotatingFileHandler
|
|
from pathlib import Path
|
|
from queue import Queue
|
|
from typing import NamedTuple
|
|
|
|
import platformdirs
|
|
|
|
from ....config import configure_prefect, get_settings
|
|
from .storage_manager import StorageManager
|
|
|
|
|
|
class _TuiLoggingContext(NamedTuple):
|
|
"""Container describing configured logging outputs for the TUI."""
|
|
|
|
queue: Queue[logging.LogRecord]
|
|
formatter: logging.Formatter
|
|
log_file: Path | None
|
|
|
|
|
|
_logging_context: _TuiLoggingContext | None = None
|
|
|
|
|
|
def _configure_tui_logging(*, log_level: str) -> _TuiLoggingContext:
|
|
"""Configure logging so that messages do not break the TUI output."""
|
|
|
|
global _logging_context
|
|
if _logging_context is not None:
|
|
return _logging_context
|
|
|
|
resolved_level = getattr(logging, log_level.upper(), logging.INFO)
|
|
log_queue: Queue[logging.LogRecord] = Queue()
|
|
formatter = logging.Formatter(
|
|
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
)
|
|
|
|
root_logger = logging.getLogger()
|
|
root_logger.setLevel(resolved_level)
|
|
|
|
# Remove existing stream handlers to prevent console flicker inside the TUI
|
|
for handler in list(root_logger.handlers):
|
|
root_logger.removeHandler(handler)
|
|
|
|
queue_handler = QueueHandler(log_queue)
|
|
queue_handler.setLevel(resolved_level)
|
|
root_logger.addHandler(queue_handler)
|
|
|
|
log_file: Path | None = None
|
|
try:
|
|
# Try current directory first for development
|
|
log_dir = Path.cwd() / "logs"
|
|
log_dir.mkdir(parents=True, exist_ok=True)
|
|
log_file = log_dir / "tui.log"
|
|
except OSError:
|
|
# Fall back to user log directory
|
|
try:
|
|
log_dir = Path(platformdirs.user_log_dir("ingest-pipeline", "ingest-pipeline"))
|
|
log_dir.mkdir(parents=True, exist_ok=True)
|
|
log_file = log_dir / "tui.log"
|
|
except OSError as exc:
|
|
fallback = logging.getLogger(__name__)
|
|
fallback.warning("Failed to create log directory, file logging disabled: %s", exc)
|
|
log_file = None
|
|
|
|
if log_file:
|
|
try:
|
|
file_handler = RotatingFileHandler(
|
|
log_file,
|
|
maxBytes=2_000_000,
|
|
backupCount=5,
|
|
encoding="utf-8",
|
|
)
|
|
file_handler.setLevel(resolved_level)
|
|
file_handler.setFormatter(formatter)
|
|
root_logger.addHandler(file_handler)
|
|
except OSError as exc:
|
|
fallback = logging.getLogger(__name__)
|
|
fallback.warning("Failed to configure file logging for TUI: %s", exc)
|
|
log_file = None
|
|
|
|
_logging_context = _TuiLoggingContext(log_queue, formatter, log_file)
|
|
return _logging_context
|
|
|
|
|
|
LOGGER: Logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def run_textual_tui() -> None:
|
|
"""Run the enhanced modern TUI with better error handling and initialization."""
|
|
settings = get_settings()
|
|
configure_prefect(settings)
|
|
|
|
logging_context = _configure_tui_logging(log_level=settings.log_level)
|
|
|
|
LOGGER.info("Initializing collection management TUI")
|
|
LOGGER.info("Scanning available storage backends")
|
|
|
|
# Create storage manager without initialization - let TUI handle it asynchronously
|
|
storage_manager = StorageManager(settings)
|
|
|
|
LOGGER.info("Launching TUI - storage backends will initialize in background")
|
|
|
|
# Import here to avoid circular import
|
|
from ..app import CollectionManagementApp
|
|
|
|
app = CollectionManagementApp(
|
|
storage_manager,
|
|
None, # weaviate - will be available after initialization
|
|
None, # openwebui - will be available after initialization
|
|
None, # r2r_backend - will be available after initialization
|
|
log_queue=logging_context.queue,
|
|
log_formatter=logging_context.formatter,
|
|
log_file=logging_context.log_file,
|
|
)
|
|
try:
|
|
await app.run_async()
|
|
finally:
|
|
LOGGER.info("Shutting down storage connections")
|
|
await storage_manager.close_all()
|
|
LOGGER.info("All storage connections closed gracefully")
|
|
|
|
|
|
def dashboard() -> None:
|
|
"""Launch the modern collection dashboard."""
|
|
asyncio.run(run_textual_tui())
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/cli/tui/styles.py">
|
|
"""Comprehensive theming system for TUI applications with WCAG AA accessibility compliance."""
|
|
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import Protocol
|
|
|
|
from textual.app import App
|
|
|
|
# Type alias for Textual apps with unknown return type
|
|
TextualApp = App[object]
|
|
|
|
|
|
class AppProtocol(Protocol):
|
|
"""Protocol for apps that support CSS and refresh."""
|
|
|
|
CSS: str
|
|
|
|
def refresh(self) -> None:
|
|
"""Refresh the app."""
|
|
...
|
|
|
|
|
|
class ThemeType(Enum):
|
|
"""Available theme types."""
|
|
|
|
DARK = "dark"
|
|
LIGHT = "light"
|
|
HIGH_CONTRAST = "high_contrast"
|
|
GITHUB_DARK = "github_dark"
|
|
|
|
|
|
@dataclass
|
|
class ColorPalette:
|
|
"""Color palette with WCAG AA compliant contrast ratios."""
|
|
|
|
# Background colors
|
|
bg_primary: str
|
|
bg_secondary: str
|
|
bg_tertiary: str
|
|
bg_elevated: str
|
|
|
|
# Text colors (all tested for WCAG AA compliance)
|
|
text_primary: str # 4.5:1+ contrast ratio
|
|
text_secondary: str # 4.5:1+ contrast ratio
|
|
text_tertiary: str # 4.5:1+ contrast ratio
|
|
text_inverse: str
|
|
|
|
# Semantic colors
|
|
primary: str
|
|
primary_hover: str
|
|
success: str
|
|
warning: str
|
|
error: str
|
|
info: str
|
|
|
|
# Interactive states
|
|
border_default: str
|
|
border_focus: str
|
|
border_hover: str
|
|
|
|
# Surface colors
|
|
surface_1: str
|
|
surface_2: str
|
|
surface_3: str
|
|
|
|
|
|
class ThemeRegistry:
|
|
"""Registry for managing application themes."""
|
|
|
|
@staticmethod
|
|
def get_enhanced_dark() -> ColorPalette:
|
|
"""Enhanced dark theme with superior contrast ratios."""
|
|
return ColorPalette(
|
|
# Backgrounds - darker for better contrast
|
|
bg_primary="#0a0c10",
|
|
bg_secondary="#151821",
|
|
bg_tertiary="#1f2329",
|
|
bg_elevated="#252932",
|
|
# Text - brighter for better visibility (WCAG AA compliant)
|
|
text_primary="#ffffff", # 21:1 contrast ratio
|
|
text_secondary="#e6edf3", # 14.8:1 contrast ratio
|
|
text_tertiary="#c9d1d9", # 9.6:1 contrast ratio
|
|
text_inverse="#0a0c10",
|
|
# Semantic colors - enhanced for visibility
|
|
primary="#1f6feb",
|
|
primary_hover="#388bfd",
|
|
success="#238636",
|
|
warning="#d29922",
|
|
error="#f85149",
|
|
info="#58a6ff",
|
|
# Interactive states
|
|
border_default="#444c56",
|
|
border_focus="#58a6ff",
|
|
border_hover="#58a6ff",
|
|
# Surface elevation
|
|
surface_1="#161b22",
|
|
surface_2="#21262d",
|
|
surface_3="#30363d",
|
|
)
|
|
|
|
@staticmethod
|
|
def get_light() -> ColorPalette:
|
|
"""Light theme with excellent readability."""
|
|
return ColorPalette(
|
|
# Backgrounds
|
|
bg_primary="#ffffff",
|
|
bg_secondary="#f6f8fa",
|
|
bg_tertiary="#f1f3f4",
|
|
bg_elevated="#ffffff",
|
|
# Text (WCAG AA compliant)
|
|
text_primary="#1f2328", # 12.6:1 contrast ratio
|
|
text_secondary="#424a53", # 7.1:1 contrast ratio
|
|
text_tertiary="#636c76", # 4.7:1 contrast ratio
|
|
text_inverse="#ffffff",
|
|
# Semantic colors
|
|
primary="#0969da",
|
|
primary_hover="#0860ca",
|
|
success="#1a7f37",
|
|
warning="#9a6700",
|
|
error="#d1242f",
|
|
info="#0969da",
|
|
# Interactive states
|
|
border_default="#d1d9e0",
|
|
border_focus="#fd7e14",
|
|
border_hover="#0969da",
|
|
# Surface elevation
|
|
surface_1="#f6f8fa",
|
|
surface_2="#eaeef2",
|
|
surface_3="#d1d9e0",
|
|
)
|
|
|
|
@staticmethod
|
|
def get_high_contrast() -> ColorPalette:
|
|
"""High contrast theme for maximum accessibility."""
|
|
return ColorPalette(
|
|
# Backgrounds
|
|
bg_primary="#000000",
|
|
bg_secondary="#1a1a1a",
|
|
bg_tertiary="#262626",
|
|
bg_elevated="#333333",
|
|
# Text (Maximum contrast)
|
|
text_primary="#ffffff", # 21:1 contrast ratio
|
|
text_secondary="#ffffff", # 21:1 contrast ratio
|
|
text_tertiary="#cccccc", # 11.8:1 contrast ratio
|
|
text_inverse="#000000",
|
|
# Semantic colors - high contrast variants
|
|
primary="#00aaff",
|
|
primary_hover="#66ccff",
|
|
success="#00ff00",
|
|
warning="#ffaa00",
|
|
error="#ff4444",
|
|
info="#00aaff",
|
|
# Interactive states
|
|
border_default="#666666",
|
|
border_focus="#ffff00",
|
|
border_hover="#ffffff",
|
|
# Surface elevation
|
|
surface_1="#1a1a1a",
|
|
surface_2="#333333",
|
|
surface_3="#4d4d4d",
|
|
)
|
|
|
|
@staticmethod
|
|
def get_github_dark() -> ColorPalette:
|
|
"""Enhanced GitHub dark theme with improved contrast."""
|
|
return ColorPalette(
|
|
# Backgrounds
|
|
bg_primary="#0d1117",
|
|
bg_secondary="#161b22",
|
|
bg_tertiary="#21262d",
|
|
bg_elevated="#2d333b",
|
|
# Text (Enhanced for better visibility)
|
|
text_primary="#f0f6fc", # 13.6:1 contrast ratio
|
|
text_secondary="#e6edf3", # 11.9:1 contrast ratio
|
|
text_tertiary="#c9d1d9", # 8.2:1 contrast ratio
|
|
text_inverse="#0d1117",
|
|
# Semantic colors
|
|
primary="#58a6ff",
|
|
primary_hover="#79c0ff",
|
|
success="#3fb950",
|
|
warning="#d29922",
|
|
error="#f85149",
|
|
info="#58a6ff",
|
|
# Interactive states
|
|
border_default="#30363d",
|
|
border_focus="#f78166",
|
|
border_hover="#58a6ff",
|
|
# Surface elevation
|
|
surface_1="#161b22",
|
|
surface_2="#21262d",
|
|
surface_3="#30363d",
|
|
)
|
|
|
|
|
|
class ThemeManager:
|
|
"""Manages theme selection and CSS generation."""
|
|
|
|
def __init__(self, default_theme: ThemeType = ThemeType.DARK):
|
|
self.current_theme: ThemeType = default_theme
|
|
self._themes: dict[ThemeType, ColorPalette] = {
|
|
ThemeType.DARK: ThemeRegistry.get_enhanced_dark(),
|
|
ThemeType.LIGHT: ThemeRegistry.get_light(),
|
|
ThemeType.HIGH_CONTRAST: ThemeRegistry.get_high_contrast(),
|
|
ThemeType.GITHUB_DARK: ThemeRegistry.get_github_dark(),
|
|
}
|
|
|
|
def set_theme(self, theme: ThemeType) -> None:
|
|
"""Switch to a different theme."""
|
|
self.current_theme = theme
|
|
|
|
def get_current_palette(self) -> ColorPalette:
|
|
"""Get the current theme's color palette."""
|
|
return self._themes[self.current_theme]
|
|
|
|
def generate_css(self) -> str:
|
|
"""Generate Textual CSS for the current theme."""
|
|
palette = self.get_current_palette()
|
|
|
|
return f"""
|
|
/* ===============================================
|
|
ENHANCED THEMING SYSTEM - {self.current_theme.value.upper()}
|
|
WCAG AA Compliant with Superior Text Visibility
|
|
=============================================== */
|
|
|
|
/* Base Application Styling */
|
|
Screen {{
|
|
background: {palette.bg_primary};
|
|
}}
|
|
|
|
* {{
|
|
color: {palette.text_primary};
|
|
}}
|
|
|
|
/* ===============================================
|
|
LAYOUT & CONTAINERS
|
|
=============================================== */
|
|
|
|
/* Enhanced title styling with superior contrast */
|
|
.title {{
|
|
text-align: center;
|
|
margin: 0;
|
|
color: {palette.text_primary};
|
|
text-style: bold;
|
|
background: {palette.bg_secondary};
|
|
padding: 0 1;
|
|
height: 3;
|
|
min-height: 3;
|
|
max-height: 3;
|
|
border: solid {palette.primary};
|
|
}}
|
|
|
|
.subtitle {{
|
|
text-align: center;
|
|
margin: 0;
|
|
color: {palette.text_secondary};
|
|
text-style: italic;
|
|
background: {palette.bg_secondary};
|
|
padding: 0 1;
|
|
height: 2;
|
|
min-height: 2;
|
|
max-height: 2;
|
|
}}
|
|
|
|
/* Main container with elevated surface */
|
|
.main_container {{
|
|
margin: 0;
|
|
padding: 1 0;
|
|
background: {palette.bg_secondary};
|
|
}}
|
|
|
|
/* Enhanced card components with better elevation */
|
|
.card {{
|
|
background: {palette.surface_2};
|
|
padding: 1;
|
|
margin: 0 1;
|
|
color: {palette.text_primary};
|
|
border: solid {palette.border_default};
|
|
height: auto;
|
|
min-height: 4;
|
|
}}
|
|
|
|
.card:focus-within {{
|
|
border: thick {palette.border_focus};
|
|
background: {palette.bg_elevated};
|
|
}}
|
|
|
|
/* ===============================================
|
|
INTERACTIVE ELEMENTS
|
|
=============================================== */
|
|
|
|
/* Base button with superior contrast */
|
|
Button {{
|
|
background: {palette.surface_2};
|
|
color: {palette.text_primary};
|
|
margin: 0 1;
|
|
border: solid {palette.border_default};
|
|
}}
|
|
|
|
Button:hover {{
|
|
background: {palette.surface_3};
|
|
color: {palette.text_primary};
|
|
border: solid {palette.border_hover};
|
|
}}
|
|
|
|
Button:focus {{
|
|
border: thick {palette.border_focus};
|
|
background: {palette.primary};
|
|
color: {palette.text_inverse};
|
|
}}
|
|
|
|
/* Semantic button variants with enhanced visibility */
|
|
Button.-primary {{
|
|
background: {palette.primary};
|
|
color: {palette.text_inverse};
|
|
border: solid {palette.primary};
|
|
}}
|
|
|
|
Button.-primary:hover {{
|
|
background: {palette.primary_hover};
|
|
border: solid {palette.primary_hover};
|
|
}}
|
|
|
|
Button.-primary:focus {{
|
|
border: thick {palette.border_focus};
|
|
background: {palette.primary_hover};
|
|
}}
|
|
|
|
Button.-success {{
|
|
background: {palette.success};
|
|
color: {palette.text_inverse};
|
|
border: solid {palette.success};
|
|
}}
|
|
|
|
Button.-success:hover {{
|
|
background: {palette.success};
|
|
opacity: 0.9;
|
|
}}
|
|
|
|
Button.-error {{
|
|
background: {palette.error};
|
|
color: {palette.text_inverse};
|
|
border: solid {palette.error};
|
|
}}
|
|
|
|
Button.-error:hover {{
|
|
background: {palette.error};
|
|
opacity: 0.9;
|
|
}}
|
|
|
|
Button.-warning {{
|
|
background: {palette.warning};
|
|
color: {palette.text_inverse};
|
|
border: solid {palette.warning};
|
|
}}
|
|
|
|
Button.-warning:hover {{
|
|
background: {palette.warning};
|
|
opacity: 0.9;
|
|
}}
|
|
|
|
/* ===============================================
|
|
DATA DISPLAY - ENHANCED READABILITY
|
|
=============================================== */
|
|
|
|
/* DataTable with superior contrast and accessibility */
|
|
DataTable {{
|
|
background: {palette.surface_2};
|
|
color: {palette.text_primary};
|
|
border: solid {palette.border_default};
|
|
}}
|
|
|
|
DataTable:focus {{
|
|
border: thick {palette.border_focus};
|
|
}}
|
|
|
|
DataTable > .datatable--header {{
|
|
background: {palette.bg_secondary};
|
|
color: {palette.primary};
|
|
text-style: bold;
|
|
}}
|
|
|
|
DataTable > .datatable--cursor {{
|
|
background: {palette.primary};
|
|
color: {palette.text_inverse};
|
|
}}
|
|
|
|
DataTable > .datatable--cursor-row {{
|
|
background: {palette.primary_hover};
|
|
color: {palette.text_inverse};
|
|
}}
|
|
|
|
DataTable > .datatable--row-odd {{
|
|
background: {palette.surface_2};
|
|
color: {palette.text_primary};
|
|
}}
|
|
|
|
DataTable > .datatable--row-even {{
|
|
background: {palette.bg_tertiary};
|
|
color: {palette.text_primary};
|
|
}}
|
|
|
|
/* ===============================================
|
|
FORM ELEMENTS - ACCESSIBLE INPUT DESIGN
|
|
=============================================== */
|
|
|
|
/* Enhanced input with superior visibility */
|
|
Input {{
|
|
background: {palette.surface_1};
|
|
color: {palette.text_primary};
|
|
border: solid {palette.border_default};
|
|
}}
|
|
|
|
Input:focus {{
|
|
border: thick {palette.border_focus};
|
|
background: {palette.bg_elevated};
|
|
color: {palette.text_primary};
|
|
}}
|
|
|
|
Input.-invalid {{
|
|
border: solid {palette.error};
|
|
background: {palette.surface_1};
|
|
}}
|
|
|
|
Input.-invalid:focus {{
|
|
border: thick {palette.error};
|
|
background: {palette.bg_elevated};
|
|
}}
|
|
|
|
/* ===============================================
|
|
NAVIGATION - ENHANCED CLARITY
|
|
=============================================== */
|
|
|
|
/* Header and Footer with improved contrast */
|
|
Header, Footer {{
|
|
background: {palette.bg_secondary};
|
|
color: {palette.text_primary};
|
|
border: solid {palette.border_default};
|
|
}}
|
|
|
|
/* Simple Tab styling to ensure text visibility */
|
|
Tab {{
|
|
color: {palette.text_primary};
|
|
background: {palette.surface_2};
|
|
}}
|
|
|
|
Tab:hover {{
|
|
color: {palette.text_primary};
|
|
background: {palette.surface_3};
|
|
}}
|
|
|
|
Tab:focus {{
|
|
color: {palette.text_primary};
|
|
background: {palette.bg_elevated};
|
|
}}
|
|
|
|
Tab.-active {{
|
|
background: {palette.primary};
|
|
color: {palette.text_inverse};
|
|
text-style: bold;
|
|
}}
|
|
|
|
/* ===============================================
|
|
TYPOGRAPHY - WCAG AA COMPLIANT
|
|
=============================================== */
|
|
|
|
/* Label hierarchy with enhanced readability */
|
|
Label {{
|
|
color: {palette.text_primary};
|
|
}}
|
|
|
|
.label-secondary {{
|
|
color: {palette.text_secondary};
|
|
}}
|
|
|
|
.label-muted {{
|
|
color: {palette.text_tertiary};
|
|
}}
|
|
|
|
/* ===============================================
|
|
STATUS INDICATORS - ENHANCED VISIBILITY
|
|
=============================================== */
|
|
|
|
/* Semantic status colors with superior contrast */
|
|
.status-active, .status-success {{
|
|
color: {palette.success};
|
|
text-style: bold;
|
|
}}
|
|
|
|
.status-error, .status-failed {{
|
|
color: {palette.error};
|
|
text-style: bold;
|
|
}}
|
|
|
|
.status-warning, .status-pending {{
|
|
color: {palette.warning};
|
|
text-style: bold;
|
|
}}
|
|
|
|
.status-info {{
|
|
color: {palette.info};
|
|
text-style: bold;
|
|
}}
|
|
|
|
.status-inactive, .status-disabled {{
|
|
color: {palette.text_tertiary};
|
|
}}
|
|
|
|
/* ===============================================
|
|
VISUAL EFFECTS - ACCESSIBLE ANIMATIONS
|
|
=============================================== */
|
|
|
|
/* Animation classes with accessibility considerations */
|
|
.pulse {{
|
|
text-style: blink;
|
|
}}
|
|
|
|
.glow {{
|
|
background: {palette.primary};
|
|
color: {palette.text_inverse};
|
|
}}
|
|
|
|
.shimmer {{
|
|
text-style: italic;
|
|
color: {palette.text_secondary};
|
|
}}
|
|
|
|
.highlight {{
|
|
background: {palette.border_focus};
|
|
color: {palette.text_inverse};
|
|
}}
|
|
|
|
/* ===============================================
|
|
METRICS - ENHANCED DASHBOARD VISIBILITY
|
|
=============================================== */
|
|
|
|
/* Enhanced metrics with superior readability */
|
|
.metrics-value {{
|
|
text-style: bold;
|
|
text-align: center;
|
|
color: {palette.primary};
|
|
height: 1;
|
|
margin: 0;
|
|
}}
|
|
|
|
.metrics-label {{
|
|
text-align: center;
|
|
color: {palette.text_primary};
|
|
text-style: bold;
|
|
height: 1;
|
|
margin: 0;
|
|
}}
|
|
|
|
.metrics-description {{
|
|
text-align: center;
|
|
color: {palette.text_secondary};
|
|
text-style: italic;
|
|
height: 1;
|
|
margin: 0;
|
|
}}
|
|
|
|
/* MetricsCard container optimization */
|
|
MetricsCard {{
|
|
background: {palette.surface_2};
|
|
border: solid {palette.border_default};
|
|
padding: 0 1;
|
|
margin: 0;
|
|
height: auto;
|
|
min-height: 3;
|
|
max-height: 5;
|
|
align: center middle;
|
|
}}
|
|
|
|
/* Section organization with enhanced hierarchy */
|
|
.section-title {{
|
|
text-style: bold;
|
|
color: {palette.primary};
|
|
margin: 0;
|
|
border-left: thick {palette.primary};
|
|
padding-left: 1;
|
|
height: auto;
|
|
min-height: 2;
|
|
max-height: 3;
|
|
}}
|
|
|
|
.section-subtitle {{
|
|
color: {palette.text_secondary};
|
|
text-style: italic;
|
|
margin: 0 0 1 0;
|
|
}}
|
|
|
|
/* ===============================================
|
|
LAYOUT SYSTEMS - IMPROVED READABILITY
|
|
=============================================== */
|
|
|
|
/* Enhanced text styling with better contrast */
|
|
.status-text {{
|
|
color: {palette.text_secondary};
|
|
text-align: center;
|
|
margin: 1 0;
|
|
text-style: italic;
|
|
}}
|
|
|
|
.help-text {{
|
|
color: {palette.text_tertiary};
|
|
text-style: italic;
|
|
}}
|
|
|
|
/* Button organization with enhanced backgrounds */
|
|
.button_bar {{
|
|
margin: 0;
|
|
background: {palette.bg_secondary};
|
|
padding: 1;
|
|
height: auto;
|
|
min-height: 5;
|
|
max-height: 6;
|
|
}}
|
|
|
|
.action_buttons {{
|
|
margin: 0;
|
|
text-align: center;
|
|
padding: 1;
|
|
height: auto;
|
|
background: {palette.surface_2};
|
|
border-top: solid {palette.border_default};
|
|
}}
|
|
|
|
/* Enhanced progress indicators */
|
|
.progress-label {{
|
|
color: {palette.text_primary};
|
|
margin: 1 0;
|
|
text-style: bold;
|
|
text-align: center;
|
|
}}
|
|
|
|
.progress-complete {{
|
|
color: {palette.success};
|
|
text-style: bold;
|
|
}}
|
|
|
|
.progress-error {{
|
|
color: {palette.error};
|
|
text-style: bold;
|
|
}}
|
|
|
|
/* ===============================================
|
|
RESPONSIVE GRID SYSTEMS
|
|
=============================================== */
|
|
|
|
/* Enhanced grid layouts */
|
|
.responsive-grid {{
|
|
grid-size: 4;
|
|
grid-gutter: 1;
|
|
background: {palette.bg_primary};
|
|
margin: 0;
|
|
padding: 0;
|
|
height: auto;
|
|
}}
|
|
|
|
.metrics-grid {{
|
|
grid-size: 4;
|
|
grid-gutter: 1;
|
|
margin: 0;
|
|
padding: 0;
|
|
background: {palette.bg_primary};
|
|
align: center middle;
|
|
height: auto;
|
|
min-height: 5;
|
|
max-height: 7;
|
|
}}
|
|
|
|
.analytics-grid {{
|
|
grid-size: 2;
|
|
grid-gutter: 1;
|
|
background: {palette.bg_primary};
|
|
}}
|
|
|
|
/* ===============================================
|
|
MODAL & OVERLAY - ENHANCED ACCESSIBILITY
|
|
=============================================== */
|
|
|
|
/* Accessible modal design */
|
|
IngestionScreen {{
|
|
align: center middle;
|
|
}}
|
|
|
|
.modal-container {{
|
|
background: {palette.surface_2};
|
|
border: thick {palette.primary};
|
|
padding: 1;
|
|
width: 90%;
|
|
height: 80%;
|
|
max-width: 80;
|
|
min-width: 40;
|
|
overflow-y: auto;
|
|
layout: vertical;
|
|
}}
|
|
|
|
/* Backend selection responsive layout */
|
|
.backend-selection {{
|
|
layout: horizontal;
|
|
padding: 1;
|
|
height: auto;
|
|
align: center middle;
|
|
}}
|
|
|
|
.backend-actions {{
|
|
layout: horizontal;
|
|
padding: 1;
|
|
height: auto;
|
|
align: center middle;
|
|
}}
|
|
|
|
/* Responsive adjustments for horizontal layout */
|
|
.backend-actions Button {{
|
|
margin: 0 1;
|
|
width: auto;
|
|
min-width: 12;
|
|
text-overflow: ellipsis;
|
|
}}
|
|
|
|
/* Backend selection checkboxes horizontal layout */
|
|
.backend-selection Checkbox {{
|
|
margin: 0 2;
|
|
width: auto;
|
|
text-overflow: ellipsis;
|
|
}}
|
|
|
|
/* Input section responsive improvements */
|
|
.input-section {{
|
|
margin: 1 0;
|
|
padding: 1;
|
|
background: {palette.surface_2};
|
|
border: solid {palette.border_default};
|
|
height: auto;
|
|
width: 100%;
|
|
}}
|
|
|
|
.modal-header {{
|
|
background: {palette.bg_secondary};
|
|
color: {palette.primary};
|
|
text-style: bold;
|
|
padding: 1;
|
|
border-bottom: solid {palette.border_default};
|
|
}}
|
|
|
|
.modal-body {{
|
|
padding: 1;
|
|
color: {palette.text_primary};
|
|
}}
|
|
|
|
.modal-footer {{
|
|
background: {palette.bg_secondary};
|
|
padding: 1;
|
|
border-top: solid {palette.border_default};
|
|
}}
|
|
|
|
/* ===============================================
|
|
SPECIALIZED COMPONENTS - ENHANCED VISIBILITY
|
|
=============================================== */
|
|
|
|
/* Enhanced chart and analytics */
|
|
.chart-title {{
|
|
text-style: bold;
|
|
color: {palette.primary};
|
|
margin: 1 0;
|
|
}}
|
|
|
|
.chart-placeholder {{
|
|
color: {palette.text_tertiary};
|
|
text-style: italic;
|
|
text-align: center;
|
|
padding: 2;
|
|
background: {palette.bg_secondary};
|
|
border: dashed {palette.border_default};
|
|
}}
|
|
|
|
/* Enhanced table variants with superior contrast */
|
|
.enhanced-table {{
|
|
background: {palette.surface_2};
|
|
color: {palette.text_primary};
|
|
border: solid {palette.border_default};
|
|
}}
|
|
|
|
.enhanced-table:focus {{
|
|
border: thick {palette.border_focus};
|
|
}}
|
|
|
|
.enhanced-table-header {{
|
|
background: {palette.bg_secondary};
|
|
color: {palette.primary};
|
|
text-style: bold;
|
|
}}
|
|
|
|
/* Enhanced status and info bars */
|
|
.status-bar {{
|
|
background: {palette.bg_secondary};
|
|
color: {palette.text_secondary};
|
|
padding: 0 1;
|
|
border: solid {palette.border_default};
|
|
}}
|
|
|
|
.info-bar {{
|
|
background: {palette.info};
|
|
color: {palette.text_inverse};
|
|
padding: 0 1;
|
|
}}
|
|
|
|
/* ===============================================
|
|
FORM SECTIONS - ACCESSIBLE INPUT DESIGN
|
|
=============================================== */
|
|
|
|
/* Enhanced input organization */
|
|
.input-section {{
|
|
margin: 0;
|
|
padding: 1;
|
|
background: {palette.surface_2};
|
|
border: solid {palette.border_default};
|
|
height: auto;
|
|
}}
|
|
|
|
.input-label {{
|
|
color: {palette.text_primary};
|
|
margin: 0 0 1 0;
|
|
text-style: bold;
|
|
}}
|
|
|
|
.input-help {{
|
|
color: {palette.text_secondary};
|
|
text-style: italic;
|
|
margin: 0 0 1 0;
|
|
}}
|
|
|
|
.modern-input {{
|
|
background: {palette.surface_1};
|
|
color: {palette.text_primary};
|
|
border: solid {palette.border_default};
|
|
margin: 1 0;
|
|
}}
|
|
|
|
.modern-input:focus {{
|
|
border: thick {palette.border_focus};
|
|
background: {palette.bg_elevated};
|
|
}}
|
|
|
|
/* Enhanced type selection buttons */
|
|
.type_buttons {{
|
|
margin: 0;
|
|
height: auto;
|
|
}}
|
|
|
|
.type-button {{
|
|
margin: 0 1;
|
|
background: {palette.surface_2};
|
|
color: {palette.text_primary};
|
|
border: solid {palette.border_default};
|
|
}}
|
|
|
|
.type-button:hover {{
|
|
background: {palette.surface_3};
|
|
border: solid {palette.border_hover};
|
|
}}
|
|
|
|
.type-button.-selected {{
|
|
background: {palette.primary};
|
|
color: {palette.text_inverse};
|
|
border: solid {palette.primary};
|
|
}}
|
|
|
|
/* ===============================================
|
|
UTILITY CLASSES - ENHANCED CONSISTENCY
|
|
=============================================== */
|
|
|
|
/* Enhanced progress sections */
|
|
.progress-section {{
|
|
margin: 1 0;
|
|
padding: 1;
|
|
background: {palette.surface_2};
|
|
border: solid {palette.border_default};
|
|
height: auto;
|
|
}}
|
|
|
|
/* Alignment utilities */
|
|
.center {{
|
|
text-align: center;
|
|
margin: 0;
|
|
padding: 0;
|
|
}}
|
|
|
|
.text-left {{
|
|
text-align: left;
|
|
}}
|
|
|
|
.text-right {{
|
|
text-align: right;
|
|
}}
|
|
|
|
/* Dashboard container spacing optimization */
|
|
Container.center {{
|
|
margin: 0;
|
|
padding: 0;
|
|
height: auto;
|
|
min-height: 0;
|
|
}}
|
|
|
|
/* Grid spacing optimization */
|
|
Grid {{
|
|
margin: 0;
|
|
padding: 0;
|
|
height: auto;
|
|
}}
|
|
|
|
/* Rule spacing optimization */
|
|
Rule {{
|
|
margin: 0;
|
|
padding: 0;
|
|
height: 1;
|
|
min-height: 1;
|
|
max-height: 1;
|
|
}}
|
|
|
|
/* Specific spacing elimination for dashboard */
|
|
.main_container Rule {{
|
|
margin: 0;
|
|
height: 0;
|
|
display: none;
|
|
}}
|
|
|
|
.main_container Container {{
|
|
margin: 0;
|
|
padding: 0;
|
|
}}
|
|
|
|
/* Enhanced state utilities */
|
|
.warning {{
|
|
color: {palette.warning};
|
|
text-style: bold;
|
|
}}
|
|
|
|
.error {{
|
|
color: {palette.error};
|
|
text-style: bold;
|
|
}}
|
|
|
|
.success {{
|
|
color: {palette.success};
|
|
text-style: bold;
|
|
}}
|
|
|
|
.info {{
|
|
color: {palette.info};
|
|
text-style: bold;
|
|
}}
|
|
|
|
/* Enhanced interactive state utilities */
|
|
.pressed {{
|
|
background: {palette.primary_hover};
|
|
color: {palette.text_inverse};
|
|
}}
|
|
|
|
.selected {{
|
|
background: {palette.primary};
|
|
color: {palette.text_inverse};
|
|
border: solid {palette.primary};
|
|
}}
|
|
|
|
.disabled {{
|
|
color: {palette.text_tertiary};
|
|
background: {palette.bg_secondary};
|
|
}}
|
|
|
|
/* ===============================================
|
|
ACCESSIBILITY - WCAG AA COMPLIANCE
|
|
=============================================== */
|
|
|
|
/* Enhanced global focus indicator system */
|
|
*:focus {{
|
|
outline: solid {palette.border_focus};
|
|
}}
|
|
|
|
/* Improved high contrast mode support */
|
|
.high-contrast {{
|
|
color: #ffffff;
|
|
background: #000000;
|
|
}}
|
|
|
|
.high-contrast Button {{
|
|
border: thick #ffffff;
|
|
color: #ffffff;
|
|
background: #000000;
|
|
}}
|
|
|
|
.high-contrast Button:focus {{
|
|
background: #ffffff;
|
|
color: #000000;
|
|
border: thick #000000;
|
|
}}
|
|
|
|
/* Enhanced reduced motion support */
|
|
.reduced-motion .pulse {{
|
|
text-style: none;
|
|
}}
|
|
|
|
.reduced-motion .shimmer {{
|
|
text-style: none;
|
|
}}
|
|
|
|
.reduced-motion .pulse {{
|
|
text-style: none;
|
|
}}
|
|
|
|
/* ===============================================
|
|
COMPONENT ENHANCEMENTS - IMPROVED VISIBILITY
|
|
=============================================== */
|
|
|
|
/* Enhanced loading states */
|
|
.loading {{
|
|
color: {palette.text_secondary};
|
|
text-style: italic;
|
|
}}
|
|
|
|
.loading-dots {{
|
|
color: {palette.primary};
|
|
text-style: blink;
|
|
}}
|
|
|
|
/* Enhanced empty states */
|
|
.empty-state {{
|
|
color: {palette.text_tertiary};
|
|
text-style: italic;
|
|
text-align: center;
|
|
padding: 4;
|
|
}}
|
|
|
|
.empty-state-icon {{
|
|
color: {palette.border_default};
|
|
text-align: center;
|
|
}}
|
|
|
|
/* Enhanced search and filter components */
|
|
.search-highlight {{
|
|
background: {palette.warning};
|
|
color: {palette.text_inverse};
|
|
text-style: bold;
|
|
}}
|
|
|
|
.filter-active {{
|
|
color: {palette.primary};
|
|
text-style: bold;
|
|
}}
|
|
|
|
/* Enhanced breadcrumb navigation */
|
|
.breadcrumb {{
|
|
color: {palette.text_secondary};
|
|
}}
|
|
|
|
.breadcrumb-separator {{
|
|
color: {palette.text_tertiary};
|
|
}}
|
|
|
|
.breadcrumb-current {{
|
|
color: {palette.text_primary};
|
|
text-style: bold;
|
|
}}
|
|
|
|
/* ===============================================
|
|
THEME-SPECIFIC CUSTOMIZATIONS
|
|
=============================================== */
|
|
|
|
/* Additional theme-specific styling can be added here */
|
|
.theme-indicator {{
|
|
color: {palette.primary};
|
|
text-style: italic;
|
|
}}
|
|
|
|
.accessibility-notice {{
|
|
color: {palette.text_primary};
|
|
background: {palette.bg_elevated};
|
|
padding: 1;
|
|
border: solid {palette.border_default};
|
|
}}
|
|
"""
|
|
|
|
|
|
# Initialize the theme manager with enhanced dark theme as default
|
|
theme_manager = ThemeManager(ThemeType.DARK)
|
|
|
|
# Generate CSS for the current theme
|
|
TUI_CSS = theme_manager.generate_css() # pyright: ignore[reportConstantRedefinition]
|
|
|
|
|
|
# Convenience functions for easy theme switching
|
|
def set_theme(theme_type: ThemeType) -> str:
|
|
"""Switch to a different theme and return the new CSS."""
|
|
theme_manager.set_theme(theme_type)
|
|
global TUI_CSS
|
|
TUI_CSS = theme_manager.generate_css() # pyright: ignore[reportConstantRedefinition]
|
|
return TUI_CSS
|
|
|
|
|
|
def get_available_themes() -> list[ThemeType]:
|
|
"""Get list of available themes."""
|
|
return list(ThemeType)
|
|
|
|
|
|
def get_current_theme() -> ThemeType:
|
|
"""Get the currently active theme."""
|
|
return theme_manager.current_theme
|
|
|
|
|
|
def get_theme_palette() -> ColorPalette:
|
|
"""Get the color palette for the current theme."""
|
|
return theme_manager.get_current_palette()
|
|
|
|
|
|
def get_css_for_theme(theme_type: ThemeType) -> str:
|
|
"""Get CSS for a specific theme without changing the current theme."""
|
|
current = theme_manager.current_theme
|
|
theme_manager.set_theme(theme_type)
|
|
css = theme_manager.generate_css()
|
|
theme_manager.set_theme(current) # Restore original theme
|
|
return css
|
|
|
|
|
|
def apply_theme_to_app(app: TextualApp | AppProtocol, theme_type: ThemeType) -> None:
|
|
"""Apply a theme to a Textual app instance."""
|
|
try:
|
|
# Note: CSS class variable cannot be changed at runtime
|
|
# This function would need to be called during app initialization
|
|
# or implement a different approach for dynamic theming
|
|
_ = set_theme(theme_type) # Keep for future implementation
|
|
if hasattr(app, "refresh"):
|
|
app.refresh()
|
|
except Exception as e:
|
|
# Graceful fallback - log but don't crash the UI
|
|
import logging
|
|
|
|
logging.debug(f"Failed to apply theme to app: {e}")
|
|
|
|
|
|
class ThemeSwitcher:
|
|
"""Helper class for managing theme switching in TUI applications."""
|
|
|
|
def __init__(self, app: TextualApp | AppProtocol | None = None) -> None:
|
|
self.app: TextualApp | AppProtocol | None = app
|
|
self.theme_history: list[ThemeType] = [ThemeType.DARK]
|
|
|
|
def switch_theme(self, theme_type: ThemeType) -> str:
|
|
"""Switch to a new theme and apply it to the app if available."""
|
|
css = set_theme(theme_type)
|
|
self.theme_history.append(theme_type)
|
|
|
|
if self.app:
|
|
apply_theme_to_app(self.app, theme_type)
|
|
|
|
return css
|
|
|
|
def toggle_dark_light(self) -> str:
|
|
"""Toggle between dark and light themes."""
|
|
current = get_current_theme()
|
|
if current in [ThemeType.DARK, ThemeType.GITHUB_DARK, ThemeType.HIGH_CONTRAST]:
|
|
return self.switch_theme(ThemeType.LIGHT)
|
|
else:
|
|
return self.switch_theme(ThemeType.DARK)
|
|
|
|
def cycle_themes(self) -> str:
|
|
"""Cycle through all available themes."""
|
|
themes = get_available_themes()
|
|
current = get_current_theme()
|
|
current_index = themes.index(current)
|
|
next_theme = themes[(current_index + 1) % len(themes)]
|
|
return self.switch_theme(next_theme)
|
|
|
|
def get_theme_info(self) -> dict[str, str | list[str] | dict[str, str]]:
|
|
"""Get information about the current theme."""
|
|
palette = get_theme_palette()
|
|
return {
|
|
"current_theme": get_current_theme().value,
|
|
"available_themes": [t.value for t in get_available_themes()],
|
|
"palette": {
|
|
"bg_primary": palette.bg_primary,
|
|
"text_primary": palette.text_primary,
|
|
"primary": palette.primary,
|
|
"contrast_info": "WCAG AA compliant colors",
|
|
},
|
|
}
|
|
|
|
|
|
# Responsive breakpoints for dynamic layout adaptation
|
|
RESPONSIVE_BREAKPOINTS = {
|
|
"xs": 40, # Extra small terminals
|
|
"sm": 60, # Small terminals
|
|
"md": 100, # Medium terminals
|
|
"lg": 140, # Large terminals
|
|
"xl": 180, # Extra large terminals
|
|
}
|
|
|
|
|
|
def get_responsive_css() -> str:
|
|
"""Generate responsive CSS with breakpoint-based adaptations."""
|
|
return """
|
|
/* Responsive Grid System */
|
|
.responsive-grid {
|
|
layout: grid;
|
|
grid-gutter: 1;
|
|
padding: 1;
|
|
}
|
|
|
|
.responsive-grid.auto-fit {
|
|
grid-columns: repeat(auto-fit, minmax(20, 1fr));
|
|
}
|
|
|
|
.responsive-grid.compact {
|
|
grid-gutter: 0;
|
|
padding: 0;
|
|
}
|
|
|
|
/* Breakpoint-specific styles */
|
|
@media (max-width: 60) {
|
|
.responsive-grid {
|
|
grid-size: 1;
|
|
grid-columns: 1fr;
|
|
}
|
|
|
|
.collapsible-sidebar {
|
|
width: 100%;
|
|
height: auto;
|
|
dock: top;
|
|
}
|
|
|
|
.form-row {
|
|
layout: vertical;
|
|
}
|
|
|
|
.form-label {
|
|
width: 100%;
|
|
text-align: left;
|
|
padding-bottom: 1;
|
|
}
|
|
|
|
.form-input {
|
|
width: 100%;
|
|
}
|
|
}
|
|
|
|
@media (min-width: 61) and (max-width: 100) {
|
|
.responsive-grid {
|
|
grid-size: 2;
|
|
grid-columns: 1fr 1fr;
|
|
}
|
|
}
|
|
|
|
@media (min-width: 101) {
|
|
.responsive-grid {
|
|
grid-size: 3;
|
|
grid-columns: 1fr 1fr 1fr;
|
|
}
|
|
}
|
|
|
|
/* Enhanced Layout Components */
|
|
.split-pane {
|
|
layout: horizontal;
|
|
height: 100%;
|
|
}
|
|
|
|
.split-pane.vertical {
|
|
layout: vertical;
|
|
}
|
|
|
|
.split-pane .pane {
|
|
background: $surface;
|
|
border: solid $border;
|
|
}
|
|
|
|
.split-pane .splitter {
|
|
width: 1;
|
|
background: $border;
|
|
cursor: col-resize;
|
|
}
|
|
|
|
.split-pane.vertical .splitter {
|
|
height: 1;
|
|
width: 100%;
|
|
cursor: row-resize;
|
|
}
|
|
|
|
/* Card Layout System */
|
|
.card-layout {
|
|
layout: grid;
|
|
grid-gutter: 2;
|
|
padding: 2;
|
|
}
|
|
|
|
.card {
|
|
background: $surface;
|
|
border: solid $border;
|
|
border-radius: 1;
|
|
padding: 2;
|
|
height: auto;
|
|
min-height: 10;
|
|
transition: border 200ms, background 200ms;
|
|
}
|
|
|
|
.card:hover {
|
|
border: solid $accent;
|
|
background: $surface-lighten-1;
|
|
}
|
|
|
|
.card:focus {
|
|
border: solid $primary;
|
|
box-shadow: 0 0 0 1 $primary-lighten-1;
|
|
}
|
|
|
|
.card-header {
|
|
dock: top;
|
|
height: 3;
|
|
background: $primary-lighten-1;
|
|
color: $text;
|
|
padding: 1;
|
|
margin: -2 -2 1 -2;
|
|
border-radius: 1 1 0 0;
|
|
}
|
|
|
|
.card-content {
|
|
height: 1fr;
|
|
overflow: auto;
|
|
}
|
|
|
|
.card-footer {
|
|
dock: bottom;
|
|
height: 3;
|
|
background: $surface-darken-1;
|
|
padding: 1;
|
|
margin: 1 -2 -2 -2;
|
|
border-radius: 0 0 1 1;
|
|
}
|
|
|
|
/* Collapsible Sidebar */
|
|
.collapsible-sidebar {
|
|
dock: left;
|
|
width: 25%;
|
|
min-width: 20;
|
|
max-width: 40;
|
|
background: $surface;
|
|
border-right: solid $border;
|
|
padding: 1;
|
|
transition: width 300ms ease-in-out;
|
|
}
|
|
|
|
.collapsible-sidebar.collapsed {
|
|
width: 3;
|
|
min-width: 3;
|
|
overflow: hidden;
|
|
}
|
|
|
|
.collapsible-sidebar.collapsed > * {
|
|
display: none;
|
|
}
|
|
|
|
.collapsible-sidebar .sidebar-toggle {
|
|
dock: top;
|
|
height: 1;
|
|
background: $primary;
|
|
color: $text;
|
|
text-align: center;
|
|
margin-bottom: 1;
|
|
cursor: pointer;
|
|
}
|
|
|
|
.collapsible-sidebar .sidebar-content {
|
|
height: 1fr;
|
|
overflow-y: auto;
|
|
}
|
|
|
|
/* Tabular Layout */
|
|
.tabular-layout {
|
|
layout: horizontal;
|
|
height: 100%;
|
|
}
|
|
|
|
.tabular-layout .main-content {
|
|
width: 1fr;
|
|
height: 100%;
|
|
layout: vertical;
|
|
}
|
|
|
|
.tabular-layout .table-container {
|
|
height: 1fr;
|
|
overflow: auto;
|
|
border: solid $border;
|
|
background: $surface;
|
|
}
|
|
|
|
.tabular-layout .table-header {
|
|
dock: top;
|
|
height: 3;
|
|
background: $primary;
|
|
color: $text;
|
|
padding: 1;
|
|
}
|
|
|
|
.tabular-layout .table-footer {
|
|
dock: bottom;
|
|
height: 3;
|
|
background: $surface-lighten-1;
|
|
padding: 1;
|
|
border-top: solid $border;
|
|
}
|
|
|
|
/* Form Styling Enhancements */
|
|
.form-container {
|
|
background: $surface;
|
|
border: solid $border;
|
|
padding: 2;
|
|
border-radius: 1;
|
|
}
|
|
|
|
.form-title {
|
|
color: $primary;
|
|
text-style: bold;
|
|
margin-bottom: 2;
|
|
text-align: center;
|
|
}
|
|
|
|
.form-section {
|
|
margin-bottom: 2;
|
|
padding: 1;
|
|
border: solid $border-lighten-1;
|
|
background: $surface-lighten-1;
|
|
border-radius: 1;
|
|
}
|
|
|
|
.section-title {
|
|
color: $primary;
|
|
text-style: bold;
|
|
margin-bottom: 1;
|
|
}
|
|
|
|
.form-row {
|
|
layout: horizontal;
|
|
align-items: center;
|
|
height: auto;
|
|
margin-bottom: 1;
|
|
}
|
|
|
|
.form-label {
|
|
width: 30%;
|
|
min-width: 15;
|
|
text-align: right;
|
|
padding-right: 2;
|
|
color: $text-secondary;
|
|
}
|
|
|
|
.form-input {
|
|
width: 70%;
|
|
}
|
|
|
|
.form-actions {
|
|
layout: horizontal;
|
|
align: center;
|
|
margin-top: 2;
|
|
padding-top: 2;
|
|
border-top: solid $border;
|
|
}
|
|
|
|
.form-actions Button {
|
|
margin: 0 1;
|
|
min-width: 10;
|
|
}
|
|
|
|
/* Button Enhancements */
|
|
Button {
|
|
transition: background 200ms, color 200ms;
|
|
}
|
|
|
|
Button:hover {
|
|
background: $primary-hover;
|
|
}
|
|
|
|
Button:focus {
|
|
border: solid $primary;
|
|
box-shadow: 0 0 0 1 $primary-lighten-1;
|
|
}
|
|
|
|
.button-group {
|
|
layout: horizontal;
|
|
align: center;
|
|
}
|
|
|
|
.button-group Button {
|
|
margin-right: 1;
|
|
}
|
|
|
|
.button-group Button:last-child {
|
|
margin-right: 0;
|
|
}
|
|
|
|
/* Data Table Enhancements */
|
|
DataTable {
|
|
border: solid $border;
|
|
background: $surface;
|
|
}
|
|
|
|
DataTable .datatable--header {
|
|
background: $primary;
|
|
color: $text;
|
|
text-style: bold;
|
|
}
|
|
|
|
DataTable .datatable--odd-row {
|
|
background: $surface-lighten-1;
|
|
}
|
|
|
|
DataTable .datatable--even-row {
|
|
background: $surface;
|
|
}
|
|
|
|
DataTable .datatable--cursor {
|
|
background: $primary-lighten-2;
|
|
color: $text;
|
|
}
|
|
|
|
/* Loading and Progress Indicators */
|
|
LoadingIndicator {
|
|
color: $primary;
|
|
background: transparent;
|
|
}
|
|
|
|
ProgressBar {
|
|
border: solid $border;
|
|
background: $surface-darken-1;
|
|
}
|
|
|
|
ProgressBar .bar--bar {
|
|
color: $primary;
|
|
}
|
|
|
|
ProgressBar .bar--percentage {
|
|
color: $text;
|
|
text-style: bold;
|
|
}
|
|
|
|
/* Modal and Dialog Styling */
|
|
.modal-container {
|
|
background: $surface;
|
|
border: thick $accent;
|
|
border-radius: 1;
|
|
padding: 2;
|
|
box-shadow: 0 4 8 0 rgba(0, 0, 0, 0.3);
|
|
}
|
|
|
|
.dialog-container {
|
|
background: $surface;
|
|
border: solid $border;
|
|
border-radius: 1;
|
|
padding: 2;
|
|
min-width: 40;
|
|
max-width: 80;
|
|
}
|
|
|
|
/* Animation Classes */
|
|
.fade-in {
|
|
opacity: 0;
|
|
transition: opacity 300ms ease-in;
|
|
}
|
|
|
|
.fade-in.visible {
|
|
opacity: 1;
|
|
}
|
|
|
|
.slide-in-left {
|
|
transform: translateX(-100%);
|
|
transition: transform 300ms ease-in-out;
|
|
}
|
|
|
|
.slide-in-left.visible {
|
|
transform: translateX(0);
|
|
}
|
|
|
|
.slide-in-right {
|
|
transform: translateX(100%);
|
|
transition: transform 300ms ease-in-out;
|
|
}
|
|
|
|
.slide-in-right.visible {
|
|
transform: translateX(0);
|
|
}
|
|
|
|
/* Accessibility Enhancements */
|
|
.screen-reader-only {
|
|
position: absolute;
|
|
width: 1px;
|
|
height: 1px;
|
|
padding: 0;
|
|
margin: -1px;
|
|
overflow: hidden;
|
|
clip: rect(0, 0, 0, 0);
|
|
border: 0;
|
|
}
|
|
|
|
.focus-visible {
|
|
outline: 2px solid $primary;
|
|
outline-offset: 2px;
|
|
}
|
|
|
|
/* Print Styles (for export functionality) */
|
|
@media print {
|
|
* {
|
|
background: white !important;
|
|
color: black !important;
|
|
}
|
|
|
|
.no-print {
|
|
display: none !important;
|
|
}
|
|
}
|
|
|
|
/* High Contrast Mode Support */
|
|
@media (prefers-contrast: high) {
|
|
* {
|
|
border-color: currentColor;
|
|
}
|
|
|
|
Button {
|
|
border: 2px solid currentColor;
|
|
}
|
|
|
|
Input, Select, TextArea {
|
|
border: 2px solid currentColor;
|
|
}
|
|
}
|
|
|
|
/* Dark Mode Detection */
|
|
@media (prefers-color-scheme: dark) {
|
|
:root {
|
|
--primary-color: #1f6feb;
|
|
--background-color: #0a0c10;
|
|
--text-color: #ffffff;
|
|
}
|
|
}
|
|
|
|
/* Light Mode Detection */
|
|
@media (prefers-color-scheme: light) {
|
|
:root {
|
|
--primary-color: #0969da;
|
|
--background-color: #ffffff;
|
|
--text-color: #1f2328;
|
|
}
|
|
}
|
|
"""
|
|
|
|
|
|
def get_css_custom_properties() -> str:
|
|
"""Generate CSS custom properties for dynamic theming."""
|
|
palette = get_theme_palette()
|
|
|
|
return f"""
|
|
:root {{
|
|
/* Color Palette */
|
|
--bg-primary: {palette.bg_primary};
|
|
--bg-secondary: {palette.bg_secondary};
|
|
--bg-tertiary: {palette.bg_tertiary};
|
|
--bg-elevated: {palette.bg_elevated};
|
|
|
|
--text-primary: {palette.text_primary};
|
|
--text-secondary: {palette.text_secondary};
|
|
--text-tertiary: {palette.text_tertiary};
|
|
--text-inverse: {palette.text_inverse};
|
|
|
|
--primary: {palette.primary};
|
|
--primary-hover: {palette.primary_hover};
|
|
--success: {palette.success};
|
|
--warning: {palette.warning};
|
|
--error: {palette.error};
|
|
--info: {palette.info};
|
|
|
|
--border-default: {palette.border_default};
|
|
--border-focus: {palette.border_focus};
|
|
--border-hover: {palette.border_hover};
|
|
|
|
--surface-1: {palette.surface_1};
|
|
--surface-2: {palette.surface_2};
|
|
--surface-3: {palette.surface_3};
|
|
|
|
/* Spacing Scale */
|
|
--space-xs: 0.25rem;
|
|
--space-sm: 0.5rem;
|
|
--space-md: 1rem;
|
|
--space-lg: 1.5rem;
|
|
--space-xl: 2rem;
|
|
|
|
/* Typography Scale */
|
|
--text-xs: 0.75rem;
|
|
--text-sm: 0.875rem;
|
|
--text-base: 1rem;
|
|
--text-lg: 1.125rem;
|
|
--text-xl: 1.25rem;
|
|
|
|
/* Border Radius */
|
|
--radius-sm: 0.25rem;
|
|
--radius-md: 0.5rem;
|
|
--radius-lg: 1rem;
|
|
|
|
/* Shadows */
|
|
--shadow-sm: 0 1px 2px rgba(0, 0, 0, 0.1);
|
|
--shadow-md: 0 4px 6px rgba(0, 0, 0, 0.1);
|
|
--shadow-lg: 0 10px 15px rgba(0, 0, 0, 0.1);
|
|
|
|
/* Transitions */
|
|
--transition-fast: 150ms ease-in-out;
|
|
--transition-normal: 250ms ease-in-out;
|
|
--transition-slow: 350ms ease-in-out;
|
|
}}
|
|
"""
|
|
|
|
|
|
def get_enhanced_dark_theme_css() -> str:
|
|
"""Generate CSS for the enhanced dark theme."""
|
|
theme_manager = ThemeManager(default_theme=ThemeType.DARK)
|
|
return theme_manager.generate_css()
|
|
|
|
|
|
def apply_responsive_theme() -> str:
|
|
"""Apply complete responsive theme with custom properties."""
|
|
base_css = get_enhanced_dark_theme_css()
|
|
responsive_css = get_responsive_css()
|
|
custom_properties = get_css_custom_properties()
|
|
|
|
return f"{custom_properties}\n{base_css}\n{responsive_css}"
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/flows/ingestion.py">
|
|
"""Prefect flow for ingestion pipeline."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable
|
|
from datetime import UTC, datetime
|
|
from typing import TYPE_CHECKING, Literal, TypeAlias, assert_never, cast
|
|
|
|
from prefect import flow, get_run_logger, task
|
|
from prefect.blocks.core import Block
|
|
from prefect.variables import Variable
|
|
from pydantic import SecretStr
|
|
|
|
from ..config.settings import Settings
|
|
from ..core.exceptions import IngestionError
|
|
from ..core.models import (
|
|
Document,
|
|
FirecrawlConfig,
|
|
IngestionJob,
|
|
IngestionResult,
|
|
IngestionSource,
|
|
IngestionStatus,
|
|
RepomixConfig,
|
|
StorageBackend,
|
|
StorageConfig,
|
|
)
|
|
from ..ingestors import BaseIngestor, FirecrawlIngestor, FirecrawlPage, RepomixIngestor
|
|
from ..storage import OpenWebUIStorage, WeaviateStorage
|
|
from ..storage import R2RStorage as RuntimeR2RStorage
|
|
from ..storage.base import BaseStorage
|
|
from ..utils.metadata_tagger import MetadataTagger
|
|
|
|
SourceTypeLiteral = Literal["web", "repository", "documentation"]
|
|
StorageBackendLiteral = Literal["weaviate", "open_webui", "r2r"]
|
|
SourceTypeLike: TypeAlias = IngestionSource | SourceTypeLiteral
|
|
StorageBackendLike: TypeAlias = StorageBackend | StorageBackendLiteral
|
|
|
|
|
|
def _safe_cache_key(prefix: str, params: dict[str, object], key: str) -> str:
|
|
"""Create a type-safe cache key from task parameters."""
|
|
value = params.get(key, "")
|
|
return f"{prefix}_{hash(str(value))}"
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from ..storage.r2r.storage import R2RStorage as R2RStorageType
|
|
else:
|
|
R2RStorageType = BaseStorage
|
|
|
|
|
|
@task(name="validate_source", retries=2, retry_delay_seconds=10, tags=["validation"])
|
|
async def validate_source_task(source_url: str, source_type: IngestionSource) -> bool:
|
|
"""
|
|
Validate that a source is accessible.
|
|
|
|
Args:
|
|
source_url: URL or path to source
|
|
source_type: Type of source
|
|
|
|
Returns:
|
|
True if valid
|
|
"""
|
|
if source_type == IngestionSource.WEB:
|
|
ingestor = FirecrawlIngestor()
|
|
elif source_type == IngestionSource.REPOSITORY:
|
|
ingestor = RepomixIngestor()
|
|
else:
|
|
raise ValueError(f"Unsupported source type: {source_type}")
|
|
|
|
result = await ingestor.validate_source(source_url)
|
|
return bool(result)
|
|
|
|
|
|
@task(name="initialize_storage", retries=3, retry_delay_seconds=5, tags=["storage"])
|
|
async def initialize_storage_task(config: StorageConfig | str) -> BaseStorage:
|
|
"""
|
|
Initialize storage backend.
|
|
|
|
Args:
|
|
config: Storage configuration block or block name
|
|
|
|
Returns:
|
|
Initialized storage adapter
|
|
"""
|
|
# Load block if string provided
|
|
if isinstance(config, str):
|
|
# Use Block.aload with type slug for better type inference
|
|
loaded_block = await Block.aload(f"storage-config/{config}")
|
|
config = cast(StorageConfig, loaded_block)
|
|
|
|
if config.backend == StorageBackend.WEAVIATE:
|
|
storage = WeaviateStorage(config)
|
|
elif config.backend == StorageBackend.OPEN_WEBUI:
|
|
storage = OpenWebUIStorage(config)
|
|
elif config.backend == StorageBackend.R2R:
|
|
if RuntimeR2RStorage is None:
|
|
raise ValueError("R2R storage not available. Check dependencies.")
|
|
storage = RuntimeR2RStorage(config)
|
|
else:
|
|
raise ValueError(f"Unsupported backend: {config.backend}")
|
|
|
|
await storage.initialize()
|
|
return storage
|
|
|
|
|
|
@task(
|
|
name="map_firecrawl_site",
|
|
retries=2,
|
|
retry_delay_seconds=15,
|
|
tags=["firecrawl", "map"],
|
|
cache_key_fn=lambda ctx, p: _safe_cache_key("firecrawl_map", p, "source_url"),
|
|
)
|
|
async def map_firecrawl_site_task(source_url: str, config: FirecrawlConfig | str) -> list[str]:
|
|
"""Map a site using Firecrawl and return discovered URLs."""
|
|
# Load block if string provided
|
|
if isinstance(config, str):
|
|
# Use Block.aload with type slug for better type inference
|
|
loaded_block = await Block.aload(f"firecrawl-config/{config}")
|
|
config = cast(FirecrawlConfig, loaded_block)
|
|
|
|
ingestor = FirecrawlIngestor(config)
|
|
mapped = await ingestor.map_site(source_url)
|
|
return mapped or [source_url]
|
|
|
|
|
|
@task(
|
|
name="filter_existing_documents",
|
|
retries=1,
|
|
retry_delay_seconds=5,
|
|
tags=["dedup"],
|
|
cache_key_fn=lambda ctx, p: _safe_cache_key("filter_docs", p, "urls"),
|
|
) # Cache based on URL list
|
|
async def filter_existing_documents_task(
|
|
urls: list[str],
|
|
storage_client: BaseStorage,
|
|
stale_after_days: int = 30,
|
|
*,
|
|
collection_name: str | None = None,
|
|
) -> list[str]:
|
|
"""Filter URLs to only those that need scraping (missing or stale in storage)."""
|
|
import asyncio
|
|
|
|
logger = get_run_logger()
|
|
|
|
# Use semaphore to limit concurrent existence checks
|
|
semaphore = asyncio.Semaphore(20)
|
|
|
|
async def check_url_exists(url: str) -> tuple[str, bool]:
|
|
async with semaphore:
|
|
try:
|
|
document_id = str(FirecrawlIngestor.compute_document_id(url))
|
|
exists = await storage_client.check_exists(
|
|
document_id, collection_name=collection_name, stale_after_days=stale_after_days
|
|
)
|
|
return url, exists
|
|
except Exception as e:
|
|
logger.warning("Error checking existence for URL %s: %s", url, e)
|
|
# Assume doesn't exist on error to ensure we scrape it
|
|
return url, False
|
|
|
|
# Check all URLs in parallel - use return_exceptions=True for partial failure handling
|
|
results = await asyncio.gather(*[check_url_exists(url) for url in urls], return_exceptions=True)
|
|
|
|
# Collect URLs that need scraping, handling any exceptions
|
|
eligible = []
|
|
for result in results:
|
|
if isinstance(result, Exception):
|
|
logger.error("Unexpected error in parallel existence check: %s", result)
|
|
continue
|
|
# Type narrowing: result is now known to be tuple[str, bool]
|
|
if isinstance(result, tuple) and len(result) == 2:
|
|
url, exists = result
|
|
if not exists:
|
|
eligible.append(url)
|
|
|
|
skipped = len(urls) - len(eligible)
|
|
if skipped > 0:
|
|
logger.info("Skipping %s up-to-date documents in %s", skipped, storage_client.display_name)
|
|
|
|
return eligible
|
|
|
|
|
|
@task(
|
|
name="scrape_firecrawl_batch", retries=2, retry_delay_seconds=20, tags=["firecrawl", "scrape"]
|
|
)
|
|
async def scrape_firecrawl_batch_task(
|
|
batch_urls: list[str], config: FirecrawlConfig
|
|
) -> list[FirecrawlPage]:
|
|
"""Scrape a batch of URLs via Firecrawl."""
|
|
ingestor = FirecrawlIngestor(config)
|
|
result: list[FirecrawlPage] = await ingestor.scrape_pages(batch_urls)
|
|
return result
|
|
|
|
|
|
@task(name="annotate_firecrawl_metadata", retries=1, retry_delay_seconds=10, tags=["metadata"])
|
|
async def annotate_firecrawl_metadata_task(
|
|
pages: list[FirecrawlPage], job: IngestionJob
|
|
) -> list[Document]:
|
|
"""Annotate scraped pages with standardized metadata."""
|
|
if not pages:
|
|
return []
|
|
|
|
ingestor = FirecrawlIngestor()
|
|
documents = [ingestor.create_document(page, job) for page in pages]
|
|
|
|
try:
|
|
from ..config import get_settings
|
|
|
|
settings = get_settings()
|
|
async with MetadataTagger(llm_endpoint=str(settings.llm_endpoint)) as tagger:
|
|
tagged_documents: list[Document] = await tagger.tag_batch(documents)
|
|
return tagged_documents
|
|
except IngestionError as exc: # pragma: no cover - logging side effect
|
|
logger = get_run_logger()
|
|
logger.warning("Metadata tagging failed: %s", exc)
|
|
return documents
|
|
except Exception as exc: # pragma: no cover - defensive
|
|
logger = get_run_logger()
|
|
logger.warning("Metadata tagging unavailable, using base metadata: %s", exc)
|
|
return documents
|
|
|
|
|
|
@task(name="upsert_r2r_documents", retries=2, retry_delay_seconds=20, tags=["storage", "r2r"])
|
|
async def upsert_r2r_documents_task(
|
|
storage_client: R2RStorageType,
|
|
documents: list[Document],
|
|
collection_name: str | None,
|
|
) -> tuple[int, int]:
|
|
"""Upsert documents into R2R storage."""
|
|
if not documents:
|
|
return 0, 0
|
|
|
|
stored_ids: list[str] = await storage_client.store_batch(
|
|
documents, collection_name=collection_name
|
|
)
|
|
processed = len(stored_ids)
|
|
failed = len(documents) - processed
|
|
|
|
if failed:
|
|
logger = get_run_logger()
|
|
logger.warning("Failed to upsert %s documents to R2R", failed)
|
|
|
|
return processed, failed
|
|
|
|
|
|
@task(name="ingest_documents", retries=2, retry_delay_seconds=30, tags=["ingestion"])
|
|
async def ingest_documents_task(
|
|
job: IngestionJob,
|
|
collection_name: str | None = None,
|
|
batch_size: int | None = None,
|
|
storage_client: BaseStorage | None = None,
|
|
storage_block_name: str | None = None,
|
|
ingestor_config_block_name: str | None = None,
|
|
progress_callback: Callable[[int, str], None] | None = None,
|
|
) -> tuple[int, int]:
|
|
"""
|
|
Ingest documents from source with optional pre-initialized storage client.
|
|
|
|
Args:
|
|
job: Ingestion job configuration
|
|
collection_name: Target collection name
|
|
batch_size: Number of documents per batch (uses Variable if None)
|
|
storage_client: Optional pre-initialized storage client
|
|
storage_block_name: Optional storage block name to load
|
|
ingestor_config_block_name: Optional ingestor config block name to load
|
|
progress_callback: Optional callback for progress updates
|
|
|
|
Returns:
|
|
Tuple of (processed_count, failed_count)
|
|
"""
|
|
if progress_callback:
|
|
progress_callback(35, "Creating ingestor and storage clients...")
|
|
|
|
# Use Variable for batch size if not provided
|
|
if batch_size is None:
|
|
try:
|
|
batch_size_var = await Variable.aget("default_batch_size", default="50")
|
|
# Convert Variable result to int, handling various types
|
|
if isinstance(batch_size_var, int):
|
|
batch_size = batch_size_var
|
|
elif isinstance(batch_size_var, (str, float)):
|
|
batch_size = int(float(str(batch_size_var)))
|
|
else:
|
|
batch_size = 50
|
|
except Exception:
|
|
batch_size = 50
|
|
|
|
ingestor = await _create_ingestor(job, ingestor_config_block_name)
|
|
storage = storage_client or await _create_storage(job, collection_name, storage_block_name)
|
|
|
|
if progress_callback:
|
|
progress_callback(40, "Starting document processing...")
|
|
|
|
return await _process_documents(
|
|
ingestor, storage, job, batch_size, collection_name, progress_callback
|
|
)
|
|
|
|
|
|
async def _create_ingestor(job: IngestionJob, config_block_name: str | None = None) -> BaseIngestor:
|
|
"""Create appropriate ingestor based on job source type."""
|
|
if job.source_type == IngestionSource.WEB:
|
|
if config_block_name:
|
|
# Use Block.aload with type slug for better type inference
|
|
loaded_block = await Block.aload(f"firecrawl-config/{config_block_name}")
|
|
config = cast(FirecrawlConfig, loaded_block)
|
|
else:
|
|
# Fallback to default configuration
|
|
config = FirecrawlConfig()
|
|
return FirecrawlIngestor(config)
|
|
elif job.source_type == IngestionSource.REPOSITORY:
|
|
if config_block_name:
|
|
# Use Block.aload with type slug for better type inference
|
|
loaded_block = await Block.aload(f"repomix-config/{config_block_name}")
|
|
config = cast(RepomixConfig, loaded_block)
|
|
else:
|
|
# Fallback to default configuration
|
|
config = RepomixConfig()
|
|
return RepomixIngestor(config)
|
|
else:
|
|
raise ValueError(f"Unsupported source: {job.source_type}")
|
|
|
|
|
|
async def _create_storage(
|
|
job: IngestionJob, collection_name: str | None, storage_block_name: str | None = None
|
|
) -> BaseStorage:
|
|
"""Create and initialize storage client."""
|
|
if collection_name is None:
|
|
# Use variable for default collection prefix
|
|
prefix = await Variable.aget("default_collection_prefix", default="docs")
|
|
collection_name = f"{prefix}_{job.source_type.value}"
|
|
|
|
if storage_block_name:
|
|
# Load storage config from block
|
|
loaded_block = await Block.aload(f"storage-config/{storage_block_name}")
|
|
storage_config = cast(StorageConfig, loaded_block)
|
|
# Override collection name if provided
|
|
storage_config.collection_name = collection_name
|
|
else:
|
|
# Fallback to building config from settings
|
|
from ..config import get_settings
|
|
|
|
settings = get_settings()
|
|
storage_config = _build_storage_config(job, settings, collection_name)
|
|
|
|
storage = _instantiate_storage(job.storage_backend, storage_config)
|
|
await storage.initialize()
|
|
return storage
|
|
|
|
|
|
def _build_storage_config(
|
|
job: IngestionJob, settings: Settings, collection_name: str
|
|
) -> StorageConfig:
|
|
"""Build storage configuration from job and settings."""
|
|
storage_endpoints = {
|
|
StorageBackend.WEAVIATE: settings.weaviate_endpoint,
|
|
StorageBackend.OPEN_WEBUI: settings.openwebui_endpoint,
|
|
StorageBackend.R2R: settings.get_storage_endpoint("r2r"),
|
|
}
|
|
storage_api_keys: dict[StorageBackend, str | None] = {
|
|
StorageBackend.WEAVIATE: settings.get_api_key("weaviate"),
|
|
StorageBackend.OPEN_WEBUI: settings.get_api_key("openwebui"),
|
|
StorageBackend.R2R: None, # R2R is self-hosted, no API key needed
|
|
}
|
|
|
|
api_key_raw: str | None = storage_api_keys[job.storage_backend]
|
|
api_key: SecretStr | None = SecretStr(api_key_raw) if api_key_raw is not None else None
|
|
|
|
return StorageConfig(
|
|
backend=job.storage_backend,
|
|
endpoint=storage_endpoints[job.storage_backend],
|
|
api_key=api_key,
|
|
collection_name=collection_name,
|
|
)
|
|
|
|
|
|
def _instantiate_storage(backend: StorageBackend, config: StorageConfig) -> BaseStorage:
|
|
"""Instantiate storage based on backend type."""
|
|
if backend == StorageBackend.WEAVIATE:
|
|
return WeaviateStorage(config)
|
|
elif backend == StorageBackend.OPEN_WEBUI:
|
|
return OpenWebUIStorage(config)
|
|
elif backend == StorageBackend.R2R:
|
|
if RuntimeR2RStorage is None:
|
|
raise ValueError("R2R storage not available. Check dependencies.")
|
|
return RuntimeR2RStorage(config)
|
|
|
|
assert_never(backend)
|
|
|
|
|
|
def _chunk_urls(urls: list[str], chunk_size: int) -> list[list[str]]:
|
|
"""Group URLs into fixed-size chunks for batch processing."""
|
|
|
|
if chunk_size <= 0:
|
|
raise ValueError("chunk_size must be greater than zero")
|
|
|
|
return [urls[i : i + chunk_size] for i in range(0, len(urls), chunk_size)]
|
|
|
|
|
|
def _deduplicate_urls(urls: list[str]) -> list[str]:
|
|
"""Return the URLs with order preserved and duplicates removed."""
|
|
|
|
seen: set[str] = set()
|
|
unique: list[str] = []
|
|
for url in urls:
|
|
if url not in seen:
|
|
seen.add(url)
|
|
unique.append(url)
|
|
return unique
|
|
|
|
|
|
async def _process_documents(
|
|
ingestor: BaseIngestor,
|
|
storage: BaseStorage,
|
|
job: IngestionJob,
|
|
batch_size: int,
|
|
collection_name: str | None,
|
|
progress_callback: Callable[[int, str], None] | None = None,
|
|
) -> tuple[int, int]:
|
|
"""Process documents in batches."""
|
|
processed = 0
|
|
failed = 0
|
|
batch: list[Document] = []
|
|
total_documents = 0
|
|
batch_count = 0
|
|
|
|
if progress_callback:
|
|
progress_callback(45, "Ingesting documents from source...")
|
|
|
|
# Use smart ingestion with deduplication if storage supports it
|
|
if hasattr(storage, "check_exists"):
|
|
try:
|
|
# Try to use the smart ingestion method
|
|
document_generator = ingestor.ingest_with_dedup(
|
|
job, storage, collection_name=collection_name
|
|
)
|
|
except Exception:
|
|
# Fall back to regular ingestion if smart method fails
|
|
document_generator = ingestor.ingest(job)
|
|
else:
|
|
document_generator = ingestor.ingest(job)
|
|
|
|
async for document in document_generator:
|
|
batch.append(document)
|
|
total_documents += 1
|
|
|
|
if len(batch) >= batch_size:
|
|
batch_count += 1
|
|
if progress_callback:
|
|
progress_callback(
|
|
45 + min(35, (batch_count * 10)),
|
|
f"Processing batch {batch_count} ({total_documents} documents so far)...",
|
|
)
|
|
|
|
batch_processed, batch_failed = await _store_batch(storage, batch, collection_name)
|
|
processed += batch_processed
|
|
failed += batch_failed
|
|
batch = []
|
|
|
|
# Process remaining batch
|
|
if batch:
|
|
batch_count += 1
|
|
if progress_callback:
|
|
progress_callback(80, f"Processing final batch ({total_documents} total documents)...")
|
|
|
|
batch_processed, batch_failed = await _store_batch(storage, batch, collection_name)
|
|
processed += batch_processed
|
|
failed += batch_failed
|
|
|
|
if progress_callback:
|
|
progress_callback(85, f"Completed processing {total_documents} documents")
|
|
|
|
return processed, failed
|
|
|
|
|
|
async def _store_batch(
|
|
storage: BaseStorage,
|
|
batch: list[Document],
|
|
collection_name: str | None,
|
|
) -> tuple[int, int]:
|
|
"""Store a batch of documents and return processed/failed counts."""
|
|
try:
|
|
# Apply metadata tagging for backends that benefit from it
|
|
processed_batch = batch
|
|
if hasattr(storage, "config") and storage.config.backend in (
|
|
StorageBackend.R2R,
|
|
StorageBackend.WEAVIATE,
|
|
):
|
|
try:
|
|
from ..config import get_settings
|
|
|
|
settings = get_settings()
|
|
async with MetadataTagger(llm_endpoint=str(settings.llm_endpoint)) as tagger:
|
|
processed_batch = await tagger.tag_batch(batch)
|
|
except Exception as exc:
|
|
print(f"Metadata tagging failed, using original documents: {exc}")
|
|
processed_batch = batch
|
|
|
|
stored_ids = await storage.store_batch(processed_batch, collection_name=collection_name)
|
|
processed_count = len(stored_ids)
|
|
failed_count = len(processed_batch) - processed_count
|
|
|
|
batch_type = (
|
|
"final" if len(processed_batch) < 50 else ""
|
|
) # Assume standard batch size is 50
|
|
print(f"Successfully stored {processed_count} documents in {batch_type} batch".strip())
|
|
|
|
return processed_count, failed_count
|
|
except Exception as e:
|
|
batch_type = "Final" if len(batch) < 50 else "Batch"
|
|
print(f"{batch_type} storage failed: {e}")
|
|
return 0, len(batch)
|
|
|
|
|
|
@flow(
|
|
name="firecrawl_to_r2r",
|
|
description="Ingest Firecrawl pages into R2R with metadata annotation",
|
|
persist_result=False,
|
|
log_prints=True,
|
|
)
|
|
async def firecrawl_to_r2r_flow(
|
|
job: IngestionJob,
|
|
collection_name: str | None = None,
|
|
progress_callback: Callable[[int, str], None] | None = None,
|
|
) -> tuple[int, int]:
|
|
"""Specialized flow for Firecrawl ingestion into R2R."""
|
|
logger = get_run_logger()
|
|
from ..config import get_settings
|
|
|
|
if progress_callback:
|
|
progress_callback(35, "Initializing Firecrawl and R2R storage...")
|
|
|
|
settings = get_settings()
|
|
firecrawl_config = FirecrawlConfig()
|
|
resolved_collection = collection_name or f"docs_{job.source_type.value}"
|
|
|
|
storage_config = _build_storage_config(job, settings, resolved_collection)
|
|
storage_client = await initialize_storage_task(storage_config)
|
|
|
|
if RuntimeR2RStorage is None or not isinstance(storage_client, RuntimeR2RStorage):
|
|
raise IngestionError("Firecrawl to R2R flow requires an R2R storage backend")
|
|
|
|
r2r_storage = cast("R2RStorageType", storage_client)
|
|
|
|
if progress_callback:
|
|
progress_callback(45, "Checking for existing content before mapping...")
|
|
|
|
# Smart mapping: try single URL first to avoid expensive map operation
|
|
base_url = str(job.source_url)
|
|
single_url_id = str(FirecrawlIngestor.compute_document_id(base_url))
|
|
base_exists = await r2r_storage.check_exists(
|
|
single_url_id, collection_name=resolved_collection, stale_after_days=30
|
|
)
|
|
|
|
if base_exists:
|
|
# Check if this is a recent single-page update
|
|
logger.info("Base URL %s exists and is fresh, skipping expensive mapping", base_url)
|
|
if progress_callback:
|
|
progress_callback(100, "Content is up to date, no processing needed")
|
|
return 0, 0
|
|
|
|
if progress_callback:
|
|
progress_callback(50, "Discovering pages with Firecrawl...")
|
|
|
|
discovered_urls = await map_firecrawl_site_task(base_url, firecrawl_config)
|
|
unique_urls = _deduplicate_urls(discovered_urls)
|
|
logger.info("Discovered %s unique URLs from Firecrawl map", len(unique_urls))
|
|
|
|
if progress_callback:
|
|
progress_callback(60, f"Found {len(unique_urls)} pages, filtering existing content...")
|
|
|
|
eligible_urls = await filter_existing_documents_task(
|
|
unique_urls, r2r_storage, collection_name=resolved_collection
|
|
)
|
|
|
|
if not eligible_urls:
|
|
logger.info("All Firecrawl pages are up to date for %s", job.source_url)
|
|
if progress_callback:
|
|
progress_callback(100, "All pages are up to date, no processing needed")
|
|
return 0, 0
|
|
|
|
if progress_callback:
|
|
progress_callback(70, f"Scraping {len(eligible_urls)} new/updated pages...")
|
|
|
|
batch_size = min(settings.default_batch_size, firecrawl_config.limit)
|
|
url_batches = _chunk_urls(eligible_urls, batch_size)
|
|
logger.info("Scraping %s batches of Firecrawl pages", len(url_batches))
|
|
|
|
# Use asyncio.gather for concurrent scraping
|
|
import asyncio
|
|
|
|
scrape_tasks = [scrape_firecrawl_batch_task(batch, firecrawl_config) for batch in url_batches]
|
|
batch_results = await asyncio.gather(*scrape_tasks)
|
|
|
|
scraped_pages: list[FirecrawlPage] = []
|
|
for batch_pages in batch_results:
|
|
scraped_pages.extend(batch_pages)
|
|
|
|
if progress_callback:
|
|
progress_callback(80, f"Processing {len(scraped_pages)} scraped pages...")
|
|
|
|
documents = await annotate_firecrawl_metadata_task(scraped_pages, job)
|
|
|
|
if not documents:
|
|
logger.warning("No documents produced after scraping for %s", job.source_url)
|
|
return 0, len(eligible_urls)
|
|
|
|
if progress_callback:
|
|
progress_callback(90, f"Storing {len(documents)} documents in R2R...")
|
|
|
|
processed, failed = await upsert_r2r_documents_task(r2r_storage, documents, resolved_collection)
|
|
|
|
logger.info("Upserted %s documents into R2R (%s failed)", processed, failed)
|
|
|
|
return processed, failed
|
|
|
|
|
|
@task(name="update_job_status", tags=["tracking"])
|
|
async def update_job_status_task(
|
|
job: IngestionJob,
|
|
status: IngestionStatus,
|
|
processed: int = 0,
|
|
_failed: int = 0,
|
|
error: str | None = None,
|
|
) -> IngestionJob:
|
|
"""
|
|
Update job status.
|
|
|
|
Args:
|
|
job: Ingestion job
|
|
status: New status
|
|
processed: Documents processed
|
|
_failed: Documents failed (currently unused)
|
|
error: Error message if any
|
|
|
|
Returns:
|
|
Updated job
|
|
"""
|
|
job.status = status
|
|
job.updated_at = datetime.now(UTC)
|
|
job.document_count = processed
|
|
|
|
if status == IngestionStatus.COMPLETED:
|
|
job.completed_at = datetime.now(UTC)
|
|
|
|
if error:
|
|
job.error_message = error
|
|
|
|
return job
|
|
|
|
|
|
@flow(
|
|
name="ingestion_pipeline",
|
|
description="Main ingestion pipeline for documents",
|
|
retries=1,
|
|
retry_delay_seconds=60,
|
|
persist_result=True,
|
|
log_prints=True,
|
|
)
|
|
async def create_ingestion_flow(
|
|
source_url: str,
|
|
source_type: SourceTypeLike,
|
|
storage_backend: StorageBackendLike = StorageBackend.WEAVIATE,
|
|
collection_name: str | None = None,
|
|
validate_first: bool = True,
|
|
progress_callback: Callable[[int, str], None] | None = None,
|
|
) -> IngestionResult:
|
|
"""
|
|
Main ingestion flow.
|
|
|
|
Args:
|
|
source_url: URL or path to source
|
|
source_type: Type of source
|
|
storage_backend: Storage backend to use
|
|
validate_first: Whether to validate source first
|
|
progress_callback: Optional callback for progress updates
|
|
|
|
Returns:
|
|
Ingestion result
|
|
"""
|
|
print(f"Starting ingestion from {source_url}")
|
|
|
|
source_enum = IngestionSource(source_type)
|
|
backend_enum = StorageBackend(storage_backend)
|
|
|
|
# Create job
|
|
job = IngestionJob(
|
|
source_url=source_url,
|
|
source_type=source_enum,
|
|
storage_backend=backend_enum,
|
|
status=IngestionStatus.PENDING,
|
|
)
|
|
|
|
start_time = datetime.now(UTC)
|
|
error_messages: list[str] = []
|
|
processed = 0
|
|
failed = 0
|
|
|
|
try:
|
|
# Validate source if requested
|
|
if validate_first:
|
|
if progress_callback:
|
|
progress_callback(10, "Validating source...")
|
|
print("Validating source...")
|
|
is_valid = await validate_source_task(source_url, job.source_type)
|
|
|
|
if not is_valid:
|
|
raise IngestionError(f"Source validation failed: {source_url}")
|
|
|
|
# Update status to in progress
|
|
if progress_callback:
|
|
progress_callback(20, "Initializing storage...")
|
|
job = await update_job_status_task(job, IngestionStatus.IN_PROGRESS)
|
|
|
|
# Run ingestion
|
|
if progress_callback:
|
|
progress_callback(30, "Starting document ingestion...")
|
|
print("Ingesting documents...")
|
|
if job.source_type == IngestionSource.WEB and job.storage_backend == StorageBackend.R2R:
|
|
processed, failed = await firecrawl_to_r2r_flow(
|
|
job, collection_name, progress_callback=progress_callback
|
|
)
|
|
else:
|
|
processed, failed = await ingest_documents_task(
|
|
job, collection_name, progress_callback=progress_callback
|
|
)
|
|
|
|
if progress_callback:
|
|
progress_callback(90, "Finalizing ingestion...")
|
|
|
|
# Update final status
|
|
if failed > 0:
|
|
error_messages.append(f"{failed} documents failed to process")
|
|
|
|
# Set status based on results
|
|
if processed == 0 and failed > 0:
|
|
final_status = IngestionStatus.FAILED
|
|
elif failed > 0:
|
|
final_status = IngestionStatus.PARTIAL
|
|
else:
|
|
final_status = IngestionStatus.COMPLETED
|
|
|
|
job = await update_job_status_task(job, final_status, processed=processed, _failed=failed)
|
|
|
|
print(f"Ingestion completed: {processed} processed, {failed} failed")
|
|
|
|
except Exception as e:
|
|
print(f"Ingestion failed: {e}")
|
|
error_messages.append(str(e))
|
|
|
|
# Don't reset counts - keep whatever was processed before the error
|
|
job = await update_job_status_task(
|
|
job, IngestionStatus.FAILED, processed=processed, _failed=failed, error=str(e)
|
|
)
|
|
|
|
# Calculate duration
|
|
duration = (datetime.now(UTC) - start_time).total_seconds()
|
|
|
|
return IngestionResult(
|
|
job_id=job.id,
|
|
status=job.status,
|
|
documents_processed=processed,
|
|
documents_failed=failed,
|
|
duration_seconds=duration,
|
|
error_messages=error_messages,
|
|
)
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/flows/scheduler.py">
|
|
"""Scheduler for Prefect deployments."""
|
|
|
|
from datetime import timedelta
|
|
from typing import Literal, Protocol, cast
|
|
|
|
from prefect.deployments.runner import RunnerDeployment
|
|
from prefect.flows import serve as prefect_serve
|
|
from prefect.schedules import Cron, Interval
|
|
from prefect.variables import Variable
|
|
|
|
from ..core.models import IngestionSource, StorageBackend
|
|
from .ingestion import SourceTypeLike, StorageBackendLike, create_ingestion_flow
|
|
|
|
|
|
class FlowWithDeployment(Protocol):
|
|
"""Protocol for flows that have deployment methods."""
|
|
|
|
def to_deployment(
|
|
self,
|
|
name: str,
|
|
**kwargs: object,
|
|
) -> RunnerDeployment:
|
|
"""Create a deployment from this flow."""
|
|
...
|
|
|
|
|
|
def create_scheduled_deployment(
|
|
name: str,
|
|
source_url: str,
|
|
source_type: SourceTypeLike,
|
|
storage_backend: StorageBackendLike = StorageBackend.WEAVIATE,
|
|
schedule_type: Literal["cron", "interval"] = "interval",
|
|
cron_expression: str | None = None,
|
|
interval_minutes: int | None = None,
|
|
tags: list[str] | None = None,
|
|
storage_block_name: str | None = None,
|
|
ingestor_config_block_name: str | None = None,
|
|
) -> RunnerDeployment:
|
|
"""
|
|
Create a scheduled deployment for ingestion with block support.
|
|
|
|
Args:
|
|
name: Deployment name
|
|
source_url: Source to ingest from
|
|
source_type: Type of source
|
|
storage_backend: Storage backend
|
|
schedule_type: Type of schedule
|
|
cron_expression: Cron expression if using cron
|
|
interval_minutes: Interval in minutes (uses Variable if None)
|
|
tags: Optional tags for deployment
|
|
storage_block_name: Optional storage block name
|
|
ingestor_config_block_name: Optional ingestor config block name
|
|
|
|
Returns:
|
|
Deployment configuration
|
|
"""
|
|
# Use Variable for interval if not provided
|
|
if interval_minutes is None:
|
|
try:
|
|
interval_var = Variable.get("default_schedule_interval", default="60")
|
|
# Convert Variable result to int, handling various types
|
|
if isinstance(interval_var, int):
|
|
interval_minutes = interval_var
|
|
elif isinstance(interval_var, (str, float)):
|
|
interval_minutes = int(float(str(interval_var)))
|
|
else:
|
|
interval_minutes = 60
|
|
except Exception:
|
|
interval_minutes = 60
|
|
|
|
# Create schedule
|
|
if schedule_type == "cron" and cron_expression:
|
|
schedule = Cron(cron_expression, timezone="UTC")
|
|
else:
|
|
schedule = Interval(timedelta(minutes=interval_minutes), timezone="UTC")
|
|
|
|
# Default tags
|
|
source_enum = IngestionSource(source_type)
|
|
backend_enum = StorageBackend(storage_backend)
|
|
|
|
if tags is None:
|
|
tags = [source_enum.value, backend_enum.value]
|
|
|
|
# Create deployment parameters with block support
|
|
parameters: dict[str, str | bool] = {
|
|
"source_url": source_url,
|
|
"source_type": source_enum.value,
|
|
"storage_backend": backend_enum.value,
|
|
"validate_first": True,
|
|
}
|
|
|
|
# Add block names if provided
|
|
if storage_block_name:
|
|
parameters["storage_block_name"] = storage_block_name
|
|
if ingestor_config_block_name:
|
|
parameters["ingestor_config_block_name"] = ingestor_config_block_name
|
|
|
|
# Create deployment
|
|
# The flow decorator adds the to_deployment method at runtime
|
|
flow_with_deployment = cast(FlowWithDeployment, create_ingestion_flow)
|
|
return flow_with_deployment.to_deployment(
|
|
name=name,
|
|
schedule=schedule,
|
|
parameters=parameters,
|
|
tags=tags,
|
|
description=f"Scheduled ingestion from {source_url}",
|
|
)
|
|
|
|
|
|
def serve_deployments(deployments: list[RunnerDeployment]) -> None:
|
|
"""
|
|
Serve multiple deployments.
|
|
|
|
Args:
|
|
deployments: List of deployment configurations
|
|
"""
|
|
prefect_serve(*deployments, limit=10)
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/ingestors/firecrawl.py">
|
|
"""Firecrawl ingestor for web and documentation sites."""
|
|
|
|
import asyncio
|
|
import logging
|
|
import re
|
|
from collections.abc import AsyncGenerator, Awaitable, Callable
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime
|
|
from typing import TYPE_CHECKING, Protocol, cast
|
|
from urllib.parse import urlparse
|
|
from uuid import NAMESPACE_URL, UUID, uuid5
|
|
|
|
from firecrawl import AsyncFirecrawl
|
|
from typing_extensions import override
|
|
|
|
from ..config import get_settings
|
|
from ..core.exceptions import IngestionError
|
|
from ..core.models import (
|
|
Document,
|
|
DocumentMetadata,
|
|
FirecrawlConfig,
|
|
IngestionJob,
|
|
IngestionSource,
|
|
)
|
|
from .base import BaseIngestor
|
|
|
|
if TYPE_CHECKING:
|
|
from ..storage.base import BaseStorage
|
|
|
|
|
|
class FirecrawlMetadata(Protocol):
|
|
"""Protocol for Firecrawl metadata objects."""
|
|
|
|
title: str | None
|
|
description: str | None
|
|
author: str | None
|
|
language: str | None
|
|
sitemap_last_modified: str | None
|
|
sourceURL: str | None
|
|
keywords: str | list[str] | None
|
|
robots: str | None
|
|
ogTitle: str | None
|
|
ogDescription: str | None
|
|
ogUrl: str | None
|
|
ogImage: str | None
|
|
twitterCard: str | None
|
|
twitterSite: str | None
|
|
twitterCreator: str | None
|
|
favicon: str | None
|
|
statusCode: int | None
|
|
|
|
|
|
class FirecrawlResult(Protocol):
|
|
"""Protocol for Firecrawl scrape result objects."""
|
|
|
|
metadata: FirecrawlMetadata | None
|
|
markdown: str | None
|
|
|
|
|
|
class FirecrawlMapLink(Protocol):
|
|
"""Protocol for Firecrawl map link objects."""
|
|
|
|
url: str
|
|
|
|
|
|
class FirecrawlMapResult(Protocol):
|
|
"""Protocol for Firecrawl map result objects."""
|
|
|
|
links: list[FirecrawlMapLink] | None
|
|
|
|
|
|
class AsyncFirecrawlSession(Protocol):
|
|
"""Protocol for AsyncFirecrawl session objects."""
|
|
|
|
async def close(self) -> None: ...
|
|
|
|
|
|
class AsyncFirecrawlClient(Protocol):
|
|
"""Protocol for AsyncFirecrawl client objects."""
|
|
|
|
_session: AsyncFirecrawlSession | None
|
|
|
|
async def close(self) -> None: ...
|
|
|
|
async def scrape(self, url: str, formats: list[str]) -> FirecrawlResult: ...
|
|
|
|
async def map(self, url: str, limit: int | None = None) -> "FirecrawlMapResult": ...
|
|
|
|
|
|
class FirecrawlError(IngestionError):
|
|
"""Base exception for Firecrawl-related errors."""
|
|
|
|
status_code: int | None
|
|
|
|
def __init__(self, message: str, status_code: int | None = None) -> None:
|
|
super().__init__(message)
|
|
self.status_code = status_code
|
|
|
|
|
|
class FirecrawlConnectionError(FirecrawlError):
|
|
"""Connection error with Firecrawl service."""
|
|
|
|
pass
|
|
|
|
|
|
class FirecrawlRateLimitError(FirecrawlError):
|
|
"""Rate limit exceeded error."""
|
|
|
|
pass
|
|
|
|
|
|
class FirecrawlUnauthorizedError(FirecrawlError):
|
|
"""Unauthorized access error."""
|
|
|
|
pass
|
|
|
|
|
|
async def retry_with_backoff(
|
|
operation: Callable[[], Awaitable[object]], max_retries: int = 3
|
|
) -> object:
|
|
"""Retry operation with exponential backoff following Firecrawl best practices."""
|
|
for attempt in range(max_retries):
|
|
try:
|
|
return await operation()
|
|
except Exception as e:
|
|
if attempt == max_retries - 1:
|
|
raise e
|
|
delay: float = 1.0 * (2**attempt)
|
|
logging.warning(
|
|
f"Firecrawl operation failed (attempt {attempt + 1}/{max_retries}): {e}. Retrying in {delay:.1f}s..."
|
|
)
|
|
await asyncio.sleep(delay)
|
|
|
|
# This should never be reached due to the exception handling above,
|
|
# but mypy requires a return statement for all code paths
|
|
raise RuntimeError("Retry loop completed without return or exception")
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class FirecrawlPage:
|
|
"""Structured representation of a scraped Firecrawl page."""
|
|
|
|
url: str
|
|
content: str
|
|
title: str | None
|
|
description: str | None
|
|
author: str | None = None
|
|
language: str | None = None
|
|
sitemap_last_modified: str | None = None
|
|
source_url: str | None = None
|
|
keywords: list[str] | None = None
|
|
robots: str | None = None
|
|
og_title: str | None = None
|
|
og_description: str | None = None
|
|
og_url: str | None = None
|
|
og_image: str | None = None
|
|
twitter_card: str | None = None
|
|
twitter_site: str | None = None
|
|
twitter_creator: str | None = None
|
|
favicon: str | None = None
|
|
status_code: int | None = None
|
|
|
|
|
|
class FirecrawlIngestor(BaseIngestor):
|
|
"""Ingestor for web and documentation sites using Firecrawl."""
|
|
|
|
config: FirecrawlConfig
|
|
client: AsyncFirecrawlClient
|
|
|
|
def __init__(self, config: FirecrawlConfig | None = None):
|
|
"""
|
|
Initialize Firecrawl ingestor.
|
|
|
|
Args:
|
|
config: Firecrawl configuration (for operational params only)
|
|
"""
|
|
self.config = config or FirecrawlConfig()
|
|
settings = get_settings()
|
|
|
|
# All connection details come from settings/.env
|
|
# For self-hosted instances, use a dummy API key if none is provided
|
|
# The SDK requires an API key even for self-hosted instances
|
|
api_key = settings.firecrawl_api_key or "no-key-required"
|
|
|
|
# Initialize AsyncFirecrawl following official pattern
|
|
# Note: api_url parameter may not be supported in all versions
|
|
# Default to standard initialization for cloud instances
|
|
try:
|
|
endpoint_str = str(settings.firecrawl_endpoint).rstrip("/")
|
|
if endpoint_str.startswith("http://crawl.lab") or endpoint_str.startswith(
|
|
"http://localhost"
|
|
):
|
|
# Self-hosted instance - try with api_url if supported
|
|
self.client = cast(
|
|
AsyncFirecrawlClient,
|
|
AsyncFirecrawl(api_key=api_key, api_url=str(settings.firecrawl_endpoint)),
|
|
)
|
|
else:
|
|
# Cloud instance - use standard initialization
|
|
self.client = cast(AsyncFirecrawlClient, AsyncFirecrawl(api_key=api_key))
|
|
except Exception:
|
|
# Fallback to standard initialization
|
|
self.client = cast(AsyncFirecrawlClient, AsyncFirecrawl(api_key=api_key))
|
|
|
|
@override
|
|
async def ingest(self, job: IngestionJob) -> AsyncGenerator[Document, None]:
|
|
"""
|
|
Ingest documents from a web source.
|
|
|
|
Args:
|
|
job: The ingestion job configuration
|
|
|
|
Yields:
|
|
Documents from the web source
|
|
"""
|
|
url = str(job.source_url)
|
|
|
|
# First, map the site to understand its structure
|
|
site_map = await self.map_site(url) or [url]
|
|
|
|
# Process pages in batches
|
|
batch_size = 10
|
|
for i in range(0, len(site_map), batch_size):
|
|
batch_urls = site_map[i : i + batch_size]
|
|
pages = await self.scrape_pages(batch_urls)
|
|
|
|
for page in pages:
|
|
yield self.create_document(page, job)
|
|
|
|
async def ingest_with_dedup(
|
|
self,
|
|
job: IngestionJob,
|
|
storage_client: "BaseStorage",
|
|
*,
|
|
collection_name: str | None = None,
|
|
stale_after_days: int = 30,
|
|
) -> AsyncGenerator[Document, None]:
|
|
"""
|
|
Ingest documents with duplicate detection to avoid unnecessary scraping.
|
|
|
|
Args:
|
|
job: The ingestion job configuration
|
|
storage_client: Storage client to check for existing documents
|
|
collection_name: Collection to check for duplicates
|
|
stale_after_days: Consider documents stale after this many days
|
|
|
|
Yields:
|
|
Documents from the web source (only new/stale ones)
|
|
"""
|
|
url = str(job.source_url)
|
|
|
|
# First, map the site to understand its structure
|
|
site_map = await self.map_site(url) or [url]
|
|
|
|
# Filter out URLs that already exist in storage and are fresh
|
|
eligible_urls: list[str] = []
|
|
for check_url in site_map:
|
|
document_id = str(self.compute_document_id(check_url))
|
|
exists = await storage_client.check_exists(
|
|
document_id, collection_name=collection_name, stale_after_days=stale_after_days
|
|
)
|
|
if not exists:
|
|
eligible_urls.append(check_url)
|
|
|
|
if not eligible_urls:
|
|
return # No new documents to scrape
|
|
|
|
# Process eligible pages in batches
|
|
batch_size = 10
|
|
for i in range(0, len(eligible_urls), batch_size):
|
|
batch_urls = eligible_urls[i : i + batch_size]
|
|
pages = await self.scrape_pages(batch_urls)
|
|
|
|
for page in pages:
|
|
yield self.create_document(page, job)
|
|
|
|
async def map_site(self, url: str) -> list[str]:
|
|
"""Public wrapper for mapping a site."""
|
|
|
|
return await self._map_site(url)
|
|
|
|
async def scrape_pages(self, urls: list[str]) -> list[FirecrawlPage]:
|
|
"""Scrape a batch of URLs and return structured page data."""
|
|
|
|
return await self._scrape_batch(urls)
|
|
|
|
@override
|
|
async def validate_source(self, source_url: str) -> bool:
|
|
"""
|
|
Validate if the web source is accessible.
|
|
|
|
Args:
|
|
source_url: URL to validate
|
|
|
|
Returns:
|
|
True if source is accessible
|
|
"""
|
|
try:
|
|
# Use SDK v2 endpoints following official pattern with retry
|
|
async def validate_operation() -> bool:
|
|
result = await self.client.scrape(source_url, formats=["markdown"])
|
|
return result is not None and getattr(result, "markdown", None) is not None
|
|
|
|
result = await retry_with_backoff(validate_operation)
|
|
return bool(result)
|
|
except Exception as e:
|
|
logging.warning(f"Failed to validate source {source_url}: {e}")
|
|
return False
|
|
|
|
@override
|
|
async def estimate_size(self, source_url: str) -> int:
|
|
"""
|
|
Estimate the number of pages in the website.
|
|
|
|
Args:
|
|
source_url: URL of the website
|
|
|
|
Returns:
|
|
Estimated number of pages
|
|
"""
|
|
try:
|
|
site_map = await self._map_site(source_url)
|
|
return len(site_map) if site_map else 0
|
|
except Exception as e:
|
|
logging.warning(f"Failed to estimate size for {source_url}: {e}")
|
|
return 0
|
|
|
|
async def _map_site(self, url: str) -> list[str]:
|
|
"""
|
|
Map a website to get all URLs.
|
|
|
|
Args:
|
|
url: Base URL to map
|
|
|
|
Returns:
|
|
List of URLs found
|
|
"""
|
|
try:
|
|
# Use SDK v2 map endpoint following official pattern
|
|
result: FirecrawlMapResult = await self.client.map(url=url, limit=self.config.limit)
|
|
|
|
if result and result.links:
|
|
# Extract URLs from the result following official pattern
|
|
return [link.url for link in result.links]
|
|
return []
|
|
except Exception as e:
|
|
# If map fails (might not be available in all versions), fall back to single URL
|
|
logging.warning(f"Map endpoint not available or failed: {e}. Using single URL.")
|
|
return [url]
|
|
|
|
async def _scrape_batch(self, urls: list[str]) -> list[FirecrawlPage]:
|
|
"""
|
|
Scrape a batch of URLs.
|
|
|
|
Args:
|
|
urls: List of URLs to scrape
|
|
|
|
Returns:
|
|
List of scraped documents
|
|
"""
|
|
tasks = [self._scrape_single(url) for url in urls]
|
|
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
pages: list[FirecrawlPage] = []
|
|
for result in results:
|
|
if isinstance(result, FirecrawlPage):
|
|
pages.append(result)
|
|
elif isinstance(result, BaseException):
|
|
continue
|
|
|
|
return pages
|
|
|
|
async def _scrape_single(self, url: str) -> FirecrawlPage | None:
|
|
"""
|
|
Scrape a single URL and extract rich metadata.
|
|
|
|
Args:
|
|
url: URL to scrape
|
|
|
|
Returns:
|
|
Scraped document data with enhanced metadata
|
|
"""
|
|
try:
|
|
# Use SDK v2 scrape endpoint following official pattern with retry
|
|
async def scrape_operation() -> FirecrawlPage | None:
|
|
result: FirecrawlResult = await self.client.scrape(url, formats=self.config.formats)
|
|
|
|
# Extract data from the result following official response handling
|
|
if result:
|
|
# The SDK returns a ScrapeData object with typed metadata
|
|
metadata: FirecrawlMetadata | None = getattr(result, "metadata", None)
|
|
|
|
# Extract basic metadata
|
|
title: str | None = getattr(metadata, "title", None) if metadata else None
|
|
description: str | None = (
|
|
getattr(metadata, "description", None) if metadata else None
|
|
)
|
|
|
|
# Extract enhanced metadata if available
|
|
author: str | None = getattr(metadata, "author", None) if metadata else None
|
|
language: str | None = getattr(metadata, "language", None) if metadata else None
|
|
sitemap_last_modified: str | None = (
|
|
getattr(metadata, "sitemap_last_modified", None) if metadata else None
|
|
)
|
|
source_url: str | None = (
|
|
getattr(metadata, "sourceURL", None) if metadata else None
|
|
)
|
|
keywords: str | list[str] | None = (
|
|
getattr(metadata, "keywords", None) if metadata else None
|
|
)
|
|
robots: str | None = getattr(metadata, "robots", None) if metadata else None
|
|
|
|
# Open Graph metadata
|
|
og_title: str | None = getattr(metadata, "ogTitle", None) if metadata else None
|
|
og_description: str | None = (
|
|
getattr(metadata, "ogDescription", None) if metadata else None
|
|
)
|
|
og_url: str | None = getattr(metadata, "ogUrl", None) if metadata else None
|
|
og_image: str | None = getattr(metadata, "ogImage", None) if metadata else None
|
|
|
|
# Twitter metadata
|
|
twitter_card: str | None = (
|
|
getattr(metadata, "twitterCard", None) if metadata else None
|
|
)
|
|
twitter_site: str | None = (
|
|
getattr(metadata, "twitterSite", None) if metadata else None
|
|
)
|
|
twitter_creator: str | None = (
|
|
getattr(metadata, "twitterCreator", None) if metadata else None
|
|
)
|
|
|
|
# Additional metadata
|
|
favicon: str | None = getattr(metadata, "favicon", None) if metadata else None
|
|
status_code: int | None = (
|
|
getattr(metadata, "statusCode", None) if metadata else None
|
|
)
|
|
|
|
return FirecrawlPage(
|
|
url=url,
|
|
content=getattr(result, "markdown", "") or "",
|
|
title=title,
|
|
description=description,
|
|
author=author,
|
|
language=language,
|
|
sitemap_last_modified=sitemap_last_modified,
|
|
source_url=source_url,
|
|
keywords=keywords.split(",")
|
|
if keywords and isinstance(keywords, str)
|
|
else (keywords if isinstance(keywords, list) else None),
|
|
robots=robots,
|
|
og_title=og_title,
|
|
og_description=og_description,
|
|
og_url=og_url,
|
|
og_image=og_image,
|
|
twitter_card=twitter_card,
|
|
twitter_site=twitter_site,
|
|
twitter_creator=twitter_creator,
|
|
favicon=favicon,
|
|
status_code=status_code,
|
|
)
|
|
return None
|
|
|
|
result = await retry_with_backoff(scrape_operation)
|
|
return result if isinstance(result, FirecrawlPage) else None
|
|
except Exception as e:
|
|
logging.debug(f"Failed to scrape {url}: {e}")
|
|
return None
|
|
|
|
@staticmethod
|
|
def compute_document_id(source_url: str) -> UUID:
|
|
"""Derive a deterministic UUID for a document based on its source URL."""
|
|
return uuid5(NAMESPACE_URL, source_url)
|
|
|
|
@staticmethod
|
|
def _analyze_content_structure(content: str) -> dict[str, str | int | bool | list[str]]:
|
|
"""Analyze markdown content to extract structural information."""
|
|
# Extract heading hierarchy
|
|
heading_pattern = r"^(#{1,6})\s+(.+)$"
|
|
headings: list[str] = []
|
|
for match in re.finditer(heading_pattern, content, re.MULTILINE):
|
|
level = len(match.group(1))
|
|
text = match.group(2).strip()
|
|
headings.append(f"{' ' * (level - 1)}{text}")
|
|
|
|
# Check for various content types
|
|
has_code_blocks = bool(re.search(r"```[\s\S]*?```", content))
|
|
has_images = bool(re.search(r"!\[.*?\]\(.*?\)", content))
|
|
has_links = bool(re.search(r"\[.*?\]\(.*?\)", content))
|
|
|
|
# Calculate section depth
|
|
max_depth = 0
|
|
if headings:
|
|
for heading in headings:
|
|
heading_str: str = str(heading)
|
|
depth = (len(heading_str) - len(heading_str.lstrip())) // 2 + 1
|
|
max_depth = max(max_depth, depth)
|
|
|
|
return {
|
|
"heading_hierarchy": headings,
|
|
"section_depth": max_depth,
|
|
"has_code_blocks": has_code_blocks,
|
|
"has_images": has_images,
|
|
"has_links": has_links,
|
|
}
|
|
|
|
@staticmethod
|
|
def _calculate_content_quality(content: str, title: str | None) -> dict[str, float | None]:
|
|
"""Calculate basic content quality metrics."""
|
|
if not content:
|
|
return {"readability_score": None, "completeness_score": None}
|
|
|
|
# Simple readability approximation (Flesch-like)
|
|
sentences = len(re.findall(r"[.!?]+", content))
|
|
words = len(content.split())
|
|
|
|
if sentences == 0 or words == 0:
|
|
readability_score = None
|
|
else:
|
|
avg_sentence_length = words / sentences
|
|
# Simplified readability score (0-100, higher is more readable)
|
|
readability_score = max(0.0, min(100.0, 100.0 - (avg_sentence_length - 15.0) * 2.0))
|
|
|
|
# Completeness score based on structure
|
|
completeness_factors = 0
|
|
total_factors = 5
|
|
|
|
if title:
|
|
completeness_factors += 1
|
|
if len(content) > 500:
|
|
completeness_factors += 1
|
|
if re.search(r"^#{1,6}\s+", content, re.MULTILINE):
|
|
completeness_factors += 1
|
|
if len(content.split()) > 100:
|
|
completeness_factors += 1
|
|
if not re.search(r"(error|404|not found|page not found)", content, re.IGNORECASE):
|
|
completeness_factors += 1
|
|
|
|
completeness_score = (completeness_factors / total_factors) * 100
|
|
|
|
return {
|
|
"readability_score": readability_score,
|
|
"completeness_score": completeness_score,
|
|
}
|
|
|
|
@staticmethod
|
|
def _extract_domain_info(url: str) -> dict[str, str]:
|
|
"""Extract domain and site information from URL."""
|
|
parsed = urlparse(url)
|
|
domain = parsed.netloc.lower()
|
|
|
|
# Remove www. prefix
|
|
if domain.startswith("www."):
|
|
domain = domain[4:]
|
|
|
|
# Extract site name from domain
|
|
domain_parts = domain.split(".")
|
|
site_name = domain_parts[0].replace("-", " ").replace("_", " ").title()
|
|
|
|
return {
|
|
"domain": domain,
|
|
"site_name": site_name,
|
|
}
|
|
|
|
def create_document(self, page: FirecrawlPage, job: IngestionJob) -> Document:
|
|
"""
|
|
Create a Document from scraped data with enriched metadata.
|
|
|
|
Args:
|
|
page: Scraped document data
|
|
job: The ingestion job
|
|
|
|
Returns:
|
|
Document instance with rich metadata
|
|
"""
|
|
content = page.content
|
|
source_url = page.url
|
|
|
|
# Analyze content structure
|
|
structure_info = self._analyze_content_structure(content)
|
|
|
|
# Calculate quality metrics
|
|
quality_info = self._calculate_content_quality(content, page.title)
|
|
|
|
# Extract domain information
|
|
domain_info = self._extract_domain_info(source_url)
|
|
|
|
# Build rich metadata
|
|
metadata: DocumentMetadata = {
|
|
# Core required fields
|
|
"source_url": source_url,
|
|
"timestamp": datetime.now(UTC),
|
|
"content_type": "text/markdown",
|
|
"word_count": len(content.split()),
|
|
"char_count": len(content),
|
|
# Basic optional fields
|
|
"title": page.title or f"Page from {source_url}",
|
|
"description": page.description
|
|
or page.og_description
|
|
or f"Content scraped from {source_url}",
|
|
# Content categorization
|
|
"tags": page.keywords or [],
|
|
"language": page.language or "en",
|
|
# Authorship and source info
|
|
"author": page.author or page.twitter_creator or "Unknown",
|
|
"domain": domain_info["domain"],
|
|
"site_name": domain_info["site_name"],
|
|
# Document structure
|
|
"heading_hierarchy": (
|
|
list(hierarchy_val)
|
|
if (hierarchy_val := structure_info.get("heading_hierarchy"))
|
|
and isinstance(hierarchy_val, (list, tuple))
|
|
else []
|
|
),
|
|
"section_depth": (
|
|
int(depth_val)
|
|
if (depth_val := structure_info.get("section_depth"))
|
|
and isinstance(depth_val, (int, str))
|
|
else 0
|
|
),
|
|
"has_code_blocks": bool(structure_info.get("has_code_blocks", False)),
|
|
"has_images": bool(structure_info.get("has_images", False)),
|
|
"has_links": bool(structure_info.get("has_links", False)),
|
|
# Processing metadata
|
|
"extraction_method": "firecrawl",
|
|
"last_modified": datetime.fromisoformat(page.sitemap_last_modified)
|
|
if page.sitemap_last_modified
|
|
else None,
|
|
# Content quality indicators
|
|
"readability_score": quality_info["readability_score"],
|
|
"completeness_score": quality_info["completeness_score"],
|
|
}
|
|
|
|
# Note: Additional web-specific metadata like og_title, twitter_card etc.
|
|
# would need to be added to DocumentMetadata TypedDict if needed
|
|
|
|
return Document(
|
|
id=self.compute_document_id(source_url),
|
|
content=content,
|
|
metadata=metadata,
|
|
source=IngestionSource.WEB,
|
|
collection=job.storage_backend.value,
|
|
)
|
|
|
|
async def close(self) -> None:
|
|
"""Close the Firecrawl client and cleanup resources."""
|
|
# AsyncFirecrawl may not have explicit close method in all versions
|
|
# This is defensive cleanup following best practices
|
|
if hasattr(self.client, "close"):
|
|
try:
|
|
await self.client.close()
|
|
except Exception as e:
|
|
logging.debug(f"Error closing Firecrawl client: {e}")
|
|
elif (
|
|
hasattr(self.client, "_session")
|
|
and self.client._session
|
|
and hasattr(self.client._session, "close")
|
|
):
|
|
try:
|
|
await self.client._session.close()
|
|
except Exception as e:
|
|
logging.debug(f"Error closing Firecrawl session: {e}")
|
|
|
|
async def __aenter__(self) -> "FirecrawlIngestor":
|
|
"""Async context manager entry."""
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: object | None,
|
|
) -> None:
|
|
"""Async context manager exit with cleanup."""
|
|
await self.close()
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/utils/metadata_tagger.py">
|
|
"""Metadata tagger for enriching documents with AI-generated tags and metadata."""
|
|
|
|
import json
|
|
from datetime import UTC, datetime
|
|
from typing import Final, Protocol, TypedDict, cast
|
|
|
|
import httpx
|
|
|
|
from ..config import get_settings
|
|
from ..core.exceptions import IngestionError
|
|
from ..core.models import Document
|
|
|
|
JSON_CONTENT_TYPE: Final[str] = "application/json"
|
|
AUTHORIZATION_HEADER: Final[str] = "Authorization"
|
|
|
|
|
|
class HttpResponse(Protocol):
|
|
"""Protocol for HTTP response."""
|
|
|
|
def raise_for_status(self) -> None: ...
|
|
def json(self) -> dict[str, object]: ...
|
|
|
|
|
|
class AsyncHttpClient(Protocol):
|
|
"""Protocol for async HTTP client."""
|
|
|
|
async def post(self, url: str, *, json: dict[str, object] | None = None) -> HttpResponse: ...
|
|
|
|
async def aclose(self) -> None: ...
|
|
|
|
async def __aenter__(self) -> "AsyncHttpClient": ...
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: object | None,
|
|
) -> None: ...
|
|
|
|
|
|
class LlmResponse(TypedDict):
|
|
"""Type for LLM API response structure."""
|
|
|
|
choices: list[dict[str, object]]
|
|
|
|
|
|
class LlmChoice(TypedDict):
|
|
"""Type for individual choice in LLM response."""
|
|
|
|
message: dict[str, object]
|
|
|
|
|
|
class LlmMessage(TypedDict):
|
|
"""Type for message in LLM choice."""
|
|
|
|
content: str
|
|
|
|
|
|
class DocumentMetadata(TypedDict, total=False):
|
|
"""Structured metadata for documents."""
|
|
|
|
tags: list[str]
|
|
category: str
|
|
summary: str
|
|
key_topics: list[str]
|
|
document_type: str
|
|
language: str
|
|
technical_level: str
|
|
|
|
|
|
class MetadataTagger:
|
|
"""Generates metadata tags for documents using language models."""
|
|
|
|
endpoint: str
|
|
model: str
|
|
client: AsyncHttpClient
|
|
|
|
def __init__(
|
|
self,
|
|
llm_endpoint: str | None = None,
|
|
model: str | None = None,
|
|
api_key: str | None = None,
|
|
*,
|
|
timeout: float | None = None,
|
|
):
|
|
"""
|
|
Initialize metadata tagger.
|
|
|
|
Args:
|
|
llm_endpoint: LLM API endpoint
|
|
model: Model to use for tagging
|
|
api_key: Explicit API key override
|
|
timeout: Optional request timeout override in seconds
|
|
"""
|
|
settings = get_settings()
|
|
endpoint_value = llm_endpoint or str(settings.llm_endpoint)
|
|
self.endpoint = endpoint_value.rstrip("/")
|
|
self.model = model or settings.metadata_model
|
|
|
|
resolved_timeout = timeout if timeout is not None else float(settings.request_timeout)
|
|
resolved_api_key = api_key or settings.get_llm_api_key() or ""
|
|
|
|
headers: dict[str, str] = {"Content-Type": JSON_CONTENT_TYPE}
|
|
if resolved_api_key:
|
|
headers[AUTHORIZATION_HEADER] = f"Bearer {resolved_api_key}"
|
|
|
|
# Create client with proper typing - httpx.AsyncClient implements AsyncHttpClient protocol
|
|
self.client = cast(
|
|
AsyncHttpClient,
|
|
httpx.AsyncClient(timeout=resolved_timeout, headers=headers),
|
|
)
|
|
|
|
async def tag_document(
|
|
self, document: Document, custom_instructions: str | None = None
|
|
) -> Document:
|
|
"""
|
|
Analyze document and generate metadata tags.
|
|
|
|
Args:
|
|
document: Document to tag
|
|
custom_instructions: Optional custom instructions for tagging
|
|
|
|
Returns:
|
|
Document with enriched metadata
|
|
"""
|
|
if not document.content:
|
|
return document
|
|
|
|
try:
|
|
# Generate metadata using LLM
|
|
metadata = await self._generate_metadata(
|
|
document.content,
|
|
document.metadata.get("title") if document.metadata else None,
|
|
custom_instructions,
|
|
)
|
|
|
|
# Merge with existing metadata - preserve ALL existing fields and add LLM-generated ones
|
|
|
|
from ..core.models import DocumentMetadata as CoreDocumentMetadata
|
|
|
|
# Start with a copy of existing metadata to preserve all fields
|
|
updated_metadata = dict(document.metadata)
|
|
|
|
# Update/enhance with LLM-generated metadata, preserving existing values when new ones are empty
|
|
if title_val := metadata.get("title"):
|
|
if not updated_metadata.get("title") and isinstance(title_val, str):
|
|
updated_metadata["title"] = title_val
|
|
if summary_val := metadata.get("summary"):
|
|
if not updated_metadata.get("description"):
|
|
updated_metadata["description"] = str(summary_val)
|
|
|
|
# Ensure required fields have values
|
|
_ = updated_metadata.setdefault("source_url", "")
|
|
_ = updated_metadata.setdefault("timestamp", datetime.now(UTC))
|
|
_ = updated_metadata.setdefault("content_type", "text/plain")
|
|
_ = updated_metadata.setdefault("word_count", len(document.content.split()))
|
|
_ = updated_metadata.setdefault("char_count", len(document.content))
|
|
|
|
# Build a proper DocumentMetadata instance with only valid keys
|
|
new_metadata: CoreDocumentMetadata = {
|
|
"source_url": str(updated_metadata.get("source_url", "")),
|
|
"timestamp": (lambda ts: ts if isinstance(ts, datetime) else datetime.now(UTC))(
|
|
updated_metadata.get("timestamp", datetime.now(UTC))
|
|
),
|
|
"content_type": str(updated_metadata.get("content_type", "text/plain")),
|
|
"word_count": (lambda wc: int(wc) if isinstance(wc, (int, str)) else 0)(
|
|
updated_metadata.get("word_count", 0)
|
|
),
|
|
"char_count": (lambda cc: int(cc) if isinstance(cc, (int, str)) else 0)(
|
|
updated_metadata.get("char_count", 0)
|
|
),
|
|
}
|
|
|
|
# Add optional fields if they exist
|
|
if "title" in updated_metadata and updated_metadata["title"]:
|
|
new_metadata["title"] = str(updated_metadata["title"])
|
|
if "description" in updated_metadata and updated_metadata["description"]:
|
|
new_metadata["description"] = str(updated_metadata["description"])
|
|
if "tags" in updated_metadata and isinstance(updated_metadata["tags"], list):
|
|
tags_list = cast(list[object], updated_metadata["tags"])
|
|
new_metadata["tags"] = [str(tag) for tag in tags_list if tag is not None]
|
|
if "category" in updated_metadata and updated_metadata["category"]:
|
|
new_metadata["category"] = str(updated_metadata["category"])
|
|
if "language" in updated_metadata and updated_metadata["language"]:
|
|
new_metadata["language"] = str(updated_metadata["language"])
|
|
|
|
document.metadata = new_metadata
|
|
|
|
return document
|
|
|
|
except Exception as e:
|
|
raise IngestionError(f"Failed to tag document: {e}") from e
|
|
|
|
async def tag_batch(
|
|
self,
|
|
documents: list[Document],
|
|
custom_instructions: str | None = None,
|
|
) -> list[Document]:
|
|
"""
|
|
Tag multiple documents with metadata.
|
|
|
|
Args:
|
|
documents: Documents to tag
|
|
custom_instructions: Optional custom instructions
|
|
|
|
Returns:
|
|
Documents with enriched metadata
|
|
"""
|
|
tagged_docs: list[Document] = []
|
|
|
|
for doc in documents:
|
|
tagged_doc = await self.tag_document(doc, custom_instructions)
|
|
tagged_docs.append(tagged_doc)
|
|
|
|
return tagged_docs
|
|
|
|
async def _generate_metadata(
|
|
self,
|
|
content: str,
|
|
title: str | None = None,
|
|
custom_instructions: str | None = None,
|
|
) -> DocumentMetadata:
|
|
"""
|
|
Generate metadata using LLM.
|
|
|
|
Args:
|
|
content: Document content
|
|
title: Document title
|
|
custom_instructions: Optional custom instructions
|
|
|
|
Returns:
|
|
Generated metadata dictionary
|
|
"""
|
|
# Prepare the prompt
|
|
system_prompt = """You are a document metadata tagger. Analyze the given content and generate relevant metadata.
|
|
|
|
Return a JSON object with the following structure:
|
|
{
|
|
"tags": ["tag1", "tag2", ...], # 3-7 relevant topic tags
|
|
"category": "string", # Main category
|
|
"summary": "string", # 1-2 sentence summary
|
|
"key_topics": ["topic1", "topic2", ...], # Main topics discussed
|
|
"document_type": "string", # Type of document (e.g., "technical", "tutorial", "reference")
|
|
"language": "string", # Primary language (e.g., "en", "es")
|
|
"technical_level": "string" # One of: "beginner", "intermediate", "advanced"
|
|
}"""
|
|
|
|
if custom_instructions:
|
|
system_prompt += f"\n\nAdditional instructions: {custom_instructions}"
|
|
|
|
# Prepare user prompt
|
|
user_prompt = "Document to analyze:\n"
|
|
if title:
|
|
user_prompt += f"Title: {title}\n"
|
|
user_prompt += f"Content:\n{content[:3000]}" # Limit content length
|
|
|
|
# Call LLM
|
|
response = await self.client.post(
|
|
f"{self.endpoint}/v1/chat/completions",
|
|
json={
|
|
"model": self.model,
|
|
"messages": [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_prompt},
|
|
],
|
|
"temperature": 0.3,
|
|
"max_tokens": 500,
|
|
"response_format": {"type": "json_object"},
|
|
},
|
|
)
|
|
response.raise_for_status()
|
|
|
|
result_raw = response.json()
|
|
if not isinstance(result_raw, dict):
|
|
raise IngestionError("Invalid response format from LLM")
|
|
|
|
result = cast(LlmResponse, result_raw)
|
|
|
|
# Extract content from response
|
|
choices = result.get("choices", [])
|
|
if not choices or not isinstance(choices, list):
|
|
raise IngestionError("No response from LLM")
|
|
|
|
first_choice_raw = choices[0]
|
|
if not isinstance(first_choice_raw, dict):
|
|
raise IngestionError("Invalid choice format")
|
|
|
|
first_choice = cast(LlmChoice, first_choice_raw)
|
|
message_raw = first_choice.get("message", {})
|
|
if not isinstance(message_raw, dict):
|
|
raise IngestionError("Invalid message format")
|
|
|
|
message = cast(LlmMessage, message_raw)
|
|
content_str = str(message.get("content", "{}"))
|
|
|
|
try:
|
|
raw_metadata = cast(dict[str, object], json.loads(content_str))
|
|
except json.JSONDecodeError as e:
|
|
raise IngestionError(f"Failed to parse LLM response: {e}") from e
|
|
|
|
# Ensure it's a dict before processing
|
|
if not isinstance(raw_metadata, dict):
|
|
raise IngestionError("LLM response is not a valid JSON object")
|
|
|
|
# Validate and sanitize metadata
|
|
return self._sanitize_metadata(raw_metadata)
|
|
|
|
def _sanitize_metadata(self, metadata: dict[str, object]) -> DocumentMetadata:
|
|
"""
|
|
Sanitize and validate metadata.
|
|
|
|
Args:
|
|
metadata: Raw metadata from LLM
|
|
|
|
Returns:
|
|
Sanitized metadata
|
|
"""
|
|
sanitized: DocumentMetadata = {}
|
|
|
|
# Tags
|
|
if "tags" in metadata and isinstance(metadata["tags"], list):
|
|
tags_list = cast(list[object], metadata["tags"])
|
|
tags_raw = tags_list[:10] if len(tags_list) > 10 else tags_list
|
|
tags = [str(tag).lower().strip() for tag in tags_raw]
|
|
sanitized["tags"] = [tag for tag in tags if tag]
|
|
|
|
# Category
|
|
if "category" in metadata:
|
|
sanitized["category"] = str(metadata["category"]).strip()
|
|
|
|
# Summary
|
|
if "summary" in metadata:
|
|
if summary := str(metadata["summary"]).strip():
|
|
sanitized["summary"] = summary[:500] # Limit length
|
|
|
|
# Key topics
|
|
if "key_topics" in metadata and isinstance(metadata["key_topics"], list):
|
|
topics_list = cast(list[object], metadata["key_topics"])
|
|
topics_raw = topics_list[:10] if len(topics_list) > 10 else topics_list
|
|
topics = [str(topic).strip() for topic in topics_raw]
|
|
sanitized["key_topics"] = [topic for topic in topics if topic]
|
|
|
|
# Document type
|
|
if "document_type" in metadata:
|
|
sanitized["document_type"] = str(metadata["document_type"]).strip()
|
|
|
|
# Language
|
|
if "language" in metadata:
|
|
lang = str(metadata["language"]).strip().lower()
|
|
if len(lang) == 2: # Basic validation for ISO 639-1
|
|
sanitized["language"] = lang
|
|
|
|
# Technical level
|
|
if "technical_level" in metadata:
|
|
level = str(metadata["technical_level"]).strip().lower()
|
|
if level in {"beginner", "intermediate", "advanced"}:
|
|
sanitized["technical_level"] = level
|
|
|
|
return sanitized
|
|
|
|
async def __aenter__(self) -> "MetadataTagger":
|
|
"""Async context manager entry."""
|
|
return self
|
|
|
|
async def __aexit__(self, *args: object) -> None:
|
|
"""Async context manager exit."""
|
|
await self.client.aclose()
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/utils/vectorizer.py">
|
|
"""Vectorizer utility for generating embeddings."""
|
|
|
|
import asyncio
|
|
from types import TracebackType
|
|
from typing import Final, NotRequired, Self, TypedDict
|
|
|
|
import httpx
|
|
|
|
from ..config import get_settings
|
|
from ..core.exceptions import VectorizationError
|
|
from ..core.models import StorageConfig, VectorConfig
|
|
|
|
JSON_CONTENT_TYPE: Final[str] = "application/json"
|
|
AUTHORIZATION_HEADER: Final[str] = "Authorization"
|
|
|
|
|
|
class EmbeddingData(TypedDict):
|
|
"""Structure for embedding data from providers."""
|
|
|
|
embedding: list[float]
|
|
index: NotRequired[int]
|
|
object: NotRequired[str]
|
|
|
|
|
|
class EmbeddingResponse(TypedDict):
|
|
"""Embedding response format for multiple providers."""
|
|
|
|
data: list[EmbeddingData]
|
|
model: NotRequired[str]
|
|
object: NotRequired[str]
|
|
usage: NotRequired[dict[str, int]]
|
|
# Alternative formats
|
|
embedding: NotRequired[list[float]]
|
|
vector: NotRequired[list[float]]
|
|
embeddings: NotRequired[list[list[float]]]
|
|
|
|
|
|
def _extract_embedding_from_response(response_data: dict[str, object]) -> list[float]:
|
|
"""Extract embedding vector from provider response."""
|
|
# OpenAI/Ollama format: {"data": [{"embedding": [...]}]}
|
|
if "data" in response_data:
|
|
data_list = response_data["data"]
|
|
if isinstance(data_list, list) and data_list:
|
|
first_item = data_list[0]
|
|
if isinstance(first_item, dict) and "embedding" in first_item:
|
|
embedding = first_item["embedding"]
|
|
if isinstance(embedding, list) and all(
|
|
isinstance(x, (int, float)) for x in embedding
|
|
):
|
|
return [float(x) for x in embedding]
|
|
|
|
# Direct embedding format: {"embedding": [...]}
|
|
if "embedding" in response_data:
|
|
embedding = response_data["embedding"]
|
|
if isinstance(embedding, list) and all(isinstance(x, (int, float)) for x in embedding):
|
|
return [float(x) for x in embedding]
|
|
|
|
# Vector format: {"vector": [...]}
|
|
if "vector" in response_data:
|
|
vector = response_data["vector"]
|
|
if isinstance(vector, list) and all(isinstance(x, (int, float)) for x in vector):
|
|
return [float(x) for x in vector]
|
|
|
|
# Embeddings array format: {"embeddings": [[...]]}
|
|
if "embeddings" in response_data:
|
|
embeddings = response_data["embeddings"]
|
|
if isinstance(embeddings, list) and embeddings:
|
|
first_embedding = embeddings[0]
|
|
if isinstance(first_embedding, list) and all(
|
|
isinstance(x, (int, float)) for x in first_embedding
|
|
):
|
|
return [float(x) for x in first_embedding]
|
|
|
|
raise VectorizationError("Unrecognized embedding response format")
|
|
|
|
|
|
class Vectorizer:
|
|
"""Handles text vectorization using LLM endpoints."""
|
|
|
|
endpoint: str
|
|
model: str
|
|
dimension: int
|
|
|
|
def __init__(self, config: StorageConfig | VectorConfig):
|
|
"""
|
|
Initialize vectorizer.
|
|
|
|
Args:
|
|
config: Configuration with embedding details
|
|
"""
|
|
settings = get_settings()
|
|
if isinstance(config, StorageConfig):
|
|
# Extract vector config from global settings when storage config is provided
|
|
self.endpoint = str(settings.llm_endpoint).rstrip("/")
|
|
self.model = settings.embedding_model
|
|
self.dimension = settings.embedding_dimension
|
|
else:
|
|
self.endpoint = str(config.embedding_endpoint).rstrip("/")
|
|
self.model = config.model
|
|
self.dimension = config.dimension
|
|
|
|
resolved_api_key = settings.get_llm_api_key() or ""
|
|
headers: dict[str, str] = {"Content-Type": JSON_CONTENT_TYPE}
|
|
if resolved_api_key:
|
|
headers[AUTHORIZATION_HEADER] = f"Bearer {resolved_api_key}"
|
|
|
|
timeout_seconds = float(settings.request_timeout)
|
|
self.client = httpx.AsyncClient(timeout=timeout_seconds, headers=headers)
|
|
|
|
async def vectorize(self, text: str) -> list[float]:
|
|
"""
|
|
Generate embedding vector for text.
|
|
|
|
Args:
|
|
text: Text to vectorize
|
|
|
|
Returns:
|
|
Embedding vector
|
|
"""
|
|
if not text:
|
|
raise VectorizationError("Cannot vectorize empty text")
|
|
|
|
try:
|
|
return (
|
|
await self._ollama_embed(text)
|
|
if "ollama" in self.model
|
|
else await self._openai_embed(text)
|
|
)
|
|
except Exception as e:
|
|
raise VectorizationError(f"Vectorization failed: {e}") from e
|
|
|
|
async def vectorize_batch(self, texts: list[str]) -> list[list[float]]:
|
|
"""
|
|
Generate embeddings for multiple texts in parallel.
|
|
|
|
Args:
|
|
texts: List of texts to vectorize
|
|
|
|
Returns:
|
|
List of embedding vectors
|
|
|
|
Raises:
|
|
VectorizationError: If any vectorization fails
|
|
"""
|
|
|
|
if not texts:
|
|
return []
|
|
|
|
# Use semaphore to limit concurrent requests and prevent overwhelming the endpoint
|
|
semaphore = asyncio.Semaphore(20)
|
|
|
|
async def vectorize_with_semaphore(text: str) -> list[float]:
|
|
async with semaphore:
|
|
return await self.vectorize(text)
|
|
|
|
try:
|
|
# Execute all vectorization requests concurrently
|
|
vectors = await asyncio.gather(*[vectorize_with_semaphore(text) for text in texts])
|
|
return list(vectors)
|
|
except Exception as e:
|
|
raise VectorizationError(f"Batch vectorization failed: {e}") from e
|
|
|
|
async def _ollama_embed(self, text: str) -> list[float]:
|
|
"""
|
|
Generate embedding using Ollama via OpenAI-compatible endpoint.
|
|
|
|
Args:
|
|
text: Text to embed
|
|
|
|
Returns:
|
|
Embedding vector
|
|
"""
|
|
# Use the full model name as it appears in the API
|
|
model_name = self.model
|
|
|
|
# Use OpenAI-compatible endpoint for ollama models
|
|
response = await self.client.post(
|
|
f"{self.endpoint}/v1/embeddings",
|
|
json={
|
|
"model": model_name,
|
|
"input": text,
|
|
},
|
|
)
|
|
_ = response.raise_for_status()
|
|
|
|
response_json = response.json()
|
|
if not isinstance(response_json, dict):
|
|
raise VectorizationError("Invalid JSON response format")
|
|
|
|
# Extract embedding using type-safe helper
|
|
embedding = _extract_embedding_from_response(response_json)
|
|
|
|
# Ensure correct dimension
|
|
if len(embedding) != self.dimension:
|
|
raise VectorizationError(
|
|
f"Embedding dimension mismatch: expected {self.dimension}, received {len(embedding)}"
|
|
)
|
|
|
|
return embedding
|
|
|
|
async def _openai_embed(self, text: str) -> list[float]:
|
|
"""
|
|
Generate embedding using OpenAI-compatible API.
|
|
|
|
Args:
|
|
text: Text to embed
|
|
|
|
Returns:
|
|
Embedding vector
|
|
"""
|
|
response = await self.client.post(
|
|
f"{self.endpoint}/v1/embeddings",
|
|
json={
|
|
"model": self.model,
|
|
"input": text,
|
|
},
|
|
)
|
|
_ = response.raise_for_status()
|
|
|
|
response_json = response.json()
|
|
if not isinstance(response_json, dict):
|
|
raise VectorizationError("Invalid JSON response format")
|
|
|
|
# Extract embedding using type-safe helper
|
|
embedding = _extract_embedding_from_response(response_json)
|
|
|
|
# Ensure correct dimension
|
|
if len(embedding) != self.dimension:
|
|
raise VectorizationError(
|
|
f"Embedding dimension mismatch: expected {self.dimension}, received {len(embedding)}"
|
|
)
|
|
|
|
return embedding
|
|
|
|
async def __aenter__(self) -> Self:
|
|
"""Async context manager entry."""
|
|
return self
|
|
|
|
async def close(self) -> None:
|
|
"""Close the HTTP client connection."""
|
|
try:
|
|
await self.client.aclose()
|
|
except Exception:
|
|
# Already closed or connection lost
|
|
pass
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: TracebackType | None,
|
|
) -> None:
|
|
"""Async context manager exit."""
|
|
await self.close()
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/cli/tui/screens/dashboard.py">
|
|
"""Main dashboard screen with collections overview."""
|
|
|
|
import logging
|
|
from typing import TYPE_CHECKING, Final
|
|
|
|
from textual import work
|
|
from textual.app import ComposeResult
|
|
from textual.binding import Binding
|
|
from textual.containers import Container, Grid, Horizontal
|
|
from textual.css.query import NoMatches
|
|
from textual.reactive import reactive, var
|
|
from textual.screen import Screen
|
|
from textual.widgets import (
|
|
Button,
|
|
Footer,
|
|
Header,
|
|
LoadingIndicator,
|
|
Rule,
|
|
Static,
|
|
TabbedContent,
|
|
TabPane,
|
|
)
|
|
from typing_extensions import override
|
|
|
|
from ....core.models import StorageBackend
|
|
from ....storage.base import BaseStorage
|
|
from ....storage.openwebui import OpenWebUIStorage
|
|
from ....storage.weaviate import WeaviateStorage
|
|
from ..models import CollectionInfo
|
|
from ..utils.storage_manager import StorageManager
|
|
from ..widgets import EnhancedDataTable, MetricsCard, StatusIndicator
|
|
|
|
if TYPE_CHECKING:
|
|
from ....storage.r2r.storage import R2RStorage
|
|
else: # pragma: no cover - optional dependency fallback
|
|
R2RStorage = BaseStorage
|
|
|
|
|
|
LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
|
|
|
|
|
class CollectionOverviewScreen(Screen[None]):
|
|
"""Enhanced dashboard with modern design and metrics."""
|
|
|
|
total_documents: int = 0
|
|
total_collections: int = 0
|
|
active_backends: int = 0
|
|
|
|
BINDINGS = [
|
|
Binding("q", "quit", "Quit"),
|
|
Binding("r", "refresh", "Refresh"),
|
|
Binding("i", "ingest", "Ingest"),
|
|
Binding("m", "manage", "Manage"),
|
|
Binding("s", "search", "Search"),
|
|
Binding("ctrl+d", "delete", "Delete"),
|
|
Binding("ctrl+1", "tab_dashboard", "Dashboard"),
|
|
Binding("ctrl+2", "tab_collections", "Collections"),
|
|
Binding("ctrl+3", "tab_analytics", "Analytics"),
|
|
Binding("tab", "next_tab", "Next Tab"),
|
|
Binding("shift+tab", "prev_tab", "Prev Tab"),
|
|
Binding("f1", "help", "Help"),
|
|
]
|
|
|
|
collections: var[list[CollectionInfo]] = var([])
|
|
is_loading: var[bool] = var(False)
|
|
selected_collection: reactive[CollectionInfo | None] = reactive(None)
|
|
storage_manager: StorageManager
|
|
weaviate: WeaviateStorage | None
|
|
openwebui: OpenWebUIStorage | None
|
|
r2r: R2RStorage | BaseStorage | None
|
|
|
|
def __init__(
|
|
self,
|
|
storage_manager: StorageManager,
|
|
weaviate: WeaviateStorage | None,
|
|
openwebui: OpenWebUIStorage | None,
|
|
r2r: R2RStorage | BaseStorage | None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.storage_manager = storage_manager
|
|
self.weaviate = weaviate
|
|
self.openwebui = openwebui
|
|
self.r2r = r2r
|
|
self.total_documents = 0
|
|
self.total_collections = 0
|
|
self.active_backends = 0
|
|
|
|
@override
|
|
def compose(self) -> ComposeResult:
|
|
yield Header(show_clock=True)
|
|
|
|
with TabbedContent():
|
|
# Dashboard Tab
|
|
with TabPane("Dashboard", id="dashboard"):
|
|
yield Container(
|
|
Static("🚀 Collection Management System", classes="title"),
|
|
Static("Modern document ingestion and management platform", classes="subtitle"),
|
|
Rule(line_style="heavy"),
|
|
# Metrics Grid
|
|
Container(
|
|
Grid(
|
|
MetricsCard(
|
|
"Collections", str(self.total_collections), "Active collections"
|
|
),
|
|
MetricsCard("Documents", str(self.total_documents), "Total indexed"),
|
|
MetricsCard(
|
|
"Backends", str(self.active_backends), "Connected services"
|
|
),
|
|
MetricsCard("Status", "Online", "System health"),
|
|
classes="responsive-grid metrics-grid",
|
|
),
|
|
classes="center",
|
|
),
|
|
Rule(line_style="dashed"),
|
|
# Quick Actions
|
|
Container(
|
|
Static("⚡ Quick Actions", classes="section-title"),
|
|
Horizontal(
|
|
Button("🔄 Refresh Data", id="quick_refresh", variant="primary"),
|
|
Button("📥 New Ingestion", id="quick_ingest", variant="success"),
|
|
Button("🔍 Search All", id="quick_search", variant="default"),
|
|
Button("⚙️ Settings", id="quick_settings", variant="default"),
|
|
classes="action_buttons",
|
|
),
|
|
classes="card",
|
|
),
|
|
# Recent Activity
|
|
Container(
|
|
Static("📊 Recent Activity", classes="section-title"),
|
|
Static(
|
|
"Loading recent activity...", id="activity_feed", classes="status-text"
|
|
),
|
|
classes="card",
|
|
),
|
|
classes="main_container",
|
|
)
|
|
|
|
# Collections Tab
|
|
with TabPane("Collections", id="collections"):
|
|
yield Container(
|
|
Static("📚 Collection Overview", classes="title"),
|
|
# Collection controls
|
|
Horizontal(
|
|
Button("🔄 Refresh", id="refresh_btn", variant="primary"),
|
|
Button("📥 Ingest", id="ingest_btn", variant="success"),
|
|
Button("🔧 Manage", id="manage_btn", variant="warning"),
|
|
Button("🗑️ Delete", id="delete_btn", variant="error"),
|
|
Button("🔍 Search", id="search_btn", variant="default"),
|
|
classes="button_bar",
|
|
),
|
|
# Collection table with enhanced navigation
|
|
EnhancedDataTable(id="collections_table", classes="enhanced-table"),
|
|
# Status bar
|
|
Container(
|
|
Static("Ready", id="status_text", classes="status-text"),
|
|
StatusIndicator("Ready", id="connection_status"),
|
|
classes="status-bar",
|
|
),
|
|
LoadingIndicator(id="loading", classes="pulse"),
|
|
classes="main_container",
|
|
)
|
|
|
|
# Analytics Tab
|
|
with TabPane("Analytics", id="analytics"):
|
|
yield Container(
|
|
Static("📈 Analytics & Insights", classes="title"),
|
|
# Analytics content
|
|
Container(
|
|
Static("🚧 Analytics Dashboard", classes="section-title"),
|
|
Static("Advanced analytics and insights coming soon!", classes="subtitle"),
|
|
# Placeholder charts area
|
|
Container(
|
|
Static("📊 Document Distribution", classes="chart-title"),
|
|
Static(
|
|
"Chart placeholder - integrate with visualization library",
|
|
classes="chart-placeholder",
|
|
),
|
|
classes="card",
|
|
),
|
|
Container(
|
|
Static("⏱️ Ingestion Timeline", classes="chart-title"),
|
|
Static("Timeline chart placeholder", classes="chart-placeholder"),
|
|
classes="card",
|
|
),
|
|
classes="analytics-grid",
|
|
),
|
|
classes="main_container",
|
|
)
|
|
|
|
yield Footer()
|
|
|
|
async def on_mount(self) -> None:
|
|
"""Initialize the screen with enhanced loading."""
|
|
self.query_one("#loading").display = False
|
|
self.update_metrics()
|
|
self.refresh_collections() # Don't await, let it run as a worker
|
|
|
|
def update_metrics(self) -> None:
|
|
"""Update dashboard metrics with enhanced calculations."""
|
|
self._calculate_metrics()
|
|
self._update_metrics_cards()
|
|
self._update_activity_feed()
|
|
|
|
def _calculate_metrics(self) -> None:
|
|
"""Calculate basic metrics from collections."""
|
|
self.total_collections = len(self.collections)
|
|
self.total_documents = sum(col["count"] for col in self.collections)
|
|
# Calculate active backends from storage manager if individual storages are None
|
|
if self.weaviate is None and self.openwebui is None and self.r2r is None:
|
|
self.active_backends = len(self.storage_manager.get_available_backends())
|
|
else:
|
|
self.active_backends = sum([bool(self.weaviate), bool(self.openwebui), bool(self.r2r)])
|
|
|
|
def _update_metrics_cards(self) -> None:
|
|
"""Update the metrics cards display."""
|
|
try:
|
|
dashboard_tab = self.query_one("#dashboard")
|
|
metrics_cards_query = dashboard_tab.query(MetricsCard)
|
|
if len(metrics_cards_query) >= 4:
|
|
metrics_cards = list(metrics_cards_query)
|
|
self._update_card_values(metrics_cards)
|
|
self._update_status_card(metrics_cards[3])
|
|
except NoMatches:
|
|
return
|
|
except Exception as exc:
|
|
LOGGER.exception("Failed to update dashboard metrics", exc_info=exc)
|
|
|
|
def _update_card_values(self, metrics_cards: list[MetricsCard]) -> None:
|
|
"""Update individual metric card values."""
|
|
metrics_cards[0].query_one(".metrics-value", Static).update(f"{self.total_collections:,}")
|
|
metrics_cards[1].query_one(".metrics-value", Static).update(f"{self.total_documents:,}")
|
|
metrics_cards[2].query_one(".metrics-value", Static).update(str(self.active_backends))
|
|
|
|
def _update_status_card(self, status_card: MetricsCard) -> None:
|
|
"""Update the system status card."""
|
|
if self.active_backends > 0 and self.total_collections > 0:
|
|
status_text, status_class = "🟢 Healthy", "status-active"
|
|
elif self.active_backends > 0:
|
|
status_text, status_class = "🟡 Ready", "status-warning"
|
|
else:
|
|
status_text, status_class = "🔴 Offline", "status-error"
|
|
|
|
status_card.query_one(".metrics-value", Static).update(status_text)
|
|
status_card.add_class(status_class)
|
|
|
|
def _update_activity_feed(self) -> None:
|
|
"""Update the activity feed with collection data."""
|
|
try:
|
|
dashboard_tab = self.query_one("#dashboard")
|
|
activity_feed = dashboard_tab.query_one("#activity_feed", Static)
|
|
activity_text = self._generate_activity_text()
|
|
activity_feed.update(activity_text)
|
|
except NoMatches:
|
|
return
|
|
except Exception as exc:
|
|
LOGGER.exception("Failed to update dashboard activity feed", exc_info=exc)
|
|
|
|
def _generate_activity_text(self) -> str:
|
|
"""Generate activity feed text from collections."""
|
|
if not self.collections:
|
|
return "🚀 No collections found. Start by creating your first ingestion!\n💡 Press 'I' to begin or use the Quick Actions above."
|
|
|
|
recent_activity = [self._format_collection_item(col) for col in self.collections[:3]]
|
|
activity_text = "\n".join(recent_activity)
|
|
|
|
if len(self.collections) > 3:
|
|
total_docs = sum(c["count"] for c in self.collections)
|
|
activity_text += (
|
|
f"\n📊 Total: {len(self.collections)} collections with {total_docs:,} documents"
|
|
)
|
|
|
|
return activity_text
|
|
|
|
def _format_collection_item(self, col: CollectionInfo) -> str:
|
|
"""Format a single collection item for the activity feed."""
|
|
content_type = self._get_content_type_icon(col["name"])
|
|
size_mb = col["size_mb"]
|
|
backend_info = col["backend"]
|
|
|
|
# Check if this represents a multi-backend ingestion result
|
|
if isinstance(backend_info, list):
|
|
if len(backend_info) > 1:
|
|
# Ensure all elements are strings for safe joining
|
|
backend_strings = [str(b) for b in backend_info if b is not None]
|
|
backend_list = " + ".join(backend_strings) if backend_strings else "unknown"
|
|
return f"{content_type} {col['name']}: {col['count']:,} docs ({size_mb:.1f} MB) → {backend_list}"
|
|
elif len(backend_info) == 1:
|
|
backend_name = str(backend_info[0]) if backend_info[0] is not None else "unknown"
|
|
return f"{content_type} {col['name']}: {col['count']:,} docs ({size_mb:.1f} MB) - {backend_name}"
|
|
else:
|
|
return f"{content_type} {col['name']}: {col['count']:,} docs ({size_mb:.1f} MB) - unknown"
|
|
else:
|
|
backend_display = str(backend_info) if backend_info is not None else "unknown"
|
|
return f"{content_type} {col['name']}: {col['count']:,} docs ({size_mb:.1f} MB) - {backend_display}"
|
|
|
|
def _get_content_type_icon(self, name: str) -> str:
|
|
"""Get appropriate icon for collection content type."""
|
|
name_lower = name.lower()
|
|
if "web" in name_lower:
|
|
return "🌐"
|
|
elif "doc" in name_lower:
|
|
return "📖"
|
|
elif "repo" in name_lower:
|
|
return "📦"
|
|
return "📄"
|
|
|
|
@work(exclusive=True)
|
|
async def refresh_collections(self) -> None:
|
|
"""Refresh collection data with enhanced multi-backend loading feedback."""
|
|
self.is_loading = True
|
|
loading_indicator = self.query_one("#loading")
|
|
status_text = self.query_one("#status_text", Static)
|
|
|
|
loading_indicator.display = True
|
|
status_text.update("🔄 Refreshing collections...")
|
|
|
|
try:
|
|
# Use storage manager for unified backend handling
|
|
if not self.storage_manager.is_initialized:
|
|
status_text.update("🔗 Initializing storage backends...")
|
|
backend_results = await self.storage_manager.initialize_all_backends()
|
|
|
|
# Report per-backend initialization status
|
|
success_count = sum(backend_results.values())
|
|
total_count = len(backend_results)
|
|
status_text.update(f"✅ Initialized {success_count}/{total_count} backends")
|
|
|
|
# Get collections from all backends via storage manager
|
|
status_text.update("📚 Loading collections from all backends...")
|
|
collections = await self.storage_manager.get_all_collections()
|
|
|
|
# Update metrics calculation for multi-backend support
|
|
self.active_backends = len(self.storage_manager.get_available_backends())
|
|
|
|
self.collections = collections
|
|
await self.update_collections_table()
|
|
self.update_metrics()
|
|
|
|
# Enhanced status reporting for multi-backend
|
|
backend_names = ", ".join(
|
|
backend.value for backend in self.storage_manager.get_available_backends()
|
|
)
|
|
status_text.update(f"✨ Ready - {len(collections)} collections from {backend_names}")
|
|
|
|
# Update connection status with multi-backend awareness
|
|
connection_status = self.query_one("#connection_status", StatusIndicator)
|
|
if collections and self.active_backends > 0:
|
|
connection_status.update_status(f"✓ {self.active_backends} Active")
|
|
else:
|
|
connection_status.update_status("No Data")
|
|
|
|
except Exception as e:
|
|
status_text.update(f"❌ Error: {e}")
|
|
self.notify(f"Failed to refresh: {e}", severity="error", markup=False)
|
|
finally:
|
|
self.is_loading = False
|
|
loading_indicator.display = False
|
|
|
|
async def update_collections_table(self) -> None:
|
|
"""Update the collections table with enhanced formatting."""
|
|
table = self.query_one("#collections_table", EnhancedDataTable)
|
|
table.clear(columns=True)
|
|
|
|
# Add enhanced columns with more metadata
|
|
table.add_columns("Collection", "Backend", "Documents", "Size", "Type", "Status", "Updated")
|
|
|
|
# Add rows with enhanced formatting
|
|
for collection in self.collections:
|
|
# Format size
|
|
size_str = f"{collection['size_mb']:.1f} MB"
|
|
if collection["size_mb"] > 1000:
|
|
size_str = f"{collection['size_mb'] / 1000:.1f} GB"
|
|
|
|
# Format document count
|
|
doc_count = f"{collection['count']:,}"
|
|
|
|
# Determine content type based on collection name or other metadata
|
|
content_type = "📄 Mixed"
|
|
if "web" in collection["name"].lower():
|
|
content_type = "🌐 Web"
|
|
elif "doc" in collection["name"].lower():
|
|
content_type = "📖 Docs"
|
|
elif "repo" in collection["name"].lower():
|
|
content_type = "📦 Code"
|
|
|
|
table.add_row(
|
|
collection["name"],
|
|
collection["backend"],
|
|
doc_count,
|
|
size_str,
|
|
content_type,
|
|
collection["status"],
|
|
collection["last_updated"],
|
|
)
|
|
|
|
if self.collections:
|
|
table.move_cursor(row=0)
|
|
|
|
self.get_selected_collection()
|
|
|
|
def update_search_controls(self, collection: CollectionInfo | None) -> None:
|
|
"""Enable or disable search controls based on backend support."""
|
|
try:
|
|
search_button = self.query_one("#search_btn", Button)
|
|
quick_search_button = self.query_one("#quick_search", Button)
|
|
except Exception:
|
|
return
|
|
|
|
is_weaviate = bool(collection and collection.get("type") == "weaviate")
|
|
search_button.disabled = not is_weaviate
|
|
quick_search_button.disabled = not is_weaviate
|
|
|
|
def get_selected_collection(self) -> CollectionInfo | None:
|
|
"""Get the currently selected collection."""
|
|
table = self.query_one("#collections_table", EnhancedDataTable)
|
|
try:
|
|
row_index = table.cursor_coordinate.row
|
|
except (AttributeError, IndexError):
|
|
self.selected_collection = None
|
|
self.update_search_controls(None)
|
|
return None
|
|
|
|
if 0 <= row_index < len(self.collections):
|
|
collection = self.collections[row_index]
|
|
self.selected_collection = collection
|
|
self.update_search_controls(collection)
|
|
return collection
|
|
|
|
self.selected_collection = None
|
|
self.update_search_controls(None)
|
|
return None
|
|
|
|
# Action methods
|
|
def action_refresh(self) -> None:
|
|
"""Refresh collections."""
|
|
self.refresh_collections()
|
|
|
|
def action_ingest(self) -> None:
|
|
"""Show enhanced ingestion dialog."""
|
|
if selected := self.get_selected_collection():
|
|
from .ingestion import IngestionScreen
|
|
|
|
self.app.push_screen(IngestionScreen(selected, self.storage_manager))
|
|
else:
|
|
self.notify("🔍 Please select a collection first", severity="warning")
|
|
|
|
def action_manage(self) -> None:
|
|
"""Manage documents in selected collection."""
|
|
if selected := self.get_selected_collection():
|
|
if storage_backend := self._get_storage_for_collection(selected):
|
|
from .documents import DocumentManagementScreen
|
|
|
|
self.app.push_screen(DocumentManagementScreen(selected, storage_backend))
|
|
else:
|
|
self.notify(
|
|
"🚧 No storage backend available for this collection", severity="warning"
|
|
)
|
|
else:
|
|
self.notify("🔍 Please select a collection first", severity="warning")
|
|
|
|
def _get_storage_for_collection(self, collection: CollectionInfo) -> BaseStorage | None:
|
|
"""Get the appropriate storage backend for a collection."""
|
|
collection_type = collection.get("type", "")
|
|
|
|
# Map collection types to storage backends (try direct instances first)
|
|
if collection_type == "weaviate" and self.weaviate:
|
|
return self.weaviate
|
|
elif collection_type == "openwebui" and self.openwebui:
|
|
return self.openwebui
|
|
elif collection_type == "r2r" and self.r2r:
|
|
return self.r2r
|
|
|
|
# Fall back to storage manager if direct instances not available
|
|
if collection_type == "weaviate":
|
|
return self.storage_manager.get_backend(StorageBackend.WEAVIATE)
|
|
elif collection_type == "openwebui":
|
|
return self.storage_manager.get_backend(StorageBackend.OPEN_WEBUI)
|
|
elif collection_type == "r2r":
|
|
return self.storage_manager.get_backend(StorageBackend.R2R)
|
|
|
|
# Fall back to checking available backends by backend name
|
|
backend_name = collection.get("backend", "")
|
|
if isinstance(backend_name, str):
|
|
if "weaviate" in backend_name.lower():
|
|
return self.weaviate or self.storage_manager.get_backend(StorageBackend.WEAVIATE)
|
|
elif "openwebui" in backend_name.lower():
|
|
return self.openwebui or self.storage_manager.get_backend(StorageBackend.OPEN_WEBUI)
|
|
elif "r2r" in backend_name.lower():
|
|
return self.r2r or self.storage_manager.get_backend(StorageBackend.R2R)
|
|
|
|
return None
|
|
|
|
def action_search(self) -> None:
|
|
"""Search in selected collection."""
|
|
if selected := self.get_selected_collection():
|
|
if selected["type"] != "weaviate":
|
|
self.notify(
|
|
"🔐 Search is currently available only for Weaviate collections",
|
|
severity="warning",
|
|
)
|
|
return
|
|
from .search import SearchScreen
|
|
|
|
self.app.push_screen(SearchScreen(selected, self.weaviate, self.openwebui))
|
|
else:
|
|
self.notify("🔍 Please select a collection first", severity="warning")
|
|
|
|
def action_delete(self) -> None:
|
|
"""Delete selected collection."""
|
|
if selected := self.get_selected_collection():
|
|
from .dialogs import ConfirmDeleteScreen
|
|
|
|
self.app.push_screen(ConfirmDeleteScreen(selected, self))
|
|
else:
|
|
self.notify("🔍 Please select a collection first", severity="warning")
|
|
|
|
def action_tab_dashboard(self) -> None:
|
|
"""Switch to dashboard tab."""
|
|
tabbed_content: TabbedContent = self.query_one(TabbedContent)
|
|
tabbed_content.active = "dashboard"
|
|
|
|
def action_tab_collections(self) -> None:
|
|
"""Switch to collections tab."""
|
|
tabbed_content: TabbedContent = self.query_one(TabbedContent)
|
|
tabbed_content.active = "collections"
|
|
|
|
def action_tab_analytics(self) -> None:
|
|
"""Switch to analytics tab."""
|
|
tabbed_content: TabbedContent = self.query_one(TabbedContent)
|
|
tabbed_content.active = "analytics"
|
|
|
|
def action_next_tab(self) -> None:
|
|
"""Switch to next tab."""
|
|
tabbed_content: TabbedContent = self.query_one(TabbedContent)
|
|
tab_ids = ["dashboard", "collections", "analytics"]
|
|
current = tabbed_content.active
|
|
try:
|
|
current_index = tab_ids.index(current)
|
|
next_index = (current_index + 1) % len(tab_ids)
|
|
tabbed_content.active = tab_ids[next_index]
|
|
except (ValueError, AttributeError):
|
|
tabbed_content.active = tab_ids[0]
|
|
|
|
def action_prev_tab(self) -> None:
|
|
"""Switch to previous tab."""
|
|
tabbed_content: TabbedContent = self.query_one(TabbedContent)
|
|
tab_ids = ["dashboard", "collections", "analytics"]
|
|
current = tabbed_content.active
|
|
try:
|
|
current_index = tab_ids.index(current)
|
|
prev_index = (current_index - 1) % len(tab_ids)
|
|
tabbed_content.active = tab_ids[prev_index]
|
|
except (ValueError, AttributeError):
|
|
tabbed_content.active = tab_ids[0]
|
|
|
|
def action_help(self) -> None:
|
|
"""Show help screen."""
|
|
from .help import HelpScreen
|
|
|
|
help_md = """
|
|
# 🚀 Modern Collection Management System
|
|
|
|
## Navigation
|
|
- **Tab** / **Shift+Tab**: Switch between tabs
|
|
- **Ctrl+1/2/3**: Direct tab access
|
|
- **Enter**: Activate selected item
|
|
- **Escape**: Go back/cancel
|
|
- **Arrow Keys**: Navigate within tables
|
|
- **Home/End**: Jump to first/last row
|
|
- **Page Up/Down**: Scroll by page
|
|
|
|
## Collections
|
|
- **R**: Refresh collections
|
|
- **I**: Start ingestion
|
|
- **M**: Manage documents
|
|
- **S**: Search collection
|
|
- **Ctrl+D**: Delete collection
|
|
|
|
## Table Navigation
|
|
- **Up/Down** or **J/K**: Navigate rows
|
|
- **Space**: Toggle selection
|
|
- **Ctrl+A**: Select all
|
|
- **Ctrl+Shift+A**: Clear selection
|
|
|
|
## General
|
|
- **Q** / **Ctrl+C**: Quit application
|
|
- **F1**: Show this help
|
|
|
|
Enjoy the enhanced interface! 🎉
|
|
"""
|
|
self.app.push_screen(HelpScreen(help_md))
|
|
|
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
"""Handle button presses with enhanced feedback."""
|
|
button_id = event.button.id
|
|
|
|
# Add visual feedback
|
|
event.button.add_class("pressed")
|
|
self.call_later(self.remove_pressed_class, event.button)
|
|
|
|
if getattr(event.button, "disabled", False):
|
|
self.notify(
|
|
"🔐 Search is currently limited to Weaviate collections",
|
|
severity="warning",
|
|
)
|
|
return
|
|
|
|
if button_id in ["refresh_btn", "quick_refresh"]:
|
|
self.action_refresh()
|
|
elif button_id in ["ingest_btn", "quick_ingest"]:
|
|
self.action_ingest()
|
|
elif button_id == "manage_btn":
|
|
self.action_manage()
|
|
elif button_id == "delete_btn":
|
|
self.action_delete()
|
|
elif button_id in ["search_btn", "quick_search"]:
|
|
self.action_search()
|
|
elif button_id == "quick_settings":
|
|
self.notify("⚙️ Settings panel coming soon!", severity="information")
|
|
|
|
def remove_pressed_class(self, button: Button) -> None:
|
|
"""Remove pressed visual feedback class."""
|
|
button.remove_class("pressed")
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/cli/tui/screens/dialogs.py">
|
|
"""Dialog screens for confirmations and user interactions."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
|
|
from textual.app import ComposeResult
|
|
from textual.binding import Binding
|
|
from textual.containers import Container, Horizontal
|
|
from textual.screen import ModalScreen, Screen
|
|
from textual.widgets import Button, Footer, Header, LoadingIndicator, RichLog, Static
|
|
from typing_extensions import override
|
|
|
|
from ..models import CollectionInfo
|
|
|
|
if TYPE_CHECKING:
|
|
from ..app import CollectionManagementApp
|
|
from .dashboard import CollectionOverviewScreen
|
|
from .documents import DocumentManagementScreen
|
|
|
|
|
|
class ConfirmDeleteScreen(Screen[None]):
|
|
"""Screen for confirming collection deletion."""
|
|
|
|
collection: CollectionInfo
|
|
parent_screen: CollectionOverviewScreen
|
|
|
|
@property
|
|
def app(self) -> CollectionManagementApp: # type: ignore[override]
|
|
"""Return the typed app instance."""
|
|
return super().app # type: ignore[return-value]
|
|
|
|
BINDINGS = [
|
|
Binding("escape", "app.pop_screen", "Cancel"),
|
|
Binding("y", "confirm_delete", "Yes"),
|
|
Binding("n", "app.pop_screen", "No"),
|
|
Binding("enter", "confirm_delete", "Confirm"),
|
|
]
|
|
|
|
def __init__(self, collection: CollectionInfo, parent_screen: CollectionOverviewScreen):
|
|
super().__init__()
|
|
self.collection = collection
|
|
self.parent_screen = parent_screen
|
|
|
|
@override
|
|
def compose(self) -> ComposeResult:
|
|
yield Header()
|
|
yield Container(
|
|
Static("⚠️ Confirm Deletion", classes="title warning"),
|
|
Static(f"Are you sure you want to delete collection '{self.collection['name']}'?"),
|
|
Static(f"Backend: {self.collection['backend']}"),
|
|
Static(f"Documents: {self.collection['count']:,}"),
|
|
Static("This action cannot be undone!", classes="warning"),
|
|
Static("Press Y to confirm, N or Escape to cancel", classes="subtitle"),
|
|
Horizontal(
|
|
Button("✅ Yes, Delete (Y)", id="yes_btn", variant="error"),
|
|
Button("❌ Cancel (N)", id="no_btn", variant="default"),
|
|
classes="action_buttons",
|
|
),
|
|
classes="main_container center",
|
|
)
|
|
yield Footer()
|
|
|
|
def on_mount(self) -> None:
|
|
"""Initialize the screen with focus on cancel button for safety."""
|
|
self.query_one("#no_btn").focus()
|
|
|
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
"""Handle button presses."""
|
|
if event.button.id == "yes_btn":
|
|
self.action_confirm_delete()
|
|
elif event.button.id == "no_btn":
|
|
self.app.pop_screen()
|
|
|
|
def action_confirm_delete(self) -> None:
|
|
"""Confirm deletion."""
|
|
self.run_worker(self.delete_collection())
|
|
|
|
async def delete_collection(self) -> None:
|
|
"""Delete the collection."""
|
|
try:
|
|
if self.collection["type"] == "weaviate" and self.parent_screen.weaviate:
|
|
# Delete Weaviate collection
|
|
if (
|
|
self.parent_screen.weaviate.client
|
|
and self.parent_screen.weaviate.client.collections
|
|
):
|
|
self.parent_screen.weaviate.client.collections.delete(self.collection["name"])
|
|
self.notify(
|
|
f"Deleted Weaviate collection: {self.collection['name']}",
|
|
severity="information",
|
|
)
|
|
else:
|
|
# Use the dashboard's method to get the appropriate storage backend
|
|
storage_backend = self.parent_screen._get_storage_for_collection(self.collection)
|
|
if not storage_backend:
|
|
self.notify(
|
|
f"❌ No storage backend available for {self.collection['type']} collection: {self.collection['name']}",
|
|
severity="error",
|
|
)
|
|
self.app.pop_screen()
|
|
return
|
|
|
|
# Check if the storage backend supports collection deletion
|
|
if not hasattr(storage_backend, "delete_collection"):
|
|
self.notify(
|
|
f"❌ Collection deletion not supported for {self.collection['type']} backend",
|
|
severity="error",
|
|
)
|
|
self.app.pop_screen()
|
|
return
|
|
|
|
# Delete the collection using the appropriate backend
|
|
# Ensure we use the exact collection name, not any default from storage config
|
|
collection_name = str(self.collection["name"])
|
|
collection_type = str(self.collection["type"])
|
|
|
|
self.notify(
|
|
f"Deleting {collection_type} collection: {collection_name}...",
|
|
severity="information",
|
|
)
|
|
|
|
# Use the standard delete_collection method for all backends
|
|
if hasattr(storage_backend, "delete_collection"):
|
|
success = await storage_backend.delete_collection(collection_name)
|
|
else:
|
|
self.notify("❌ Backend does not support collection deletion", severity="error")
|
|
self.app.pop_screen()
|
|
return
|
|
if success:
|
|
self.notify(
|
|
f"✅ Successfully deleted {self.collection['type']} collection: {self.collection['name']}",
|
|
severity="information",
|
|
timeout=3.0,
|
|
)
|
|
else:
|
|
self.notify(
|
|
f"❌ Failed to delete {self.collection['type']} collection: {self.collection['name']}",
|
|
severity="error",
|
|
)
|
|
# Don't refresh if deletion failed
|
|
self.app.pop_screen()
|
|
return
|
|
|
|
# Refresh parent screen after a short delay to ensure deletion is processed
|
|
self.call_later(self._refresh_parent_collections, 0.5) # 500ms delay
|
|
self.app.pop_screen()
|
|
|
|
except Exception as e:
|
|
self.notify(f"Failed to delete collection: {e}", severity="error", markup=False)
|
|
|
|
def _refresh_parent_collections(self) -> None:
|
|
"""Helper method to refresh parent collections."""
|
|
self.parent_screen.refresh_collections()
|
|
|
|
|
|
class ConfirmDocumentDeleteScreen(Screen[None]):
|
|
"""Screen for confirming document deletion."""
|
|
|
|
doc_ids: list[str]
|
|
collection: CollectionInfo
|
|
parent_screen: DocumentManagementScreen
|
|
|
|
@property
|
|
def app(self) -> CollectionManagementApp: # type: ignore[override]
|
|
"""Return the typed app instance."""
|
|
return super().app # type: ignore[return-value]
|
|
|
|
BINDINGS = [
|
|
Binding("escape", "app.pop_screen", "Cancel"),
|
|
Binding("y", "confirm_delete", "Yes"),
|
|
Binding("n", "app.pop_screen", "No"),
|
|
Binding("enter", "confirm_delete", "Confirm"),
|
|
]
|
|
|
|
def __init__(
|
|
self,
|
|
doc_ids: list[str],
|
|
collection: CollectionInfo,
|
|
parent_screen: DocumentManagementScreen,
|
|
):
|
|
super().__init__()
|
|
self.doc_ids = doc_ids
|
|
self.collection = collection
|
|
self.parent_screen = parent_screen
|
|
|
|
@override
|
|
def compose(self) -> ComposeResult:
|
|
yield Header()
|
|
yield Container(
|
|
Static("⚠️ Confirm Document Deletion", classes="title warning"),
|
|
Static(
|
|
f"Are you sure you want to delete {len(self.doc_ids)} documents from '{self.collection['name']}'?"
|
|
),
|
|
Static("This action cannot be undone!", classes="warning"),
|
|
Static("Press Y to confirm, N or Escape to cancel", classes="subtitle"),
|
|
Horizontal(
|
|
Button("✅ Yes, Delete (Y)", id="yes_btn", variant="error"),
|
|
Button("❌ Cancel (N)", id="no_btn", variant="default"),
|
|
classes="action_buttons",
|
|
),
|
|
LoadingIndicator(id="loading"),
|
|
classes="main_container center",
|
|
)
|
|
yield Footer()
|
|
|
|
def on_mount(self) -> None:
|
|
"""Initialize the screen with focus on cancel button for safety."""
|
|
self.query_one("#loading").display = False
|
|
self.query_one("#no_btn").focus()
|
|
|
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
"""Handle button presses."""
|
|
if event.button.id == "yes_btn":
|
|
self.action_confirm_delete()
|
|
elif event.button.id == "no_btn":
|
|
self.app.pop_screen()
|
|
|
|
def action_confirm_delete(self) -> None:
|
|
"""Confirm deletion."""
|
|
self.run_worker(self.delete_documents())
|
|
|
|
async def delete_documents(self) -> None:
|
|
"""Delete the selected documents."""
|
|
loading = self.query_one("#loading")
|
|
loading.display = True
|
|
|
|
try:
|
|
results: dict[str, bool] = {}
|
|
if hasattr(self.parent_screen, "storage") and self.parent_screen.storage:
|
|
# Delete documents via storage
|
|
# The storage should have delete_documents method for weaviate
|
|
storage = self.parent_screen.storage
|
|
if hasattr(storage, "delete_documents"):
|
|
results = await storage.delete_documents(
|
|
self.doc_ids,
|
|
collection_name=self.collection["name"],
|
|
)
|
|
|
|
# Count successful deletions
|
|
successful = sum(bool(success) for success in results.values())
|
|
failed = len(results) - successful
|
|
|
|
if successful > 0:
|
|
self.notify(f"Deleted {successful} documents", severity="information")
|
|
if failed > 0:
|
|
self.notify(f"Failed to delete {failed} documents", severity="error")
|
|
|
|
# Clear selection and refresh parent screen
|
|
self.parent_screen.selected_docs.clear()
|
|
await self.parent_screen.load_documents()
|
|
self.app.pop_screen()
|
|
|
|
except Exception as e:
|
|
self.notify(f"Failed to delete documents: {e}", severity="error", markup=False)
|
|
finally:
|
|
loading.display = False
|
|
|
|
|
|
class LogViewerScreen(ModalScreen[None]):
|
|
"""Display live log output without disrupting the TUI."""
|
|
|
|
_log_widget: RichLog | None
|
|
_log_file: Path | None
|
|
|
|
@property
|
|
def app(self) -> CollectionManagementApp: # type: ignore[override]
|
|
"""Return the typed app instance."""
|
|
return super().app # type: ignore[return-value]
|
|
|
|
BINDINGS = [
|
|
Binding("escape", "close", "Close"),
|
|
Binding("ctrl+l", "close", "Close"),
|
|
Binding("s", "show_path", "Log File"),
|
|
]
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self._log_widget = None
|
|
self._log_file = None
|
|
|
|
@override
|
|
def compose(self) -> ComposeResult:
|
|
yield Header(show_clock=True)
|
|
yield Container(
|
|
Static("📜 Live Application Logs", classes="title"),
|
|
Static(
|
|
"Logs update in real time. Press S to reveal the log file path.", classes="subtitle"
|
|
),
|
|
RichLog(id="log_stream", classes="log-stream", wrap=True, highlight=False),
|
|
Static("", id="log_file_path", classes="subtitle"),
|
|
classes="main_container log-viewer-container",
|
|
)
|
|
yield Footer()
|
|
|
|
def on_mount(self) -> None:
|
|
"""Attach this viewer to the parent application once mounted."""
|
|
self._log_widget = self.query_one(RichLog)
|
|
|
|
if hasattr(self.app, "attach_log_viewer"):
|
|
self.app.attach_log_viewer(self) # type: ignore[arg-type]
|
|
|
|
def on_unmount(self) -> None:
|
|
"""Detach from the parent application when closed."""
|
|
|
|
if hasattr(self.app, "detach_log_viewer"):
|
|
self.app.detach_log_viewer(self) # type: ignore[arg-type]
|
|
|
|
def _get_log_widget(self) -> RichLog:
|
|
if self._log_widget is None:
|
|
self._log_widget = self.query_one(RichLog)
|
|
if self._log_widget is None:
|
|
raise RuntimeError("RichLog widget not found")
|
|
return self._log_widget
|
|
|
|
def replace_logs(self, lines: list[str]) -> None:
|
|
"""Replace rendered logs with the provided history."""
|
|
log_widget = self._get_log_widget()
|
|
log_widget.clear()
|
|
for line in lines:
|
|
log_widget.write(line)
|
|
log_widget.scroll_end(animate=False)
|
|
|
|
def append_logs(self, lines: list[str]) -> None:
|
|
"""Append new log lines to the viewer."""
|
|
log_widget = self._get_log_widget()
|
|
for line in lines:
|
|
log_widget.write(line)
|
|
log_widget.scroll_end(animate=False)
|
|
|
|
def update_log_file(self, log_file: Path | None) -> None:
|
|
"""Update the displayed log file path."""
|
|
self._log_file = log_file
|
|
label = self.query_one("#log_file_path", Static)
|
|
if log_file is None:
|
|
label.update("Logs are not currently being persisted to disk.")
|
|
else:
|
|
label.update(f"Log file: {log_file}")
|
|
|
|
def action_close(self) -> None:
|
|
"""Close the log viewer."""
|
|
self.app.pop_screen()
|
|
|
|
def action_show_path(self) -> None:
|
|
"""Reveal the log file location in a notification."""
|
|
if self._log_file is None:
|
|
self.notify("File logging is disabled for this session.", severity="warning")
|
|
else:
|
|
self.notify(
|
|
f"Log file available at: {self._log_file}", severity="information", markup=False
|
|
)
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/config/settings.py">
|
|
"""Application settings and configuration."""
|
|
|
|
from functools import lru_cache
|
|
from typing import Annotated, ClassVar, Final, Literal
|
|
|
|
from prefect.variables import Variable
|
|
from pydantic import Field, HttpUrl, model_validator
|
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
|
|
|
|
class Settings(BaseSettings):
|
|
"""Application settings."""
|
|
|
|
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
|
|
env_file=".env",
|
|
env_file_encoding="utf-8",
|
|
case_sensitive=False,
|
|
extra="ignore", # Ignore extra environment variables
|
|
)
|
|
|
|
# API Keys
|
|
firecrawl_api_key: str | None = None
|
|
llm_api_key: str | None = None
|
|
openai_api_key: str | None = None
|
|
openwebui_api_key: str | None = None
|
|
weaviate_api_key: str | None = None
|
|
r2r_api_key: str | None = None
|
|
|
|
# Endpoints
|
|
llm_endpoint: HttpUrl = HttpUrl("http://llm.lab")
|
|
weaviate_endpoint: HttpUrl = HttpUrl("http://weaviate.yo")
|
|
openwebui_endpoint: HttpUrl = HttpUrl("http://chat.lab") # This will be the API URL
|
|
firecrawl_endpoint: HttpUrl = HttpUrl("http://crawl.lab:30002")
|
|
r2r_endpoint: HttpUrl | None = Field(default=None, alias="r2r_api_url")
|
|
|
|
# Model Configuration
|
|
embedding_model: str = "ollama/bge-m3:latest"
|
|
metadata_model: str = "fireworks/glm-4p5-air"
|
|
embedding_dimension: int = 1024
|
|
|
|
# Ingestion Settings
|
|
default_batch_size: Annotated[int, Field(gt=0, le=500)] = 50
|
|
max_file_size: int = 1_000_000
|
|
max_crawl_depth: Annotated[int, Field(ge=1, le=20)] = 5
|
|
max_crawl_pages: Annotated[int, Field(ge=1, le=1000)] = 100
|
|
|
|
# Storage Settings
|
|
default_storage_backend: Literal["weaviate", "open_webui", "r2r"] = "weaviate"
|
|
default_collection_prefix: str = "docs"
|
|
|
|
# Prefect Settings
|
|
prefect_api_url: HttpUrl | None = None
|
|
prefect_api_key: str | None = None
|
|
prefect_work_pool: str = "default"
|
|
|
|
# Scheduling Defaults
|
|
default_schedule_interval: Annotated[int, Field(ge=1, le=10080)] = 60 # Max 1 week
|
|
|
|
# Performance Settings
|
|
max_concurrent_tasks: Annotated[int, Field(ge=1, le=20)] = 5
|
|
request_timeout: Annotated[int, Field(ge=10, le=300)] = 60
|
|
|
|
# Logging
|
|
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO"
|
|
|
|
def get_storage_endpoint(self, backend: str) -> HttpUrl:
|
|
"""
|
|
Get endpoint for storage backend.
|
|
|
|
Args:
|
|
backend: Storage backend name
|
|
|
|
Returns:
|
|
Endpoint URL
|
|
|
|
Raises:
|
|
ValueError: If backend is unknown or R2R endpoint not configured
|
|
"""
|
|
endpoints = {
|
|
"weaviate": self.weaviate_endpoint,
|
|
"open_webui": self.openwebui_endpoint,
|
|
}
|
|
|
|
if backend in endpoints:
|
|
return endpoints[backend]
|
|
elif backend == "r2r":
|
|
if not self.r2r_endpoint:
|
|
raise ValueError(
|
|
"R2R_API_URL must be set in environment variables. "
|
|
"This should have been caught during settings validation."
|
|
)
|
|
return self.r2r_endpoint
|
|
else:
|
|
raise ValueError(f"Unknown backend: {backend}. Supported: weaviate, open_webui, r2r")
|
|
|
|
def get_api_key(self, service: str) -> str | None:
|
|
"""
|
|
Get API key for service.
|
|
|
|
Args:
|
|
service: Service name
|
|
|
|
Returns:
|
|
API key or None
|
|
"""
|
|
service_map: Final[dict[str, str | None]] = {
|
|
"firecrawl": self.firecrawl_api_key,
|
|
"openwebui": self.openwebui_api_key,
|
|
"weaviate": self.weaviate_api_key,
|
|
"r2r": self.r2r_api_key,
|
|
"llm": self.get_llm_api_key(),
|
|
"openai": self.openai_api_key,
|
|
}
|
|
return service_map.get(service)
|
|
|
|
def get_llm_api_key(self) -> str | None:
|
|
"""Get API key for LLM services with OpenAI fallback."""
|
|
return self.llm_api_key or (self.openai_api_key or None)
|
|
|
|
@model_validator(mode="after")
|
|
def validate_backend_configuration(self) -> "Settings":
|
|
"""Validate that required configuration is present for the default backend."""
|
|
backend = self.default_storage_backend
|
|
|
|
# Validate R2R backend configuration
|
|
if backend == "r2r" and not self.r2r_endpoint:
|
|
raise ValueError(
|
|
"R2R_API_URL must be set in environment variables when using R2R as default backend"
|
|
)
|
|
|
|
# Validate API key requirements (optional warning for missing keys)
|
|
required_keys = {
|
|
"weaviate": ("WEAVIATE_API_KEY", self.weaviate_api_key),
|
|
"open_webui": ("OPENWEBUI_API_KEY", self.openwebui_api_key),
|
|
"r2r": ("R2R_API_KEY", self.r2r_api_key),
|
|
}
|
|
|
|
if backend in required_keys:
|
|
key_name, key_value = required_keys[backend]
|
|
if not key_value:
|
|
import warnings
|
|
|
|
warnings.warn(
|
|
f"{key_name} not set - authentication may fail for {backend} backend",
|
|
UserWarning,
|
|
stacklevel=2,
|
|
)
|
|
|
|
return self
|
|
|
|
|
|
@lru_cache
|
|
def get_settings() -> Settings:
|
|
"""
|
|
Get cached settings instance.
|
|
|
|
Returns:
|
|
Settings instance
|
|
"""
|
|
return Settings()
|
|
|
|
|
|
class PrefectVariableConfig:
|
|
"""Helper class for managing Prefect variables with fallbacks to settings."""
|
|
|
|
def __init__(self) -> None:
|
|
self._settings: Settings = get_settings()
|
|
self._variable_names: list[str] = [
|
|
"default_batch_size",
|
|
"max_file_size",
|
|
"max_crawl_depth",
|
|
"max_crawl_pages",
|
|
"default_storage_backend",
|
|
"default_collection_prefix",
|
|
"max_concurrent_tasks",
|
|
"request_timeout",
|
|
"default_schedule_interval",
|
|
]
|
|
|
|
def _get_fallback_value(self, name: str, default_value: object = None) -> object:
|
|
"""Get fallback value from settings or default."""
|
|
return default_value or getattr(self._settings, name, default_value)
|
|
|
|
def get_with_fallback(
|
|
self, name: str, default_value: str | int | float | None = None
|
|
) -> str | int | float | None:
|
|
"""Get variable value with fallback synchronously."""
|
|
fallback = self._get_fallback_value(name, default_value)
|
|
# Ensure fallback is a type that Variable expects
|
|
variable_fallback = str(fallback) if fallback is not None else None
|
|
try:
|
|
result = Variable.get(name, default=variable_fallback)
|
|
# Variable can return various types, convert to our expected types
|
|
if isinstance(result, (str, int, float)):
|
|
return result
|
|
elif result is None:
|
|
return None
|
|
else:
|
|
# Convert other types to string
|
|
return str(result)
|
|
except Exception:
|
|
# Return fallback with proper type
|
|
if isinstance(fallback, (str, int, float)) or fallback is None:
|
|
return fallback
|
|
return str(fallback) if fallback is not None else None
|
|
|
|
async def get_with_fallback_async(
|
|
self, name: str, default_value: str | int | float | None = None
|
|
) -> str | int | float | None:
|
|
"""Get variable value with fallback asynchronously."""
|
|
fallback = self._get_fallback_value(name, default_value)
|
|
variable_fallback = str(fallback) if fallback is not None else None
|
|
try:
|
|
result = await Variable.aget(name, default=variable_fallback)
|
|
# Variable can return various types, convert to our expected types
|
|
if isinstance(result, (str, int, float)):
|
|
return result
|
|
elif result is None:
|
|
return None
|
|
else:
|
|
# Convert other types to string
|
|
return str(result)
|
|
except Exception:
|
|
# Return fallback with proper type
|
|
if isinstance(fallback, (str, int, float)) or fallback is None:
|
|
return fallback
|
|
return str(fallback) if fallback is not None else None
|
|
|
|
def get_ingestion_config(self) -> dict[str, str | int | float | None]:
|
|
"""Get all ingestion-related configuration variables synchronously."""
|
|
return {name: self.get_with_fallback(name) for name in self._variable_names}
|
|
|
|
async def get_ingestion_config_async(self) -> dict[str, str | int | float | None]:
|
|
"""Get all ingestion-related configuration variables asynchronously."""
|
|
result: dict[str, str | int | float | None] = {}
|
|
for name in self._variable_names:
|
|
result[name] = await self.get_with_fallback_async(name)
|
|
return result
|
|
|
|
|
|
@lru_cache
|
|
def get_prefect_config() -> PrefectVariableConfig:
|
|
"""Get cached Prefect variable configuration helper."""
|
|
return PrefectVariableConfig()
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/core/models.py">
|
|
"""Core data models with strict typing."""
|
|
|
|
from datetime import UTC, datetime
|
|
from enum import Enum
|
|
from typing import Annotated, ClassVar, TypedDict
|
|
from uuid import UUID, uuid4
|
|
|
|
from prefect.blocks.core import Block
|
|
from pydantic import BaseModel, Field, HttpUrl, SecretStr
|
|
|
|
from ..config import get_settings
|
|
|
|
|
|
def _default_embedding_model() -> str:
|
|
return str(get_settings().embedding_model)
|
|
|
|
|
|
def _default_embedding_endpoint() -> HttpUrl:
|
|
endpoint = get_settings().llm_endpoint
|
|
return endpoint if isinstance(endpoint, HttpUrl) else HttpUrl(str(endpoint))
|
|
|
|
|
|
def _default_embedding_dimension() -> int:
|
|
return int(get_settings().embedding_dimension)
|
|
|
|
|
|
def _default_batch_size() -> int:
|
|
return int(get_settings().default_batch_size)
|
|
|
|
|
|
def _default_collection_name() -> str:
|
|
return str(get_settings().default_collection_prefix)
|
|
|
|
|
|
def _default_max_crawl_depth() -> int:
|
|
return int(get_settings().max_crawl_depth)
|
|
|
|
|
|
def _default_max_crawl_pages() -> int:
|
|
return int(get_settings().max_crawl_pages)
|
|
|
|
|
|
def _default_max_file_size() -> int:
|
|
return int(get_settings().max_file_size)
|
|
|
|
|
|
class IngestionStatus(str, Enum):
|
|
"""Status of an ingestion job."""
|
|
|
|
PENDING = "pending"
|
|
IN_PROGRESS = "in_progress"
|
|
COMPLETED = "completed"
|
|
PARTIAL = "partial" # Some documents succeeded, some failed
|
|
FAILED = "failed"
|
|
CANCELLED = "cancelled"
|
|
|
|
|
|
class StorageBackend(str, Enum):
|
|
"""Available storage backends."""
|
|
|
|
WEAVIATE = "weaviate"
|
|
OPEN_WEBUI = "open_webui"
|
|
R2R = "r2r"
|
|
|
|
|
|
class IngestionSource(str, Enum):
|
|
"""Types of ingestion sources."""
|
|
|
|
WEB = "web"
|
|
REPOSITORY = "repository"
|
|
DOCUMENTATION = "documentation"
|
|
|
|
|
|
class VectorConfig(BaseModel):
|
|
"""Configuration for vectorization."""
|
|
|
|
model: str = Field(default_factory=_default_embedding_model)
|
|
embedding_endpoint: HttpUrl = Field(default_factory=_default_embedding_endpoint)
|
|
dimension: int = Field(default_factory=_default_embedding_dimension)
|
|
batch_size: Annotated[int, Field(gt=0, le=1000)] = Field(default_factory=_default_batch_size)
|
|
|
|
|
|
class StorageConfig(Block):
|
|
"""Configuration for storage backend."""
|
|
|
|
_block_type_name: ClassVar[str | None] = "Storage Configuration"
|
|
_block_type_slug: ClassVar[str | None] = "storage-config"
|
|
_description: ClassVar[str | None] = (
|
|
"Configures storage backend connections and settings for document ingestion"
|
|
)
|
|
|
|
backend: StorageBackend
|
|
endpoint: HttpUrl
|
|
api_key: SecretStr | None = Field(default=None)
|
|
collection_name: str = Field(default_factory=_default_collection_name)
|
|
batch_size: Annotated[int, Field(gt=0, le=1000)] = Field(default_factory=_default_batch_size)
|
|
grpc_port: int | None = Field(default=None, description="gRPC port for Weaviate connections")
|
|
|
|
|
|
class FirecrawlConfig(Block):
|
|
"""Configuration for Firecrawl ingestion (operational parameters only)."""
|
|
|
|
_block_type_name: ClassVar[str | None] = "Firecrawl Configuration"
|
|
_block_type_slug: ClassVar[str | None] = "firecrawl-config"
|
|
_description: ClassVar[str | None] = "Configures Firecrawl web scraping and crawling parameters"
|
|
|
|
formats: list[str] = Field(default_factory=lambda: ["markdown", "html"])
|
|
max_depth: Annotated[int, Field(ge=1, le=20)] = Field(default_factory=_default_max_crawl_depth)
|
|
limit: Annotated[int, Field(ge=1, le=1000)] = Field(default_factory=_default_max_crawl_pages)
|
|
only_main_content: bool = Field(default=True)
|
|
include_subdomains: bool = Field(default=False)
|
|
|
|
|
|
class RepomixConfig(Block):
|
|
"""Configuration for Repomix ingestion."""
|
|
|
|
_block_type_name: ClassVar[str | None] = "Repomix Configuration"
|
|
_block_type_slug: ClassVar[str | None] = "repomix-config"
|
|
_description: ClassVar[str | None] = (
|
|
"Configures repository ingestion patterns and file processing settings"
|
|
)
|
|
|
|
include_patterns: list[str] = Field(
|
|
default_factory=lambda: ["*.py", "*.js", "*.ts", "*.md", "*.yaml", "*.json"]
|
|
)
|
|
exclude_patterns: list[str] = Field(
|
|
default_factory=lambda: ["**/node_modules/**", "**/__pycache__/**", "**/.git/**"]
|
|
)
|
|
max_file_size: int = Field(default_factory=_default_max_file_size) # 1MB
|
|
respect_gitignore: bool = Field(default=True)
|
|
|
|
|
|
class R2RConfig(Block):
|
|
"""Configuration for R2R ingestion."""
|
|
|
|
_block_type_name: ClassVar[str | None] = "R2R Configuration"
|
|
_block_type_slug: ClassVar[str | None] = "r2r-config"
|
|
_description: ClassVar[str | None] = (
|
|
"Configures R2R-specific ingestion settings including chunking and graph enrichment"
|
|
)
|
|
|
|
chunk_size: Annotated[int, Field(ge=100, le=8192)] = 1000
|
|
chunk_overlap: Annotated[int, Field(ge=0, le=1000)] = 200
|
|
enable_graph_enrichment: bool = Field(default=False)
|
|
graph_creation_settings: dict[str, object] | None = Field(default=None)
|
|
|
|
|
|
class DocumentMetadataRequired(TypedDict):
|
|
"""Required metadata fields for a document."""
|
|
|
|
source_url: str
|
|
timestamp: datetime
|
|
content_type: str
|
|
word_count: int
|
|
char_count: int
|
|
|
|
|
|
class DocumentMetadata(DocumentMetadataRequired, total=False):
|
|
"""Rich metadata for a document with R2R-compatible fields."""
|
|
|
|
# Basic optional fields
|
|
title: str | None
|
|
description: str | None
|
|
|
|
# Content categorization
|
|
tags: list[str]
|
|
category: str
|
|
section: str
|
|
language: str
|
|
|
|
# Authorship and source info
|
|
author: str
|
|
domain: str
|
|
site_name: str
|
|
|
|
# Document structure
|
|
heading_hierarchy: list[str]
|
|
section_depth: int
|
|
has_code_blocks: bool
|
|
has_images: bool
|
|
has_links: bool
|
|
|
|
# Processing metadata
|
|
extraction_method: str
|
|
crawl_depth: int
|
|
last_modified: datetime | None
|
|
|
|
# Content quality indicators
|
|
readability_score: float | None
|
|
completeness_score: float | None
|
|
|
|
# Repository-specific fields
|
|
file_path: str | None
|
|
repository_name: str | None
|
|
branch_name: str | None
|
|
commit_hash: str | None
|
|
programming_language: str | None
|
|
|
|
# Custom business metadata
|
|
importance_score: float | None
|
|
review_status: str | None
|
|
assigned_team: str | None
|
|
|
|
|
|
class Document(BaseModel):
|
|
"""Represents a single document."""
|
|
|
|
id: UUID = Field(default_factory=uuid4)
|
|
content: str
|
|
metadata: DocumentMetadata
|
|
vector: list[float] | None = Field(default=None)
|
|
score: float | None = Field(default=None)
|
|
source: IngestionSource
|
|
collection: str = Field(default_factory=_default_collection_name)
|
|
|
|
|
|
class IngestionJob(BaseModel):
|
|
"""Represents an ingestion job."""
|
|
|
|
id: UUID = Field(default_factory=uuid4)
|
|
source_type: IngestionSource
|
|
source_url: HttpUrl | str
|
|
status: IngestionStatus = Field(default=IngestionStatus.PENDING)
|
|
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
|
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
|
completed_at: datetime | None = Field(default=None)
|
|
error_message: str | None = Field(default=None)
|
|
document_count: int = Field(default=0)
|
|
storage_backend: StorageBackend
|
|
|
|
|
|
class IngestionResult(BaseModel):
|
|
"""Result of an ingestion operation."""
|
|
|
|
job_id: UUID
|
|
status: IngestionStatus
|
|
documents_processed: int
|
|
documents_failed: int
|
|
duration_seconds: float
|
|
error_messages: list[str] = Field(default_factory=list)
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/storage/r2r/storage.py">
|
|
"""R2R storage implementation using the official R2R SDK."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import contextlib
|
|
import logging
|
|
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
|
|
from datetime import UTC, datetime
|
|
from typing import Final, Self, TypeVar, cast
|
|
from uuid import UUID, uuid4
|
|
|
|
# Direct imports for runtime and type checking
|
|
from httpx import AsyncClient, HTTPStatusError # type: ignore
|
|
from r2r import R2RAsyncClient, R2RException # type: ignore
|
|
from typing_extensions import override
|
|
|
|
from ...core.exceptions import StorageError
|
|
from ...core.models import Document, DocumentMetadata, IngestionSource, StorageConfig
|
|
from ..base import BaseStorage
|
|
from ..types import DocumentInfo
|
|
|
|
LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
def _as_mapping(value: object) -> dict[str, object]:
|
|
if isinstance(value, Mapping):
|
|
return dict(cast(Mapping[str, object], value))
|
|
if hasattr(value, "__dict__"):
|
|
return dict(cast(Mapping[str, object], value.__dict__))
|
|
return {}
|
|
|
|
|
|
def _as_sequence(value: object) -> tuple[object, ...]:
|
|
"""Convert value to a tuple of objects."""
|
|
if isinstance(value, Sequence):
|
|
return tuple(value)
|
|
return tuple(value) if isinstance(value, Iterable) else ()
|
|
|
|
|
|
def _extract_id(source: object, fallback: str) -> str:
|
|
mapping = _as_mapping(source)
|
|
identifier = mapping.get("id") if mapping else None
|
|
if identifier is None and hasattr(source, "id"):
|
|
identifier = getattr(source, "id", None)
|
|
return fallback if identifier is None else str(identifier)
|
|
|
|
|
|
def _as_datetime(value: object) -> datetime:
|
|
if isinstance(value, datetime):
|
|
return value
|
|
if isinstance(value, str):
|
|
with contextlib.suppress(ValueError):
|
|
return datetime.fromisoformat(value)
|
|
return datetime.now(UTC)
|
|
|
|
|
|
def _as_int(value: object, default: int = 0) -> int:
|
|
if isinstance(value, bool):
|
|
return int(value)
|
|
if isinstance(value, int):
|
|
return value
|
|
if isinstance(value, float):
|
|
return int(value)
|
|
if isinstance(value, str):
|
|
try:
|
|
return int(float(value)) if "." in value else int(value)
|
|
except ValueError:
|
|
return default
|
|
return default
|
|
|
|
|
|
class R2RStorage(BaseStorage):
|
|
"""R2R storage implementation using the official R2R SDK."""
|
|
|
|
def __init__(self, config: StorageConfig) -> None:
|
|
"""Initialize R2R storage with SDK client."""
|
|
super().__init__(config)
|
|
self.endpoint: str = str(config.endpoint).rstrip("/")
|
|
self.client: R2RAsyncClient = R2RAsyncClient(self.endpoint)
|
|
self.default_collection_id: str | None = None
|
|
|
|
def _get_http_client_headers(self) -> dict[str, str]:
|
|
"""Get consistent HTTP headers for direct API calls."""
|
|
headers = {"Content-Type": "application/json"}
|
|
|
|
# Add authentication headers if available
|
|
# Note: R2R SDK may handle auth internally, so we extract it if possible
|
|
if hasattr(self.client, "_get_headers"):
|
|
with contextlib.suppress(Exception):
|
|
sdk_headers = self.client._get_headers() # type: ignore[attr-defined]
|
|
if isinstance(sdk_headers, dict):
|
|
headers |= sdk_headers
|
|
return headers
|
|
|
|
def _create_http_client(self) -> AsyncClient:
|
|
"""Create a properly configured HTTP client for direct API calls."""
|
|
headers = self._get_http_client_headers()
|
|
return AsyncClient(headers=headers, timeout=30.0)
|
|
|
|
@override
|
|
async def initialize(self) -> None:
|
|
"""Initialize R2R connection and ensure default collection exists."""
|
|
try:
|
|
# Ensure we have an event loop
|
|
try:
|
|
_ = asyncio.get_running_loop()
|
|
except RuntimeError:
|
|
# No event loop running, this should not happen in async context
|
|
# but let's be defensive
|
|
import logging
|
|
|
|
logging.warning("No event loop found during R2R initialization")
|
|
|
|
# Test connection using direct HTTP call to v3 API
|
|
endpoint = self.endpoint
|
|
client = self._create_http_client()
|
|
try:
|
|
response = await client.get(f"{endpoint}/v3/collections")
|
|
response.raise_for_status()
|
|
finally:
|
|
await client.aclose()
|
|
_ = await self._ensure_collection(self.config.collection_name)
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to initialize R2R: {e}") from e
|
|
|
|
async def _ensure_collection(self, collection_name: str) -> str:
|
|
"""Get or create collection by name."""
|
|
endpoint = self.endpoint
|
|
client = self._create_http_client()
|
|
try:
|
|
# List collections and find by name
|
|
response = await client.get(f"{endpoint}/v3/collections")
|
|
response.raise_for_status()
|
|
data: dict[str, object] = response.json()
|
|
|
|
results = cast(list[dict[str, object]], data.get("results", []))
|
|
for collection in results:
|
|
if collection.get("name") == collection_name:
|
|
collection_id_raw = collection.get("id")
|
|
if collection_id_raw is None:
|
|
raise StorageError(f"Collection '{collection_name}' exists but has no ID")
|
|
collection_id = str(collection_id_raw)
|
|
if collection_name == self.config.collection_name:
|
|
self.default_collection_id = collection_id
|
|
return collection_id
|
|
|
|
# Create if not found
|
|
create_response = await client.post(
|
|
f"{endpoint}/v3/collections",
|
|
json={
|
|
"name": collection_name,
|
|
"description": f"Auto-created collection: {collection_name}",
|
|
},
|
|
)
|
|
create_response.raise_for_status()
|
|
created: dict[str, object] = create_response.json()
|
|
created_results = cast(dict[str, object], created.get("results", {}))
|
|
collection_id_raw = created_results.get("id")
|
|
if collection_id_raw is None:
|
|
raise StorageError("Failed to get collection ID from creation response")
|
|
collection_id = str(collection_id_raw)
|
|
|
|
if collection_name == self.config.collection_name:
|
|
self.default_collection_id = collection_id
|
|
|
|
return collection_id
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to ensure collection '{collection_name}': {e}") from e
|
|
finally:
|
|
await client.aclose()
|
|
|
|
# This should never be reached, but satisfies static analyzer
|
|
raise StorageError(f"Unexpected code path in _ensure_collection for '{collection_name}'")
|
|
|
|
@override
|
|
async def store(self, document: Document, *, collection_name: str | None = None) -> str:
|
|
"""Store a single document."""
|
|
return (await self.store_batch([document], collection_name=collection_name))[0]
|
|
|
|
@override
|
|
async def store_batch(
|
|
self, documents: list[Document], *, collection_name: str | None = None
|
|
) -> list[str]:
|
|
"""Store multiple documents efficiently with connection reuse."""
|
|
collection_id = await self._resolve_collection_id(collection_name)
|
|
LOGGER.info(
|
|
"Using collection ID: %s for collection: %s",
|
|
collection_id,
|
|
collection_name or self.config.collection_name,
|
|
)
|
|
|
|
# Filter valid documents upfront
|
|
valid_documents = [doc for doc in documents if self._is_document_valid(doc)]
|
|
if not valid_documents:
|
|
return []
|
|
|
|
stored_ids: list[str] = []
|
|
|
|
# Use a single HTTP client for all requests
|
|
http_client = AsyncClient()
|
|
async with http_client: # type: ignore
|
|
# Process documents with controlled concurrency
|
|
import asyncio
|
|
|
|
semaphore = asyncio.Semaphore(5) # Limit concurrent uploads
|
|
|
|
async def store_single_with_client(document: Document) -> str | None:
|
|
async with semaphore:
|
|
return await self._store_single_document_with_client(
|
|
document, collection_id, http_client
|
|
)
|
|
|
|
# Execute all uploads concurrently
|
|
results = await asyncio.gather(
|
|
*[store_single_with_client(doc) for doc in valid_documents], return_exceptions=True
|
|
)
|
|
|
|
# Collect successful IDs
|
|
for result in results:
|
|
if isinstance(result, str):
|
|
stored_ids.append(result)
|
|
elif isinstance(result, Exception):
|
|
LOGGER.error("Document upload failed: %s", result)
|
|
|
|
return stored_ids
|
|
|
|
async def _resolve_collection_id(self, collection_name: str | None) -> str:
|
|
"""Resolve collection ID from name or use default."""
|
|
if collection_name:
|
|
return await self._ensure_collection(collection_name)
|
|
|
|
if self.default_collection_id:
|
|
return self.default_collection_id
|
|
|
|
collection_id = await self._ensure_collection(self.config.collection_name)
|
|
self.default_collection_id = collection_id
|
|
return collection_id
|
|
|
|
def _is_document_valid(self, document: Document) -> bool:
|
|
"""Validate document content and size."""
|
|
requested_id = str(document.id)
|
|
|
|
if not document.content or not document.content.strip():
|
|
LOGGER.warning("Skipping document %s: empty content", requested_id)
|
|
return False
|
|
|
|
if len(document.content) > 1_000_000: # 1MB limit
|
|
LOGGER.warning(
|
|
"Skipping document %s: content too large (%d chars)",
|
|
requested_id,
|
|
len(document.content),
|
|
)
|
|
return False
|
|
|
|
return True
|
|
|
|
async def _store_single_document(self, document: Document, collection_id: str) -> str | None:
|
|
"""Store a single document with retry logic."""
|
|
http_client = AsyncClient()
|
|
async with http_client: # type: ignore
|
|
return await self._store_single_document_with_client(
|
|
document, collection_id, http_client
|
|
)
|
|
|
|
async def _store_single_document_with_client(
|
|
self, document: Document, collection_id: str, http_client: AsyncClient
|
|
) -> str | None:
|
|
"""Store a single document with retry logic using provided HTTP client."""
|
|
requested_id = str(document.id)
|
|
LOGGER.debug("Creating document with ID: %s", requested_id)
|
|
|
|
max_retries = 3
|
|
retry_delay = 1.0
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
doc_response = await self._attempt_document_creation_with_client(
|
|
document, collection_id, http_client
|
|
)
|
|
if doc_response:
|
|
return self._process_document_response(
|
|
doc_response, requested_id, collection_id
|
|
)
|
|
except (TimeoutError, OSError) as e:
|
|
if not await self._should_retry_timeout(
|
|
e, attempt, max_retries, requested_id, retry_delay
|
|
):
|
|
break
|
|
retry_delay *= 2
|
|
except HTTPStatusError as e:
|
|
if not await self._should_retry_http_error(
|
|
e, attempt, max_retries, requested_id, retry_delay
|
|
):
|
|
break
|
|
retry_delay *= 2
|
|
except Exception as exc:
|
|
self._log_document_error(document.id, exc)
|
|
break
|
|
|
|
return None
|
|
|
|
async def _attempt_document_creation(
|
|
self, document: Document, collection_id: str
|
|
) -> dict[str, object] | None:
|
|
"""Attempt to create a document via HTTP API."""
|
|
http_client = AsyncClient()
|
|
async with http_client: # type: ignore
|
|
return await self._attempt_document_creation_with_client(
|
|
document, collection_id, http_client
|
|
)
|
|
|
|
async def _attempt_document_creation_with_client(
|
|
self, document: Document, collection_id: str, http_client: AsyncClient
|
|
) -> dict[str, object] | None:
|
|
"""Attempt to create a document via HTTP API using provided client."""
|
|
import json
|
|
|
|
requested_id = str(document.id)
|
|
metadata = self._build_metadata(document)
|
|
LOGGER.debug("Built metadata for document %s: %s", requested_id, metadata)
|
|
|
|
files = {
|
|
"raw_text": (None, document.content),
|
|
"metadata": (None, json.dumps(metadata)),
|
|
"id": (None, requested_id),
|
|
"ingestion_mode": (None, "hi-res"),
|
|
}
|
|
|
|
if collection_id:
|
|
files["collection_ids"] = (None, json.dumps([collection_id]))
|
|
LOGGER.debug(
|
|
"Creating document %s with collection_ids: [%s]", requested_id, collection_id
|
|
)
|
|
|
|
LOGGER.debug("Sending to R2R - files keys: %s", list(files.keys()))
|
|
LOGGER.debug("Metadata JSON: %s", files["metadata"][1])
|
|
|
|
response = await http_client.post(f"{self.endpoint}/v3/documents", files=files) # type: ignore[call-arg]
|
|
|
|
if response.status_code == 422:
|
|
self._handle_validation_error(response, requested_id, metadata)
|
|
return None
|
|
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
def _handle_validation_error(
|
|
self, response: object, requested_id: str, metadata: dict[str, object]
|
|
) -> None:
|
|
"""Handle validation errors from R2R API."""
|
|
try:
|
|
error_detail = (
|
|
getattr(response, "json", lambda: {})() if hasattr(response, "json") else {}
|
|
)
|
|
LOGGER.error("R2R validation error for document %s: %s", requested_id, error_detail)
|
|
LOGGER.error("Document metadata sent: %s", metadata)
|
|
LOGGER.error("Response status: %s", getattr(response, "status_code", "unknown"))
|
|
LOGGER.error("Response headers: %s", dict(getattr(response, "headers", {})))
|
|
except Exception:
|
|
LOGGER.error(
|
|
"R2R validation error for document %s: %s",
|
|
requested_id,
|
|
getattr(response, "text", "unknown error"),
|
|
)
|
|
LOGGER.error("Document metadata sent: %s", metadata)
|
|
|
|
def _process_document_response(
|
|
self, doc_response: dict[str, object], requested_id: str, collection_id: str
|
|
) -> str:
|
|
"""Process successful document creation response."""
|
|
response_payload = doc_response.get("results", doc_response)
|
|
doc_id = _extract_id(response_payload, requested_id)
|
|
|
|
LOGGER.info("R2R returned document ID: %s", doc_id)
|
|
|
|
if doc_id != requested_id:
|
|
LOGGER.warning("Requested ID %s but got %s", requested_id, doc_id)
|
|
|
|
if collection_id:
|
|
LOGGER.info(
|
|
"Document %s should be assigned to collection %s via creation API",
|
|
doc_id,
|
|
collection_id,
|
|
)
|
|
|
|
return doc_id
|
|
|
|
async def _should_retry_timeout(
|
|
self,
|
|
error: Exception,
|
|
attempt: int,
|
|
max_retries: int,
|
|
requested_id: str,
|
|
retry_delay: float,
|
|
) -> bool:
|
|
"""Determine if timeout error should be retried."""
|
|
if attempt >= max_retries - 1:
|
|
return False
|
|
|
|
LOGGER.warning("Timeout for document %s, retrying in %ss...", requested_id, retry_delay)
|
|
await asyncio.sleep(retry_delay)
|
|
return True
|
|
|
|
async def _should_retry_http_error(
|
|
self,
|
|
error: HTTPStatusError,
|
|
attempt: int,
|
|
max_retries: int,
|
|
requested_id: str,
|
|
retry_delay: float,
|
|
) -> bool:
|
|
"""Determine if HTTP error should be retried."""
|
|
status_code = error.response.status_code
|
|
if status_code < 500 or attempt >= max_retries - 1:
|
|
return False
|
|
|
|
LOGGER.warning(
|
|
"Server error %s for document %s, retrying in %ss...",
|
|
status_code,
|
|
requested_id,
|
|
retry_delay,
|
|
)
|
|
await asyncio.sleep(retry_delay)
|
|
return True
|
|
|
|
def _log_document_error(self, document_id: object, exc: Exception) -> None:
|
|
"""Log document storage errors with specific categorization."""
|
|
LOGGER.error("Failed to store document %s: %s", document_id, exc)
|
|
|
|
exc_str = str(exc)
|
|
if "422" in exc_str:
|
|
LOGGER.error(" → Data validation issue - check document content and metadata format")
|
|
elif "timeout" in exc_str.lower():
|
|
LOGGER.error(" → Network timeout - R2R may be overloaded")
|
|
elif "500" in exc_str:
|
|
LOGGER.error(" → Server error - R2R internal issue")
|
|
else:
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
|
|
def _build_metadata(self, document: Document) -> dict[str, object]:
|
|
"""Convert document metadata to enriched R2R format."""
|
|
metadata = document.metadata
|
|
|
|
# Core required fields
|
|
result: dict[str, object] = {
|
|
"source_url": metadata["source_url"],
|
|
"content_type": metadata["content_type"],
|
|
"word_count": metadata["word_count"],
|
|
"char_count": metadata["char_count"],
|
|
"timestamp": metadata["timestamp"].isoformat(),
|
|
"ingestion_source": document.source.value,
|
|
}
|
|
|
|
# Basic optional fields
|
|
if title := metadata.get("title"):
|
|
result["title"] = title
|
|
if description := metadata.get("description"):
|
|
result["description"] = description
|
|
|
|
# Content categorization
|
|
if tags := metadata.get("tags"):
|
|
result["tags"] = tags
|
|
if category := metadata.get("category"):
|
|
result["category"] = category
|
|
if section := metadata.get("section"):
|
|
result["section"] = section
|
|
if language := metadata.get("language"):
|
|
result["language"] = language
|
|
|
|
# Authorship and source info
|
|
if author := metadata.get("author"):
|
|
result["author"] = author
|
|
if domain := metadata.get("domain"):
|
|
result["domain"] = domain
|
|
if site_name := metadata.get("site_name"):
|
|
result["site_name"] = site_name
|
|
|
|
# Document structure
|
|
if heading_hierarchy := metadata.get("heading_hierarchy"):
|
|
result["heading_hierarchy"] = heading_hierarchy
|
|
if section_depth := metadata.get("section_depth"):
|
|
result["section_depth"] = section_depth
|
|
if has_code_blocks := metadata.get("has_code_blocks"):
|
|
result["has_code_blocks"] = has_code_blocks
|
|
if has_images := metadata.get("has_images"):
|
|
result["has_images"] = has_images
|
|
if has_links := metadata.get("has_links"):
|
|
result["has_links"] = has_links
|
|
|
|
# Processing metadata
|
|
if extraction_method := metadata.get("extraction_method"):
|
|
result["extraction_method"] = extraction_method
|
|
if crawl_depth := metadata.get("crawl_depth"):
|
|
result["crawl_depth"] = crawl_depth
|
|
if last_modified := metadata.get("last_modified"):
|
|
result["last_modified"] = last_modified.isoformat() if last_modified else None
|
|
|
|
# Content quality indicators
|
|
if readability_score := metadata.get("readability_score"):
|
|
result["readability_score"] = readability_score
|
|
if completeness_score := metadata.get("completeness_score"):
|
|
result["completeness_score"] = completeness_score
|
|
|
|
# Repository-specific fields
|
|
if file_path := metadata.get("file_path"):
|
|
result["file_path"] = file_path
|
|
if repository_name := metadata.get("repository_name"):
|
|
result["repository_name"] = repository_name
|
|
if branch_name := metadata.get("branch_name"):
|
|
result["branch_name"] = branch_name
|
|
if commit_hash := metadata.get("commit_hash"):
|
|
result["commit_hash"] = commit_hash
|
|
if programming_language := metadata.get("programming_language"):
|
|
result["programming_language"] = programming_language
|
|
|
|
# Custom business metadata
|
|
if importance_score := metadata.get("importance_score"):
|
|
result["importance_score"] = importance_score
|
|
if review_status := metadata.get("review_status"):
|
|
result["review_status"] = review_status
|
|
if assigned_team := metadata.get("assigned_team"):
|
|
result["assigned_team"] = assigned_team
|
|
|
|
return result
|
|
|
|
@override
|
|
async def retrieve(
|
|
self, document_id: str, *, collection_name: str | None = None
|
|
) -> Document | None:
|
|
"""Retrieve a document by ID."""
|
|
try:
|
|
response = await self.client.documents.retrieve(document_id)
|
|
except R2RException as exc:
|
|
status_code = getattr(exc, "status_code", None)
|
|
if status_code == 404:
|
|
return None
|
|
import logging
|
|
|
|
logging.warning(f"Unexpected error retrieving document {document_id}: {exc}")
|
|
return None
|
|
except Exception as error:
|
|
import logging
|
|
|
|
logging.warning(f"Unexpected error retrieving document {document_id}: {error}")
|
|
return None
|
|
payload = getattr(response, "results", response)
|
|
return self._convert_to_document(payload, collection_name)
|
|
|
|
def _convert_to_document(self, r2r_doc: object, collection_name: str | None = None) -> Document:
|
|
"""Convert R2R document payload to our Document model."""
|
|
doc_map = _as_mapping(r2r_doc)
|
|
metadata_map = _as_mapping(doc_map.get("metadata", {}))
|
|
|
|
doc_uuid = self._extract_document_uuid(r2r_doc)
|
|
timestamp = _as_datetime(doc_map.get("created_at", metadata_map.get("timestamp")))
|
|
|
|
metadata = self._build_core_metadata(metadata_map, timestamp)
|
|
self._add_optional_metadata_fields(metadata, doc_map, metadata_map)
|
|
|
|
source_enum = self._extract_ingestion_source(metadata_map)
|
|
content_value = doc_map.get("content", getattr(r2r_doc, "content", ""))
|
|
|
|
return Document(
|
|
id=doc_uuid,
|
|
content=str(content_value),
|
|
metadata=metadata,
|
|
source=source_enum,
|
|
collection=collection_name or self.config.collection_name,
|
|
)
|
|
|
|
def _extract_document_uuid(self, r2r_doc: object) -> UUID:
|
|
"""Extract and validate document UUID."""
|
|
doc_id_str = _extract_id(r2r_doc, str(uuid4()))
|
|
try:
|
|
return UUID(doc_id_str)
|
|
except ValueError:
|
|
return uuid4()
|
|
|
|
def _build_core_metadata(
|
|
self, metadata_map: dict[str, object], timestamp: datetime
|
|
) -> DocumentMetadata:
|
|
"""Build core required metadata fields."""
|
|
return {
|
|
"source_url": str(metadata_map.get("source_url", "")),
|
|
"timestamp": timestamp,
|
|
"content_type": str(metadata_map.get("content_type", "text/plain")),
|
|
"word_count": _as_int(metadata_map.get("word_count")),
|
|
"char_count": _as_int(metadata_map.get("char_count")),
|
|
}
|
|
|
|
def _add_optional_metadata_fields(
|
|
self,
|
|
metadata: DocumentMetadata,
|
|
doc_map: dict[str, object],
|
|
metadata_map: dict[str, object],
|
|
) -> None:
|
|
"""Add optional metadata fields if present."""
|
|
self._add_title_and_description(metadata, doc_map, metadata_map)
|
|
self._add_content_categorization(metadata, metadata_map)
|
|
self._add_authorship_fields(metadata, metadata_map)
|
|
self._add_structure_fields(metadata, metadata_map)
|
|
self._add_processing_fields(metadata, metadata_map)
|
|
self._add_quality_scores(metadata, metadata_map)
|
|
|
|
def _add_title_and_description(
|
|
self,
|
|
metadata: DocumentMetadata,
|
|
doc_map: dict[str, object],
|
|
metadata_map: dict[str, object],
|
|
) -> None:
|
|
"""Add title and description fields."""
|
|
if title := (doc_map.get("title") or metadata_map.get("title")):
|
|
metadata["title"] = cast(str | None, title)
|
|
|
|
if summary := (doc_map.get("summary") or metadata_map.get("summary")):
|
|
metadata["description"] = cast(str | None, summary)
|
|
elif description := metadata_map.get("description"):
|
|
metadata["description"] = cast(str | None, description)
|
|
|
|
def _add_content_categorization(
|
|
self, metadata: DocumentMetadata, metadata_map: dict[str, object]
|
|
) -> None:
|
|
"""Add content categorization fields."""
|
|
if tags := metadata_map.get("tags"):
|
|
metadata["tags"] = [str(tag) for tag in tags] if isinstance(tags, list) else []
|
|
if category := metadata_map.get("category"):
|
|
metadata["category"] = str(category)
|
|
if section := metadata_map.get("section"):
|
|
metadata["section"] = str(section)
|
|
if language := metadata_map.get("language"):
|
|
metadata["language"] = str(language)
|
|
|
|
def _add_authorship_fields(
|
|
self, metadata: DocumentMetadata, metadata_map: dict[str, object]
|
|
) -> None:
|
|
"""Add authorship and source information fields."""
|
|
if author := metadata_map.get("author"):
|
|
metadata["author"] = str(author)
|
|
if domain := metadata_map.get("domain"):
|
|
metadata["domain"] = str(domain)
|
|
if site_name := metadata_map.get("site_name"):
|
|
metadata["site_name"] = str(site_name)
|
|
|
|
def _add_structure_fields(
|
|
self, metadata: DocumentMetadata, metadata_map: dict[str, object]
|
|
) -> None:
|
|
"""Add document structure fields."""
|
|
if heading_hierarchy := metadata_map.get("heading_hierarchy"):
|
|
metadata["heading_hierarchy"] = (
|
|
list(heading_hierarchy) if isinstance(heading_hierarchy, list) else []
|
|
)
|
|
if section_depth := metadata_map.get("section_depth"):
|
|
metadata["section_depth"] = _as_int(section_depth)
|
|
if has_code_blocks := metadata_map.get("has_code_blocks"):
|
|
metadata["has_code_blocks"] = bool(has_code_blocks)
|
|
if has_images := metadata_map.get("has_images"):
|
|
metadata["has_images"] = bool(has_images)
|
|
if has_links := metadata_map.get("has_links"):
|
|
metadata["has_links"] = bool(has_links)
|
|
|
|
def _add_processing_fields(
|
|
self, metadata: DocumentMetadata, metadata_map: dict[str, object]
|
|
) -> None:
|
|
"""Add processing-related metadata fields."""
|
|
if extraction_method := metadata_map.get("extraction_method"):
|
|
metadata["extraction_method"] = str(extraction_method)
|
|
if crawl_depth := metadata_map.get("crawl_depth"):
|
|
metadata["crawl_depth"] = _as_int(crawl_depth)
|
|
if last_modified := metadata_map.get("last_modified"):
|
|
metadata["last_modified"] = _as_datetime(last_modified)
|
|
|
|
def _add_quality_scores(
|
|
self, metadata: DocumentMetadata, metadata_map: dict[str, object]
|
|
) -> None:
|
|
"""Add quality score fields with safe float conversion."""
|
|
if readability_score := metadata_map.get("readability_score"):
|
|
try:
|
|
metadata["readability_score"] = float(str(readability_score))
|
|
except (ValueError, TypeError):
|
|
metadata["readability_score"] = None
|
|
if completeness_score := metadata_map.get("completeness_score"):
|
|
try:
|
|
metadata["completeness_score"] = float(str(completeness_score))
|
|
except (ValueError, TypeError):
|
|
metadata["completeness_score"] = None
|
|
|
|
def _extract_ingestion_source(self, metadata_map: dict[str, object]) -> IngestionSource:
|
|
"""Extract and validate ingestion source."""
|
|
source_value = str(metadata_map.get("ingestion_source", IngestionSource.WEB.value))
|
|
try:
|
|
return IngestionSource(source_value)
|
|
except ValueError:
|
|
return IngestionSource.WEB
|
|
|
|
@override
|
|
async def search(
|
|
self,
|
|
query: str,
|
|
limit: int = 10,
|
|
threshold: float = 0.7,
|
|
*,
|
|
collection_name: str | None = None,
|
|
) -> AsyncGenerator[Document, None]:
|
|
"""Search documents using R2R."""
|
|
try:
|
|
search_settings: dict[str, object] = {
|
|
"limit": limit,
|
|
"similarity_threshold": threshold,
|
|
}
|
|
|
|
if collection_name:
|
|
collection_id = await self._ensure_collection(collection_name)
|
|
search_settings["collection_ids"] = [collection_id]
|
|
|
|
search_response = await self.client.retrieval.search(
|
|
query=query,
|
|
search_settings=search_settings,
|
|
)
|
|
|
|
for result in _as_sequence(getattr(search_response, "results", ())):
|
|
result_map = _as_mapping(result)
|
|
document_id_value = result_map.get(
|
|
"document_id", getattr(result, "document_id", None)
|
|
)
|
|
if document_id_value is None:
|
|
continue
|
|
document_id = str(document_id_value)
|
|
|
|
try:
|
|
doc_response = await self.client.documents.retrieve(document_id)
|
|
except R2RException as exc:
|
|
import logging
|
|
|
|
logging.warning(
|
|
f"Failed to retrieve document {document_id} during search: {exc}"
|
|
)
|
|
continue
|
|
|
|
document_payload = getattr(doc_response, "results", doc_response)
|
|
document = self._convert_to_document(document_payload, collection_name)
|
|
|
|
score_value = result_map.get("score", getattr(result, "score", None))
|
|
if score_value is not None:
|
|
try:
|
|
# Handle various score value types safely
|
|
if isinstance(score_value, (int, float, str)):
|
|
document.score = float(score_value)
|
|
else:
|
|
# For unknown types, try string conversion first
|
|
document.score = float(str(score_value))
|
|
except (TypeError, ValueError) as e:
|
|
import logging
|
|
|
|
logging.debug(
|
|
f"Invalid score value {score_value} for document {document_id}: {e}"
|
|
)
|
|
document.score = None
|
|
|
|
yield document
|
|
|
|
except R2RException as exc:
|
|
raise StorageError(f"Search failed: {exc}") from exc
|
|
|
|
@override
|
|
async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool:
|
|
"""Delete a document."""
|
|
try:
|
|
_ = await self.client.documents.delete(document_id)
|
|
return True
|
|
except R2RException:
|
|
return False
|
|
|
|
@override
|
|
async def count(self, *, collection_name: str | None = None) -> int:
|
|
"""Get document count in collection."""
|
|
endpoint = self.endpoint
|
|
client = self._create_http_client()
|
|
try:
|
|
# Get collections and find the count for the specific collection
|
|
response = await client.get(f"{endpoint}/v3/collections")
|
|
response.raise_for_status()
|
|
data: dict[str, object] = response.json()
|
|
|
|
target_collection = collection_name or self.config.collection_name
|
|
results = cast(list[dict[str, object]], data.get("results", []))
|
|
for collection in results:
|
|
if collection.get("name") == target_collection:
|
|
doc_count = collection.get("document_count", 0)
|
|
return _as_int(doc_count)
|
|
|
|
# Collection not found
|
|
return 0
|
|
except Exception:
|
|
return 0
|
|
finally:
|
|
await client.aclose()
|
|
|
|
# This should never be reached, but satisfies static analyzer
|
|
return 0
|
|
|
|
@override
|
|
async def close(self) -> None:
|
|
"""Close R2R client."""
|
|
try:
|
|
await self.client.close()
|
|
except Exception as e:
|
|
import logging
|
|
|
|
logging.warning(f"Error closing R2R client: {e}")
|
|
|
|
async def __aenter__(self) -> Self:
|
|
"""Async context manager entry."""
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: object | None,
|
|
) -> None:
|
|
"""Async context manager exit with proper cleanup."""
|
|
await self.close()
|
|
|
|
# Additional R2R-specific comprehensive management methods
|
|
|
|
async def create_collection(self, name: str, description: str | None = None) -> str:
|
|
"""Create a new collection."""
|
|
try:
|
|
response = await self.client.collections.create(name=name, description=description)
|
|
created = _as_mapping(getattr(response, "results", {}))
|
|
return str(created.get("id", name))
|
|
except R2RException as exc:
|
|
raise StorageError(f"Failed to create collection {name}: {exc}") from exc
|
|
|
|
async def delete_collection(self, collection_name: str) -> bool:
|
|
"""Delete a collection."""
|
|
try:
|
|
collection_id = await self._ensure_collection(collection_name)
|
|
_ = await self.client.collections.delete(collection_id)
|
|
return True
|
|
except R2RException:
|
|
return False
|
|
|
|
@override
|
|
async def list_collections(self) -> list[str]:
|
|
"""List all available collections."""
|
|
endpoint = self.endpoint
|
|
client = self._create_http_client()
|
|
try:
|
|
response = await client.get(f"{endpoint}/v3/collections")
|
|
response.raise_for_status()
|
|
data: dict[str, object] = response.json()
|
|
|
|
collection_names: list[str] = []
|
|
results = cast(list[dict[str, object]], data.get("results", []))
|
|
for entry in results:
|
|
if name := entry.get("name"):
|
|
collection_names.append(str(name))
|
|
return collection_names
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to list collections: {e}") from e
|
|
finally:
|
|
await client.aclose()
|
|
|
|
# This should never be reached, but satisfies static analyzer
|
|
return []
|
|
|
|
async def list_collections_detailed(self) -> list[dict[str, object]]:
|
|
"""List all available collections with detailed information."""
|
|
try:
|
|
response = await self.client.collections.list()
|
|
collections: list[dict[str, object]] = []
|
|
for entry in _as_sequence(getattr(response, "results", ())):
|
|
entry_map = _as_mapping(entry)
|
|
collections.append(
|
|
{
|
|
"id": str(entry_map.get("id", "")),
|
|
"name": str(entry_map.get("name", "")),
|
|
"description": entry_map.get("description"),
|
|
}
|
|
)
|
|
return collections
|
|
except R2RException as exc:
|
|
raise StorageError(f"Failed to list collections: {exc}") from exc
|
|
|
|
async def get_document_chunks(self, document_id: str) -> list[dict[str, object]]:
|
|
"""Get all chunks for a specific document."""
|
|
try:
|
|
response = await self.client.chunks.list(filters={"document_id": document_id})
|
|
return [
|
|
dict(_as_mapping(chunk)) for chunk in _as_sequence(getattr(response, "results", ()))
|
|
]
|
|
except R2RException as exc:
|
|
raise StorageError(f"Failed to get chunks for document {document_id}: {exc}") from exc
|
|
|
|
async def extract_entities(self, document_id: str) -> dict[str, object]:
|
|
"""Extract entities and relationships from a document."""
|
|
try:
|
|
response = await self.client.documents.extract(id=document_id)
|
|
return dict(_as_mapping(getattr(response, "results", {})))
|
|
except R2RException as exc:
|
|
raise StorageError(
|
|
f"Failed to extract entities from document {document_id}: {exc}"
|
|
) from exc
|
|
|
|
async def get_document_overview(self, document_id: str) -> dict[str, object]:
|
|
"""Get comprehensive document overview and statistics."""
|
|
try:
|
|
doc_response = await self.client.documents.retrieve(document_id)
|
|
chunks_response = await self.client.chunks.list(filters={"document_id": document_id})
|
|
document_payload = dict(_as_mapping(getattr(doc_response, "results", {})))
|
|
chunk_payload = [
|
|
dict(_as_mapping(chunk))
|
|
for chunk in _as_sequence(getattr(chunks_response, "results", ()))
|
|
]
|
|
return {
|
|
"document": document_payload,
|
|
"chunk_count": len(chunk_payload),
|
|
"chunks": chunk_payload,
|
|
}
|
|
except R2RException as exc:
|
|
raise StorageError(f"Failed to get overview for document {document_id}: {exc}") from exc
|
|
|
|
@override
|
|
async def list_documents(
|
|
self,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
*,
|
|
collection_name: str | None = None,
|
|
) -> list[DocumentInfo]:
|
|
"""
|
|
List documents in R2R with pagination.
|
|
|
|
Args:
|
|
limit: Maximum number of documents to return
|
|
offset: Number of documents to skip
|
|
collection_name: Collection name (optional)
|
|
|
|
Returns:
|
|
List of document dictionaries with metadata
|
|
"""
|
|
try:
|
|
documents: list[DocumentInfo] = []
|
|
|
|
if collection_name:
|
|
# Get collection ID first
|
|
collection_id = await self._ensure_collection(collection_name)
|
|
# Use the collections API to list documents in a specific collection
|
|
endpoint = self.endpoint
|
|
client = self._create_http_client()
|
|
try:
|
|
params = {"offset": offset, "limit": limit}
|
|
response = await client.get(
|
|
f"{endpoint}/v3/collections/{collection_id}/documents", params=params
|
|
)
|
|
response.raise_for_status()
|
|
data: dict[str, object] = response.json()
|
|
finally:
|
|
await client.aclose()
|
|
|
|
doc_sequence = _as_sequence(data.get("results", []))
|
|
else:
|
|
# List all documents
|
|
r2r_response = await self.client.documents.list(offset=offset, limit=limit)
|
|
documents_data: list[object] | dict[str, object] = getattr(
|
|
r2r_response, "results", []
|
|
)
|
|
|
|
doc_sequence = _as_sequence(
|
|
documents_data.get("results", [])
|
|
if isinstance(documents_data, dict)
|
|
else documents_data
|
|
)
|
|
|
|
for doc_data in doc_sequence:
|
|
doc_map = _as_mapping(doc_data)
|
|
|
|
# Extract standard document fields
|
|
doc_id = str(doc_map.get("id", ""))
|
|
title = str(doc_map.get("title", "Untitled"))
|
|
metadata = _as_mapping(doc_map.get("metadata", {}))
|
|
|
|
document_info: DocumentInfo = {
|
|
"id": doc_id,
|
|
"title": title,
|
|
"source_url": str(metadata.get("source_url", "")),
|
|
"description": str(metadata.get("description", "")),
|
|
"content_type": str(metadata.get("content_type", "text/plain")),
|
|
"content_preview": str(doc_map.get("content", ""))[:200] + "..."
|
|
if doc_map.get("content")
|
|
else "",
|
|
"word_count": _as_int(metadata.get("word_count", 0)),
|
|
"timestamp": str(doc_map.get("created_at", "")),
|
|
}
|
|
documents.append(document_info)
|
|
|
|
return documents
|
|
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to list documents: {e}") from e
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/storage/base.py">
|
|
"""Base storage interface."""
|
|
|
|
import asyncio
|
|
import logging
|
|
import random
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import AsyncGenerator
|
|
from types import TracebackType
|
|
from typing import Final
|
|
|
|
import httpx
|
|
from pydantic import SecretStr
|
|
|
|
from ..core.exceptions import StorageError
|
|
from ..core.models import Document, StorageConfig
|
|
from .types import CollectionSummary, DocumentInfo
|
|
|
|
LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
|
|
|
|
|
class TypedHttpClient:
|
|
"""
|
|
A properly typed HTTP client wrapper for HTTPX.
|
|
|
|
Provides consistent exception handling and type annotations
|
|
for storage adapters that use HTTP APIs.
|
|
|
|
Note: Some type checkers (Pylance) may report warnings about HTTPX types
|
|
due to library compatibility issues. The code functions correctly at runtime.
|
|
"""
|
|
|
|
client: httpx.AsyncClient
|
|
_base_url: str
|
|
|
|
def __init__(
|
|
self,
|
|
base_url: str,
|
|
*,
|
|
api_key: SecretStr | None = None,
|
|
timeout: float = 30.0,
|
|
headers: dict[str, str] | None = None,
|
|
max_connections: int = 100,
|
|
max_keepalive_connections: int = 20,
|
|
):
|
|
"""
|
|
Initialize the typed HTTP client.
|
|
|
|
Args:
|
|
base_url: Base URL for all requests
|
|
api_key: Optional API key for authentication
|
|
timeout: Request timeout in seconds
|
|
headers: Additional headers to include with requests
|
|
max_connections: Maximum total connections in pool
|
|
max_keepalive_connections: Maximum keepalive connections
|
|
"""
|
|
self._base_url = base_url
|
|
|
|
# Build headers with optional authentication
|
|
client_headers: dict[str, str] = headers or {}
|
|
if api_key:
|
|
client_headers["Authorization"] = f"Bearer {api_key.get_secret_value()}"
|
|
|
|
# Create typed client configuration with connection pooling
|
|
limits = httpx.Limits(
|
|
max_connections=max_connections, max_keepalive_connections=max_keepalive_connections
|
|
)
|
|
timeout_config = httpx.Timeout(connect=5.0, read=timeout, write=30.0, pool=10.0)
|
|
self.client = httpx.AsyncClient(
|
|
base_url=base_url, headers=client_headers, timeout=timeout_config, limits=limits
|
|
)
|
|
|
|
async def request(
|
|
self,
|
|
method: str,
|
|
path: str,
|
|
*,
|
|
allow_404: bool = False,
|
|
json: dict[str, object] | None = None,
|
|
data: dict[str, object] | None = None,
|
|
files: dict[str, tuple[str, bytes, str]] | None = None,
|
|
params: dict[str, str | bool] | None = None,
|
|
max_retries: int = 3,
|
|
retry_delay: float = 1.0,
|
|
) -> httpx.Response | None:
|
|
"""
|
|
Perform an HTTP request with consistent error handling and retries.
|
|
|
|
Args:
|
|
method: HTTP method (GET, POST, DELETE, etc.)
|
|
path: URL path relative to base_url
|
|
allow_404: If True, return None for 404 responses instead of raising
|
|
json: JSON data to send
|
|
data: Form data to send
|
|
files: Files to upload
|
|
params: Query parameters
|
|
max_retries: Maximum number of retry attempts
|
|
retry_delay: Base delay between retries in seconds
|
|
|
|
Returns:
|
|
HTTP response object, or None if allow_404=True and status is 404
|
|
|
|
Raises:
|
|
StorageError: If request fails after retries
|
|
"""
|
|
last_exception: Exception | None = None
|
|
|
|
for attempt in range(max_retries + 1):
|
|
try:
|
|
response = await self.client.request(
|
|
method, path, json=json, data=data, files=files, params=params
|
|
)
|
|
response.raise_for_status()
|
|
return response
|
|
except httpx.HTTPStatusError as e:
|
|
# Handle 404 as special case if requested
|
|
if allow_404 and e.response.status_code == 404:
|
|
LOGGER.debug("Resource not found (404): %s %s", method, path)
|
|
return None
|
|
|
|
# Don't retry client errors (4xx except for specific cases)
|
|
if 400 <= e.response.status_code < 500 and e.response.status_code not in [429, 408]:
|
|
raise StorageError(
|
|
f"HTTP {e.response.status_code} error from {self._base_url}: {e}"
|
|
) from e
|
|
|
|
last_exception = e
|
|
if attempt < max_retries:
|
|
# Exponential backoff with jitter for retryable errors
|
|
delay = retry_delay * (2**attempt) + random.uniform(0, 1)
|
|
LOGGER.warning(
|
|
"HTTP %d error on attempt %d/%d, retrying in %.2fs: %s",
|
|
e.response.status_code,
|
|
attempt + 1,
|
|
max_retries + 1,
|
|
delay,
|
|
e,
|
|
)
|
|
await asyncio.sleep(delay)
|
|
|
|
except httpx.HTTPError as e:
|
|
last_exception = e
|
|
if attempt < max_retries:
|
|
# Retry transport errors with backoff
|
|
delay = retry_delay * (2**attempt) + random.uniform(0, 1)
|
|
LOGGER.warning(
|
|
"HTTP transport error on attempt %d/%d, retrying in %.2fs: %s",
|
|
attempt + 1,
|
|
max_retries + 1,
|
|
delay,
|
|
e,
|
|
)
|
|
await asyncio.sleep(delay)
|
|
|
|
# All retries exhausted - last_exception should always be set if we reach here
|
|
if last_exception is None:
|
|
raise StorageError(
|
|
f"Request to {self._base_url} failed after {max_retries + 1} attempts with unknown error"
|
|
)
|
|
|
|
if isinstance(last_exception, httpx.HTTPStatusError):
|
|
raise StorageError(
|
|
f"HTTP {last_exception.response.status_code} error from {self._base_url} after {max_retries + 1} attempts: {last_exception}"
|
|
) from last_exception
|
|
else:
|
|
raise StorageError(
|
|
f"HTTP transport error to {self._base_url} after {max_retries + 1} attempts: {last_exception}"
|
|
) from last_exception
|
|
|
|
async def close(self) -> None:
|
|
"""Close the HTTP client and cleanup resources."""
|
|
try:
|
|
await self.client.aclose()
|
|
except Exception as e:
|
|
LOGGER.warning("Error closing HTTP client: %s", e)
|
|
|
|
async def __aenter__(self) -> "TypedHttpClient":
|
|
"""Async context manager entry."""
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: TracebackType | None,
|
|
) -> None:
|
|
"""Async context manager exit."""
|
|
await self.close()
|
|
|
|
|
|
class BaseStorage(ABC):
|
|
"""Abstract base class for storage adapters."""
|
|
|
|
config: StorageConfig
|
|
|
|
def __init__(self, config: StorageConfig):
|
|
"""
|
|
Initialize storage adapter.
|
|
|
|
Args:
|
|
config: Storage configuration
|
|
"""
|
|
self.config = config
|
|
|
|
@property
|
|
def display_name(self) -> str:
|
|
"""Human-readable name for UI display."""
|
|
return self.__class__.__name__.replace("Storage", "")
|
|
|
|
@abstractmethod
|
|
async def initialize(self) -> None:
|
|
"""Initialize the storage backend and create collections if needed."""
|
|
pass # pragma: no cover
|
|
|
|
@abstractmethod
|
|
async def store(self, document: Document, *, collection_name: str | None = None) -> str:
|
|
"""
|
|
Store a single document.
|
|
|
|
Args:
|
|
document: Document to store
|
|
|
|
Returns:
|
|
Document ID
|
|
"""
|
|
pass # pragma: no cover
|
|
|
|
@abstractmethod
|
|
async def store_batch(
|
|
self, documents: list[Document], *, collection_name: str | None = None
|
|
) -> list[str]:
|
|
"""
|
|
Store multiple documents in batch.
|
|
|
|
Args:
|
|
documents: List of documents to store
|
|
|
|
Returns:
|
|
List of document IDs
|
|
"""
|
|
pass # pragma: no cover
|
|
|
|
async def retrieve(
|
|
self, document_id: str, *, collection_name: str | None = None
|
|
) -> Document | None:
|
|
"""
|
|
Retrieve a document by ID (if supported by backend).
|
|
|
|
Args:
|
|
document_id: Document ID
|
|
|
|
Returns:
|
|
Document or None if not found
|
|
|
|
Raises:
|
|
NotImplementedError: If backend doesn't support retrieval
|
|
"""
|
|
raise NotImplementedError(f"{self.__class__.__name__} doesn't support document retrieval")
|
|
|
|
async def check_exists(
|
|
self, document_id: str, *, collection_name: str | None = None, stale_after_days: int = 30
|
|
) -> bool:
|
|
"""
|
|
Check if a document exists and is not stale.
|
|
|
|
Args:
|
|
document_id: Document ID to check
|
|
collection_name: Collection to check in
|
|
stale_after_days: Consider document stale after this many days
|
|
|
|
Returns:
|
|
True if document exists and is not stale, False otherwise
|
|
"""
|
|
try:
|
|
document = await self.retrieve(document_id, collection_name=collection_name)
|
|
if document is None:
|
|
return False
|
|
|
|
# Check staleness if timestamp is available
|
|
if "timestamp" in document.metadata:
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
timestamp_obj = document.metadata["timestamp"]
|
|
|
|
# Handle both datetime objects and ISO strings
|
|
if isinstance(timestamp_obj, datetime):
|
|
timestamp = timestamp_obj
|
|
# Ensure timezone awareness
|
|
if timestamp.tzinfo is None:
|
|
timestamp = timestamp.replace(tzinfo=UTC)
|
|
elif isinstance(timestamp_obj, str):
|
|
try:
|
|
timestamp = datetime.fromisoformat(timestamp_obj)
|
|
# Ensure timezone awareness
|
|
if timestamp.tzinfo is None:
|
|
timestamp = timestamp.replace(tzinfo=UTC)
|
|
except ValueError:
|
|
# If parsing fails, assume document is stale
|
|
return False
|
|
else:
|
|
# Unknown timestamp format, assume stale
|
|
return False
|
|
|
|
cutoff = datetime.now(UTC) - timedelta(days=stale_after_days)
|
|
return timestamp >= cutoff
|
|
|
|
# If no timestamp, assume it exists and is valid
|
|
return True
|
|
except Exception:
|
|
# Backend doesn't support retrieval, assume doesn't exist
|
|
return False
|
|
|
|
def search(
|
|
self,
|
|
query: str,
|
|
limit: int = 10,
|
|
threshold: float = 0.7,
|
|
*,
|
|
collection_name: str | None = None,
|
|
) -> AsyncGenerator[Document, None]:
|
|
"""
|
|
Search for documents (if supported by backend).
|
|
|
|
Args:
|
|
query: Search query
|
|
limit: Maximum number of results
|
|
threshold: Similarity threshold
|
|
|
|
Yields:
|
|
Matching documents
|
|
|
|
Raises:
|
|
NotImplementedError: If backend doesn't support search
|
|
"""
|
|
raise NotImplementedError(f"{self.__class__.__name__} doesn't support search")
|
|
|
|
@abstractmethod
|
|
async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool:
|
|
"""
|
|
Delete a document.
|
|
|
|
Args:
|
|
document_id: Document ID
|
|
|
|
Returns:
|
|
True if deleted successfully
|
|
"""
|
|
pass # pragma: no cover
|
|
|
|
async def count(self, *, collection_name: str | None = None) -> int:
|
|
"""
|
|
Get total document count (if supported by backend).
|
|
|
|
Returns:
|
|
Number of documents, 0 if not supported
|
|
"""
|
|
return 0
|
|
|
|
async def list_collections(self) -> list[str]:
|
|
"""
|
|
List available collections (if supported by backend).
|
|
|
|
Returns:
|
|
List of collection names, empty list if not supported
|
|
"""
|
|
return []
|
|
|
|
async def describe_collections(self) -> list[CollectionSummary]:
|
|
"""
|
|
Describe available collections with metadata (if supported by backend).
|
|
|
|
Returns:
|
|
List of collection metadata, empty list if not supported
|
|
"""
|
|
return []
|
|
|
|
async def delete_collection(self, collection_name: str) -> bool:
|
|
"""
|
|
Delete a collection (if supported by backend).
|
|
|
|
Args:
|
|
collection_name: Name of collection to delete
|
|
|
|
Returns:
|
|
True if deleted successfully, False if not supported
|
|
"""
|
|
return False
|
|
|
|
async def delete_documents(
|
|
self, document_ids: list[str], *, collection_name: str | None = None
|
|
) -> dict[str, bool]:
|
|
"""
|
|
Delete documents by IDs (if supported by backend).
|
|
|
|
Args:
|
|
document_ids: List of document IDs to delete
|
|
collection_name: Collection to delete from
|
|
|
|
Returns:
|
|
Dict mapping document IDs to success status, empty if not supported
|
|
"""
|
|
return {}
|
|
|
|
async def list_documents(
|
|
self,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
*,
|
|
collection_name: str | None = None,
|
|
) -> list[DocumentInfo]:
|
|
"""
|
|
List documents in the storage backend (if supported).
|
|
|
|
Args:
|
|
limit: Maximum number of documents to return
|
|
offset: Number of documents to skip
|
|
collection_name: Collection to list documents from
|
|
|
|
Returns:
|
|
List of document information with metadata
|
|
|
|
Raises:
|
|
NotImplementedError: If backend doesn't support document listing
|
|
"""
|
|
raise NotImplementedError(f"{self.__class__.__name__} doesn't support document listing")
|
|
|
|
async def close(self) -> None:
|
|
"""
|
|
Close storage connections and cleanup resources.
|
|
|
|
Default implementation does nothing.
|
|
"""
|
|
# Default implementation - storage backends can override to cleanup connections
|
|
return None
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/storage/openwebui.py">
|
|
"""Open WebUI storage adapter."""
|
|
|
|
import asyncio
|
|
import contextlib
|
|
import logging
|
|
import time
|
|
from typing import Final, NamedTuple, TypedDict
|
|
|
|
from typing_extensions import override
|
|
|
|
from ..core.exceptions import StorageError
|
|
from ..core.models import Document, StorageConfig
|
|
from .base import BaseStorage, TypedHttpClient
|
|
from .types import CollectionSummary, DocumentInfo
|
|
|
|
LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
|
|
|
|
|
class OpenWebUIFileResponse(TypedDict, total=False):
|
|
"""OpenWebUI API file response structure."""
|
|
|
|
id: str
|
|
filename: str
|
|
name: str
|
|
content_type: str
|
|
size: int
|
|
created_at: str
|
|
meta: dict[str, str | int]
|
|
|
|
|
|
class OpenWebUIKnowledgeBase(TypedDict, total=False):
|
|
"""OpenWebUI knowledge base response structure."""
|
|
|
|
id: str
|
|
name: str
|
|
description: str
|
|
files: list[OpenWebUIFileResponse]
|
|
data: dict[str, str]
|
|
created_at: str
|
|
updated_at: str
|
|
|
|
|
|
class CacheEntry(NamedTuple):
|
|
"""Cache entry with value and expiration time."""
|
|
|
|
value: str
|
|
expires_at: float
|
|
|
|
|
|
class OpenWebUIStorage(BaseStorage):
|
|
"""Storage adapter for Open WebUI knowledge endpoints."""
|
|
|
|
http_client: TypedHttpClient
|
|
_knowledge_cache: dict[str, CacheEntry]
|
|
_cache_ttl: float
|
|
|
|
def __init__(self, config: StorageConfig):
|
|
"""
|
|
Initialize Open WebUI storage.
|
|
|
|
Args:
|
|
config: Storage configuration
|
|
"""
|
|
super().__init__(config)
|
|
|
|
self.http_client = TypedHttpClient(
|
|
base_url=str(config.endpoint),
|
|
api_key=config.api_key,
|
|
timeout=30.0,
|
|
)
|
|
self._knowledge_cache = {}
|
|
self._cache_ttl = 300.0 # 5 minutes TTL
|
|
|
|
@override
|
|
async def initialize(self) -> None:
|
|
"""Initialize Open WebUI connection."""
|
|
try:
|
|
if self.config.collection_name:
|
|
await self._get_knowledge_id(
|
|
self.config.collection_name,
|
|
create=True,
|
|
)
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to initialize Open WebUI: {e}") from e
|
|
|
|
async def _create_collection(self, name: str) -> str:
|
|
"""Create knowledge base in Open WebUI."""
|
|
response = await self.http_client.request(
|
|
"POST",
|
|
"/api/v1/knowledge/create",
|
|
json={
|
|
"name": name,
|
|
"description": "Documents ingested from various sources",
|
|
"data": {},
|
|
"access_control": None,
|
|
},
|
|
)
|
|
if response is None:
|
|
raise StorageError("Unexpected None response from knowledge base creation")
|
|
result = response.json()
|
|
knowledge_id = result.get("id")
|
|
|
|
if not knowledge_id or not isinstance(knowledge_id, str):
|
|
raise StorageError("Knowledge base creation failed: no ID returned")
|
|
|
|
return str(knowledge_id)
|
|
|
|
async def _fetch_knowledge_bases(self) -> list[OpenWebUIKnowledgeBase]:
|
|
"""Return the list of knowledge bases from the API."""
|
|
response = await self.http_client.request("GET", "/api/v1/knowledge/list")
|
|
if response is None:
|
|
return []
|
|
data = response.json()
|
|
if not isinstance(data, list):
|
|
return []
|
|
normalized: list[OpenWebUIKnowledgeBase] = []
|
|
for item in data:
|
|
if (
|
|
isinstance(item, dict)
|
|
and "id" in item
|
|
and "name" in item
|
|
and isinstance(item["id"], str)
|
|
and isinstance(item["name"], str)
|
|
):
|
|
# Create a new dict with known structure
|
|
kb_item: OpenWebUIKnowledgeBase = {
|
|
"id": item["id"],
|
|
"name": item["name"],
|
|
"description": item.get("description", ""),
|
|
"created_at": item.get("created_at", ""),
|
|
"updated_at": item.get("updated_at", ""),
|
|
}
|
|
if "files" in item and isinstance(item["files"], list):
|
|
kb_item["files"] = item["files"]
|
|
if "data" in item and isinstance(item["data"], dict):
|
|
kb_item["data"] = item["data"]
|
|
normalized.append(kb_item)
|
|
return normalized
|
|
|
|
async def _get_knowledge_id(
|
|
self,
|
|
name: str | None,
|
|
*,
|
|
create: bool,
|
|
) -> str | None:
|
|
"""Retrieve (and optionally create) a knowledge base identifier."""
|
|
target_raw = name or self.config.collection_name
|
|
target = str(target_raw) if target_raw else ""
|
|
if not target:
|
|
raise StorageError("Knowledge base name is required")
|
|
|
|
# Check cache with TTL
|
|
if cached_entry := self._knowledge_cache.get(target):
|
|
if time.time() < cached_entry.expires_at:
|
|
return cached_entry.value
|
|
else:
|
|
# Entry expired, remove it
|
|
del self._knowledge_cache[target]
|
|
|
|
knowledge_bases = await self._fetch_knowledge_bases()
|
|
for kb in knowledge_bases:
|
|
if kb.get("name") == target:
|
|
kb_id = kb.get("id")
|
|
if isinstance(kb_id, str):
|
|
expires_at = time.time() + self._cache_ttl
|
|
self._knowledge_cache[target] = CacheEntry(kb_id, expires_at)
|
|
return kb_id
|
|
|
|
if not create:
|
|
return None
|
|
|
|
knowledge_id = await self._create_collection(target)
|
|
expires_at = time.time() + self._cache_ttl
|
|
self._knowledge_cache[target] = CacheEntry(knowledge_id, expires_at)
|
|
return knowledge_id
|
|
|
|
@override
|
|
async def store(self, document: Document, *, collection_name: str | None = None) -> str:
|
|
"""
|
|
Store a document in Open WebUI as a file.
|
|
|
|
Args:
|
|
document: Document to store
|
|
|
|
Returns:
|
|
File ID
|
|
"""
|
|
try:
|
|
knowledge_id = await self._get_knowledge_id(
|
|
collection_name,
|
|
create=True,
|
|
)
|
|
if not knowledge_id:
|
|
raise StorageError("Knowledge base not initialized")
|
|
|
|
# Step 1: Upload document as file
|
|
# Use document title from metadata if available, otherwise fall back to ID
|
|
filename = document.metadata.get("title") or f"doc_{document.id}"
|
|
# Ensure filename has proper extension
|
|
if not filename.endswith((".txt", ".md", ".pdf", ".doc", ".docx")):
|
|
filename = f"{filename}.txt"
|
|
files = {"file": (filename, document.content.encode(), "text/plain")}
|
|
response = await self.http_client.request(
|
|
"POST",
|
|
"/api/v1/files/",
|
|
files=files,
|
|
params={"process": True, "process_in_background": False},
|
|
)
|
|
if response is None:
|
|
raise StorageError("Unexpected None response from file upload")
|
|
|
|
file_data = response.json()
|
|
file_id = file_data.get("id")
|
|
|
|
if not file_id or not isinstance(file_id, str):
|
|
raise StorageError("File upload failed: no file ID returned")
|
|
|
|
# Step 2: Add file to knowledge base
|
|
response = await self.http_client.request(
|
|
"POST", f"/api/v1/knowledge/{knowledge_id}/file/add", json={"file_id": file_id}
|
|
)
|
|
|
|
return str(file_id)
|
|
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to store document: {e}") from e
|
|
|
|
@override
|
|
async def store_batch(
|
|
self, documents: list[Document], *, collection_name: str | None = None
|
|
) -> list[str]:
|
|
"""
|
|
Store multiple documents as files in batch.
|
|
|
|
Args:
|
|
documents: List of documents
|
|
|
|
Returns:
|
|
List of file IDs
|
|
"""
|
|
try:
|
|
knowledge_id = await self._get_knowledge_id(
|
|
collection_name,
|
|
create=True,
|
|
)
|
|
if not knowledge_id:
|
|
raise StorageError("Knowledge base not initialized")
|
|
|
|
async def upload_and_attach(doc: Document) -> str:
|
|
# Use document title from metadata if available, otherwise fall back to ID
|
|
filename = doc.metadata.get("title") or f"doc_{doc.id}"
|
|
# Ensure filename has proper extension
|
|
if not filename.endswith((".txt", ".md", ".pdf", ".doc", ".docx")):
|
|
filename = f"{filename}.txt"
|
|
files = {"file": (filename, doc.content.encode(), "text/plain")}
|
|
upload_response = await self.http_client.request(
|
|
"POST",
|
|
"/api/v1/files/",
|
|
files=files,
|
|
params={"process": True, "process_in_background": False},
|
|
)
|
|
if upload_response is None:
|
|
raise StorageError(
|
|
f"Unexpected None response from file upload for document {doc.id}"
|
|
)
|
|
|
|
file_data = upload_response.json()
|
|
file_id = file_data.get("id")
|
|
|
|
if not file_id or not isinstance(file_id, str):
|
|
raise StorageError(
|
|
f"File upload failed for document {doc.id}: no file ID returned"
|
|
)
|
|
|
|
await self.http_client.request(
|
|
"POST", f"/api/v1/knowledge/{knowledge_id}/file/add", json={"file_id": file_id}
|
|
)
|
|
|
|
return str(file_id)
|
|
|
|
tasks = [upload_and_attach(doc) for doc in documents]
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
file_ids: list[str] = []
|
|
failures: list[str] = []
|
|
|
|
for index, result in enumerate(results):
|
|
doc = documents[index]
|
|
if isinstance(result, Exception):
|
|
failures.append(f"{doc.id}: {result}")
|
|
else:
|
|
if isinstance(result, str):
|
|
file_ids.append(result)
|
|
|
|
if failures:
|
|
LOGGER.warning(
|
|
"OpenWebUI partial batch failure for knowledge base %s: %s",
|
|
self.config.collection_name,
|
|
", ".join(failures),
|
|
)
|
|
|
|
return file_ids
|
|
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to store batch: {e}") from e
|
|
|
|
@override
|
|
async def retrieve(
|
|
self, document_id: str, *, collection_name: str | None = None
|
|
) -> Document | None:
|
|
"""
|
|
OpenWebUI doesn't support document retrieval by ID.
|
|
|
|
Args:
|
|
document_id: File ID (not supported)
|
|
collection_name: Collection name (not used)
|
|
|
|
Returns:
|
|
Always None - retrieval not supported
|
|
"""
|
|
_ = document_id, collection_name # Mark as used
|
|
# OpenWebUI uses file-based storage without direct document retrieval
|
|
raise NotImplementedError("OpenWebUI doesn't support document retrieval by ID")
|
|
|
|
@override
|
|
async def check_exists(
|
|
self, document_id: str, *, collection_name: str | None = None, stale_after_days: int = 30
|
|
) -> bool:
|
|
"""
|
|
Check if a document exists in OpenWebUI knowledge base by searching files.
|
|
|
|
Args:
|
|
document_id: Document ID to check (usually based on source URL)
|
|
collection_name: Knowledge base name
|
|
stale_after_days: Consider document stale after this many days
|
|
|
|
Returns:
|
|
True if document exists and is not stale, False otherwise
|
|
"""
|
|
try:
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
# Get knowledge base
|
|
knowledge_id = await self._get_knowledge_id(collection_name, create=False)
|
|
if not knowledge_id:
|
|
return False
|
|
|
|
# Get detailed knowledge base info to access files
|
|
response = await self.http_client.request("GET", f"/api/v1/knowledge/{knowledge_id}")
|
|
if response is None:
|
|
return False
|
|
|
|
kb_data = response.json()
|
|
files = kb_data.get("files", [])
|
|
|
|
# Look for file with matching document ID or source URL in metadata
|
|
cutoff = datetime.now(UTC) - timedelta(days=stale_after_days)
|
|
|
|
def _parse_openwebui_timestamp(timestamp_str: str) -> datetime | None:
|
|
"""Parse OpenWebUI timestamp with proper timezone handling."""
|
|
try:
|
|
# Handle both 'Z' suffix and explicit timezone
|
|
normalized = timestamp_str.replace("Z", "+00:00")
|
|
parsed = datetime.fromisoformat(normalized)
|
|
# Ensure timezone awareness
|
|
if parsed.tzinfo is None:
|
|
parsed = parsed.replace(tzinfo=UTC)
|
|
return parsed
|
|
except (ValueError, AttributeError):
|
|
return None
|
|
|
|
def _check_file_freshness(file_info: dict[str, object]) -> bool:
|
|
"""Check if file is fresh enough based on creation date."""
|
|
created_at = file_info.get("created_at")
|
|
if not isinstance(created_at, str):
|
|
# No date info available, consider stale to be safe
|
|
return False
|
|
|
|
file_date = _parse_openwebui_timestamp(created_at)
|
|
return file_date is not None and file_date >= cutoff
|
|
|
|
for file_info in files:
|
|
if not isinstance(file_info, dict):
|
|
continue
|
|
|
|
file_id = file_info.get("id")
|
|
if str(file_id) == document_id:
|
|
return _check_file_freshness(file_info)
|
|
|
|
# Also check meta.source_url if available for URL-based document IDs
|
|
meta = file_info.get("meta", {})
|
|
if isinstance(meta, dict):
|
|
source_url = meta.get("source_url")
|
|
if source_url and document_id in str(source_url):
|
|
return _check_file_freshness(file_info)
|
|
|
|
return False
|
|
|
|
except Exception as e:
|
|
LOGGER.debug("Error checking document existence in OpenWebUI: %s", e)
|
|
return False
|
|
|
|
@override
|
|
async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool:
|
|
"""
|
|
Remove a file from Open WebUI knowledge base.
|
|
|
|
Args:
|
|
document_id: File ID to remove
|
|
|
|
Returns:
|
|
True if removed successfully
|
|
"""
|
|
try:
|
|
knowledge_id = await self._get_knowledge_id(
|
|
collection_name,
|
|
create=False,
|
|
)
|
|
if not knowledge_id:
|
|
return False
|
|
|
|
# Remove file from knowledge base
|
|
await self.http_client.request(
|
|
"POST",
|
|
f"/api/v1/knowledge/{knowledge_id}/file/remove",
|
|
json={"file_id": document_id},
|
|
)
|
|
|
|
await self.http_client.request("DELETE", f"/api/v1/files/{document_id}", allow_404=True)
|
|
return True
|
|
except Exception as exc:
|
|
LOGGER.error("Error deleting file %s from OpenWebUI", document_id, exc_info=exc)
|
|
return False
|
|
|
|
async def list_collections(self) -> list[str]:
|
|
"""
|
|
List all available knowledge bases.
|
|
|
|
Returns:
|
|
List of knowledge base names
|
|
"""
|
|
try:
|
|
knowledge_bases = await self._fetch_knowledge_bases()
|
|
|
|
# Extract names from knowledge bases
|
|
return [
|
|
str(kb.get("name", f"knowledge_{kb.get('id', 'unknown')}") or "")
|
|
for kb in knowledge_bases
|
|
]
|
|
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to list knowledge bases: {e}") from e
|
|
|
|
async def delete_collection(self, collection_name: str) -> bool:
|
|
"""
|
|
Delete a knowledge base by name.
|
|
|
|
Args:
|
|
collection_name: Name of the knowledge base to delete
|
|
|
|
Returns:
|
|
True if deleted successfully, False otherwise
|
|
"""
|
|
try:
|
|
knowledge_id = await self._get_knowledge_id(collection_name, create=False)
|
|
if not knowledge_id:
|
|
# Collection doesn't exist, consider it already deleted
|
|
return True
|
|
|
|
# Delete the knowledge base using the OpenWebUI API
|
|
await self.http_client.request(
|
|
"DELETE", f"/api/v1/knowledge/{knowledge_id}/delete", allow_404=True
|
|
)
|
|
|
|
# Remove from cache if it exists
|
|
if collection_name in self._knowledge_cache:
|
|
del self._knowledge_cache[collection_name]
|
|
|
|
LOGGER.info("Successfully deleted knowledge base: %s", collection_name)
|
|
return True
|
|
|
|
except Exception as e:
|
|
if hasattr(e, "response"):
|
|
response_attr = getattr(e, "response", None)
|
|
if response_attr is not None and hasattr(response_attr, "status_code"):
|
|
with contextlib.suppress(Exception):
|
|
status_code = response_attr.status_code
|
|
if status_code == 404:
|
|
LOGGER.info(
|
|
"Knowledge base %s was already deleted or not found",
|
|
collection_name,
|
|
)
|
|
return True
|
|
LOGGER.error(
|
|
"Error deleting knowledge base %s from OpenWebUI",
|
|
collection_name,
|
|
exc_info=e,
|
|
)
|
|
return False
|
|
|
|
async def _get_knowledge_base_count(self, kb: OpenWebUIKnowledgeBase) -> int:
|
|
"""Get the file count for a knowledge base."""
|
|
kb_id = kb.get("id")
|
|
name = kb.get("name", "Unknown")
|
|
|
|
if not kb_id:
|
|
return self._count_files_from_basic_info(kb)
|
|
|
|
return await self._count_files_from_detailed_info(str(kb_id), str(name), kb)
|
|
|
|
def _count_files_from_basic_info(self, kb: OpenWebUIKnowledgeBase) -> int:
|
|
"""Count files from basic knowledge base info."""
|
|
files = kb.get("files", [])
|
|
return len(files) if isinstance(files, list) and files is not None else 0
|
|
|
|
async def _count_files_from_detailed_info(
|
|
self, kb_id: str, name: str, kb: OpenWebUIKnowledgeBase
|
|
) -> int:
|
|
"""Count files by fetching detailed knowledge base info."""
|
|
try:
|
|
LOGGER.debug(f"Fetching detailed info for KB '{name}' from /api/v1/knowledge/{kb_id}")
|
|
detail_response = await self.http_client.request("GET", f"/api/v1/knowledge/{kb_id}")
|
|
if detail_response is None:
|
|
LOGGER.warning(f"Knowledge base '{name}' (ID: {kb_id}) not found")
|
|
return self._count_files_from_basic_info(kb)
|
|
detailed_kb = detail_response.json()
|
|
|
|
files = detailed_kb.get("files", [])
|
|
count = len(files) if isinstance(files, list) and files is not None else 0
|
|
|
|
LOGGER.info(f"Knowledge base '{name}' (ID: {kb_id}): found {count} files")
|
|
return count
|
|
|
|
except Exception as e:
|
|
LOGGER.warning(f"Failed to get detailed info for KB '{name}' (ID: {kb_id}): {e}")
|
|
return self._count_files_from_basic_info(kb)
|
|
|
|
async def describe_collections(self) -> list[CollectionSummary]:
|
|
"""Return metadata about each knowledge base."""
|
|
try:
|
|
knowledge_bases = await self._fetch_knowledge_bases()
|
|
collections: list[CollectionSummary] = []
|
|
|
|
for kb in knowledge_bases:
|
|
count = await self._get_knowledge_base_count(kb)
|
|
name = kb.get("name", "Unknown")
|
|
size_mb = count * 0.5 # rough heuristic
|
|
|
|
summary: CollectionSummary = {
|
|
"name": str(name),
|
|
"count": count,
|
|
"size_mb": float(size_mb),
|
|
}
|
|
collections.append(summary)
|
|
|
|
return collections
|
|
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to describe knowledge bases: {e}") from e
|
|
|
|
async def count(self, *, collection_name: str | None = None) -> int:
|
|
"""
|
|
Get document count for a specific collection (knowledge base).
|
|
|
|
Args:
|
|
collection_name: Name of the knowledge base to count documents for
|
|
|
|
Returns:
|
|
Number of documents in the collection, 0 if collection not found
|
|
"""
|
|
if not collection_name:
|
|
# If no collection name provided, return total across all collections
|
|
try:
|
|
collections = await self.describe_collections()
|
|
return sum(
|
|
int(collection["count"]) if isinstance(collection["count"], (int, str)) else 0
|
|
for collection in collections
|
|
)
|
|
except Exception:
|
|
return 0
|
|
|
|
try:
|
|
# Get knowledge base by name and return its file count
|
|
kb = await self.get_knowledge_by_name(collection_name)
|
|
if not kb:
|
|
return 0
|
|
|
|
kb_id = kb.get("id")
|
|
if not kb_id:
|
|
return 0
|
|
|
|
# Get detailed knowledge base information to get accurate file count
|
|
detail_response = await self.http_client.request("GET", f"/api/v1/knowledge/{kb_id}")
|
|
if detail_response is None:
|
|
LOGGER.warning(f"Knowledge base '{collection_name}' (ID: {kb_id}) not found")
|
|
return self._count_files_from_basic_info(kb)
|
|
detailed_kb = detail_response.json()
|
|
|
|
files = detailed_kb.get("files", [])
|
|
count = len(files) if isinstance(files, list) else 0
|
|
|
|
LOGGER.debug(f"Count for collection '{collection_name}': {count} files")
|
|
return count
|
|
|
|
except Exception as e:
|
|
LOGGER.warning(f"Failed to get count for collection '{collection_name}': {e}")
|
|
return 0
|
|
|
|
async def get_knowledge_by_name(self, name: str) -> OpenWebUIKnowledgeBase | None:
|
|
"""
|
|
Get knowledge base details by name.
|
|
|
|
Args:
|
|
name: Knowledge base name
|
|
|
|
Returns:
|
|
Knowledge base details or None if not found
|
|
"""
|
|
try:
|
|
response = await self.http_client.request("GET", "/api/v1/knowledge/list")
|
|
if response is None:
|
|
return None
|
|
knowledge_bases = response.json()
|
|
|
|
# Find and properly type the knowledge base
|
|
for kb in knowledge_bases:
|
|
if (
|
|
isinstance(kb, dict)
|
|
and kb.get("name") == name
|
|
and "id" in kb
|
|
and isinstance(kb["id"], str)
|
|
):
|
|
# Create properly typed response
|
|
result: OpenWebUIKnowledgeBase = {
|
|
"id": kb["id"],
|
|
"name": str(kb["name"]),
|
|
"description": kb.get("description", ""),
|
|
"created_at": kb.get("created_at", ""),
|
|
"updated_at": kb.get("updated_at", ""),
|
|
}
|
|
if "files" in kb and isinstance(kb["files"], list):
|
|
result["files"] = kb["files"]
|
|
if "data" in kb and isinstance(kb["data"], dict):
|
|
result["data"] = kb["data"]
|
|
return result
|
|
return None
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to get knowledge base by name: {e}") from e
|
|
|
|
async def __aenter__(self) -> "OpenWebUIStorage":
|
|
"""Async context manager entry."""
|
|
await self.initialize()
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: object | None,
|
|
) -> None:
|
|
"""Async context manager exit."""
|
|
_ = exc_type, exc_val, exc_tb # Mark as used
|
|
await self.close()
|
|
|
|
async def list_documents(
|
|
self,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
*,
|
|
collection_name: str | None = None,
|
|
) -> list[DocumentInfo]:
|
|
"""
|
|
List documents (files) in a knowledge base.
|
|
|
|
NOTE: This is a basic implementation that attempts to extract file information
|
|
from OpenWebUI knowledge bases. The actual file listing capabilities depend
|
|
on the OpenWebUI API version and may not include detailed file metadata.
|
|
|
|
Args:
|
|
limit: Maximum number of documents to return
|
|
offset: Number of documents to skip
|
|
collection_name: Knowledge base name
|
|
|
|
Returns:
|
|
List of document dictionaries with available metadata
|
|
"""
|
|
try:
|
|
# Use the knowledge base name or fall back to default
|
|
kb_name = collection_name or self.config.collection_name or "default"
|
|
|
|
# Try to get knowledge base details
|
|
knowledge_base = await self.get_knowledge_by_name(kb_name)
|
|
if not knowledge_base:
|
|
# If specific KB not found, return empty list with a note
|
|
return []
|
|
|
|
# Extract files if available (API structure may vary)
|
|
files = knowledge_base.get("files", [])
|
|
|
|
# Handle different possible API response structures
|
|
if not isinstance(files, list):
|
|
# Some API versions might structure this differently
|
|
# Try to handle gracefully
|
|
return [
|
|
{
|
|
"id": "unknown",
|
|
"title": f"Knowledge Base: {kb_name}",
|
|
"source_url": "",
|
|
"description": "OpenWebUI knowledge base (file details not available)",
|
|
"content_type": "text/plain",
|
|
"content_preview": "Document listing not fully supported for OpenWebUI",
|
|
"word_count": 0,
|
|
"timestamp": "",
|
|
}
|
|
]
|
|
|
|
# Apply pagination
|
|
paginated_files = files[offset : offset + limit]
|
|
|
|
# Convert to document format with safe field access
|
|
documents: list[DocumentInfo] = []
|
|
for i, file_info in enumerate(paginated_files):
|
|
# Safely extract fields with fallbacks
|
|
doc_id = str(file_info.get("id", f"file_{i}"))
|
|
|
|
# Try multiple ways to get filename from OpenWebUI API response
|
|
filename = None
|
|
# Check direct filename field
|
|
if "filename" in file_info:
|
|
filename = file_info["filename"]
|
|
# Check name field
|
|
elif "name" in file_info:
|
|
filename = file_info["name"]
|
|
# Check meta.name (from FileModelResponse schema)
|
|
elif isinstance(file_info.get("meta"), dict):
|
|
meta = file_info.get("meta")
|
|
if isinstance(meta, dict):
|
|
filename_value = meta.get("name")
|
|
if isinstance(filename_value, str):
|
|
filename = filename_value
|
|
|
|
# Final fallback
|
|
if not filename:
|
|
filename = f"file_{i}"
|
|
|
|
filename = str(filename)
|
|
|
|
# Extract size from meta if available
|
|
size = 0
|
|
meta = file_info.get("meta")
|
|
if isinstance(meta, dict):
|
|
size_value = meta.get("size", 0)
|
|
size = int(size_value) if isinstance(size_value, (int, float)) else 0
|
|
else:
|
|
size_value = file_info.get("size", 0)
|
|
size = int(size_value) if isinstance(size_value, (int, float)) else 0
|
|
|
|
# Estimate word count from file size (very rough approximation)
|
|
word_count = max(1, int(size / 6)) if isinstance(size, (int, float)) else 0
|
|
|
|
doc_info: DocumentInfo = {
|
|
"id": doc_id,
|
|
"title": filename,
|
|
"source_url": "", # OpenWebUI files don't typically have source URLs
|
|
"description": f"File: {filename}",
|
|
"content_type": str(file_info.get("content_type", "text/plain")),
|
|
"content_preview": f"File uploaded to OpenWebUI: {filename}",
|
|
"word_count": word_count,
|
|
"timestamp": str(file_info.get("created_at") or file_info.get("timestamp", "")),
|
|
}
|
|
documents.append(doc_info)
|
|
|
|
return documents
|
|
|
|
except Exception as e:
|
|
# Since OpenWebUI file listing API structure is not guaranteed,
|
|
# we gracefully fall back rather than raise an error
|
|
import logging
|
|
|
|
logging.warning(f"OpenWebUI document listing failed: {e}")
|
|
|
|
# Return a placeholder entry indicating limited support
|
|
return [
|
|
{
|
|
"id": "api_error",
|
|
"title": f"Knowledge Base: {collection_name or 'default'}",
|
|
"source_url": "",
|
|
"description": "Document listing encountered an error - API compatibility issue",
|
|
"content_type": "text/plain",
|
|
"content_preview": f"Error: {str(e)[:100]}...",
|
|
"word_count": 0,
|
|
"timestamp": "",
|
|
}
|
|
]
|
|
|
|
async def close(self) -> None:
|
|
"""Close client connection."""
|
|
if hasattr(self, "http_client"):
|
|
await self.http_client.close()
|
|
</file>
|
|
|
|
<file path="ingest_pipeline/storage/weaviate.py">
|
|
"""Weaviate storage adapter."""
|
|
|
|
import asyncio
|
|
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
|
|
from datetime import UTC, datetime
|
|
from typing import Literal, Self, TypeAlias, TypeVar, cast, overload
|
|
from uuid import UUID
|
|
|
|
import weaviate
|
|
from typing_extensions import override
|
|
from weaviate.classes.config import Configure, DataType, Property
|
|
from weaviate.classes.data import DataObject
|
|
from weaviate.classes.query import Filter
|
|
from weaviate.collections import Collection
|
|
from weaviate.exceptions import (
|
|
WeaviateBatchError,
|
|
WeaviateConnectionError,
|
|
WeaviateQueryError,
|
|
)
|
|
|
|
from ..core.exceptions import StorageError
|
|
from ..core.models import Document, DocumentMetadata, IngestionSource, StorageConfig
|
|
from ..utils.vectorizer import Vectorizer
|
|
from .base import BaseStorage
|
|
from .types import CollectionSummary, DocumentInfo
|
|
|
|
VectorContainer: TypeAlias = Mapping[str, object] | Sequence[object] | None
|
|
T = TypeVar("T")
|
|
|
|
|
|
class WeaviateStorage(BaseStorage):
|
|
"""Storage adapter for Weaviate."""
|
|
|
|
client: weaviate.WeaviateClient | None
|
|
vectorizer: Vectorizer
|
|
_default_collection: str
|
|
|
|
def __init__(self, config: StorageConfig):
|
|
"""
|
|
Initialize Weaviate storage.
|
|
|
|
Args:
|
|
config: Storage configuration
|
|
"""
|
|
super().__init__(config)
|
|
self.client = None
|
|
self.vectorizer = Vectorizer(config)
|
|
self._default_collection = self._normalize_collection_name(config.collection_name)
|
|
|
|
async def _run_sync(self, func: Callable[..., T], *args: object, **kwargs: object) -> T:
|
|
"""
|
|
Run synchronous Weaviate operations in thread pool to avoid blocking event loop.
|
|
|
|
Args:
|
|
func: Synchronous function to run
|
|
*args: Positional arguments for the function
|
|
**kwargs: Keyword arguments for the function
|
|
|
|
Returns:
|
|
Result of the function call
|
|
|
|
Raises:
|
|
StorageError: If the operation fails
|
|
"""
|
|
try:
|
|
return await asyncio.to_thread(func, *args, **kwargs)
|
|
except (WeaviateConnectionError, WeaviateBatchError, WeaviateQueryError) as e:
|
|
raise StorageError(f"Weaviate operation failed: {e}") from e
|
|
except Exception as e:
|
|
raise StorageError(f"Unexpected error in Weaviate operation: {e}") from e
|
|
|
|
@override
|
|
async def initialize(self) -> None:
|
|
"""Initialize Weaviate client and create collection if needed."""
|
|
try:
|
|
# Let Weaviate client handle URL parsing
|
|
self.client = weaviate.WeaviateClient(
|
|
connection_params=weaviate.connect.ConnectionParams.from_url(
|
|
url=str(self.config.endpoint),
|
|
grpc_port=self.config.grpc_port or 50051,
|
|
),
|
|
additional_config=weaviate.classes.init.AdditionalConfig(
|
|
timeout=weaviate.classes.init.Timeout(init=30, query=60, insert=120),
|
|
),
|
|
)
|
|
|
|
# Connect to the client
|
|
await self._run_sync(self.client.connect)
|
|
|
|
# Ensure the default collection exists
|
|
await self._ensure_collection(self._default_collection)
|
|
|
|
except WeaviateConnectionError as e:
|
|
raise StorageError(f"Failed to connect to Weaviate: {e}") from e
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to initialize Weaviate: {e}") from e
|
|
|
|
async def _create_collection(self, collection_name: str) -> None:
|
|
"""Create Weaviate collection with schema."""
|
|
if not self.client:
|
|
raise StorageError("Weaviate client not initialized")
|
|
try:
|
|
await self._run_sync(
|
|
self.client.collections.create,
|
|
name=collection_name,
|
|
properties=[
|
|
Property(
|
|
name="content", data_type=DataType.TEXT, description="Document content"
|
|
),
|
|
Property(name="source_url", data_type=DataType.TEXT, description="Source URL"),
|
|
Property(name="title", data_type=DataType.TEXT, description="Document title"),
|
|
Property(
|
|
name="description",
|
|
data_type=DataType.TEXT,
|
|
description="Document description",
|
|
),
|
|
Property(
|
|
name="timestamp", data_type=DataType.DATE, description="Ingestion timestamp"
|
|
),
|
|
Property(
|
|
name="content_type", data_type=DataType.TEXT, description="Content type"
|
|
),
|
|
Property(name="word_count", data_type=DataType.INT, description="Word count"),
|
|
Property(
|
|
name="char_count", data_type=DataType.INT, description="Character count"
|
|
),
|
|
Property(
|
|
name="source", data_type=DataType.TEXT, description="Ingestion source"
|
|
),
|
|
],
|
|
vectorizer_config=Configure.Vectorizer.none(),
|
|
)
|
|
except (WeaviateConnectionError, WeaviateBatchError) as e:
|
|
raise StorageError(f"Failed to create collection: {e}") from e
|
|
|
|
@staticmethod
|
|
def _extract_vector(vector_raw: VectorContainer) -> list[float] | None:
|
|
"""Normalize vector payloads returned by Weaviate into a float list."""
|
|
if isinstance(vector_raw, Mapping):
|
|
default_vector = vector_raw.get("default")
|
|
return WeaviateStorage._extract_vector(cast(VectorContainer, default_vector))
|
|
|
|
if not isinstance(vector_raw, Sequence) or isinstance(vector_raw, (str, bytes, bytearray)):
|
|
return None
|
|
|
|
items = list(vector_raw)
|
|
if not items:
|
|
return None
|
|
|
|
first_item = items[0]
|
|
if isinstance(first_item, (int, float)):
|
|
numeric_items = cast(list[int | float], items)
|
|
try:
|
|
return [float(value) for value in numeric_items]
|
|
except (TypeError, ValueError):
|
|
return None
|
|
|
|
if isinstance(first_item, Sequence) and not isinstance(first_item, (str, bytes, bytearray)):
|
|
inner_items = list(first_item)
|
|
if all(isinstance(item, (int, float)) for item in inner_items):
|
|
try:
|
|
numeric_inner = cast(list[int | float], inner_items)
|
|
return [float(item) for item in numeric_inner]
|
|
except (TypeError, ValueError):
|
|
return None
|
|
|
|
return None
|
|
|
|
@staticmethod
|
|
def _parse_source(source_raw: object) -> IngestionSource:
|
|
"""Safely normalize persistence source values into enum instances."""
|
|
if isinstance(source_raw, IngestionSource):
|
|
return source_raw
|
|
|
|
if isinstance(source_raw, str):
|
|
try:
|
|
return IngestionSource(source_raw)
|
|
except ValueError:
|
|
return IngestionSource.WEB
|
|
|
|
return IngestionSource.WEB
|
|
|
|
@staticmethod
|
|
@overload
|
|
def _coerce_properties(
|
|
properties: object,
|
|
*,
|
|
context: str,
|
|
) -> Mapping[str, object]: ...
|
|
|
|
@staticmethod
|
|
@overload
|
|
def _coerce_properties(
|
|
properties: object,
|
|
*,
|
|
context: str,
|
|
allow_missing: Literal[False],
|
|
) -> Mapping[str, object]: ...
|
|
|
|
@staticmethod
|
|
@overload
|
|
def _coerce_properties(
|
|
properties: object,
|
|
*,
|
|
context: str,
|
|
allow_missing: Literal[True],
|
|
) -> Mapping[str, object] | None: ...
|
|
|
|
@staticmethod
|
|
def _coerce_properties(
|
|
properties: object,
|
|
*,
|
|
context: str,
|
|
allow_missing: bool = False,
|
|
) -> Mapping[str, object] | None:
|
|
"""Ensure Weaviate properties payloads are mappings."""
|
|
if properties is None:
|
|
if allow_missing:
|
|
return None
|
|
raise StorageError(f"{context} returned object without properties")
|
|
|
|
if not isinstance(properties, Mapping):
|
|
raise StorageError(
|
|
f"{context} returned invalid properties payload of type {type(properties)!r}"
|
|
)
|
|
|
|
return cast(Mapping[str, object], properties)
|
|
|
|
@staticmethod
|
|
def _build_document_properties(doc: Document) -> dict[str, object]:
|
|
"""
|
|
Build Weaviate properties dict from document.
|
|
|
|
Args:
|
|
doc: Document to build properties for
|
|
|
|
Returns:
|
|
Properties dict suitable for Weaviate
|
|
"""
|
|
return {
|
|
"content": doc.content,
|
|
"source_url": doc.metadata["source_url"],
|
|
"title": doc.metadata.get("title", ""),
|
|
"description": doc.metadata.get("description", ""),
|
|
"timestamp": doc.metadata["timestamp"].isoformat(),
|
|
"content_type": doc.metadata["content_type"],
|
|
"word_count": doc.metadata["word_count"],
|
|
"char_count": doc.metadata["char_count"],
|
|
"source": doc.source.value,
|
|
}
|
|
|
|
def _normalize_collection_name(self, collection_name: str | None) -> str:
|
|
"""Return a canonicalized collection name, defaulting to configured value."""
|
|
candidate = collection_name or self.config.collection_name
|
|
if not candidate:
|
|
raise StorageError("Collection name is required")
|
|
|
|
if normalized := candidate.strip():
|
|
return normalized[0].upper() + normalized[1:]
|
|
else:
|
|
raise StorageError("Collection name cannot be empty")
|
|
|
|
async def _ensure_collection(self, collection_name: str) -> None:
|
|
"""Create the collection if missing."""
|
|
if not self.client:
|
|
raise StorageError("Weaviate client not initialized")
|
|
|
|
client = self.client
|
|
existing = client.collections.list_all()
|
|
if collection_name not in existing:
|
|
await self._create_collection(collection_name)
|
|
|
|
async def _prepare_collection(
|
|
self,
|
|
collection_name: str | None,
|
|
*,
|
|
ensure_exists: bool,
|
|
) -> tuple[Collection, str]:
|
|
"""Return a ready collection handle and normalized name."""
|
|
normalized = self._normalize_collection_name(collection_name)
|
|
|
|
if not self.client:
|
|
raise StorageError("Weaviate client not initialized")
|
|
|
|
if ensure_exists:
|
|
await self._ensure_collection(normalized)
|
|
|
|
client = self.client
|
|
return client.collections.get(normalized), normalized
|
|
|
|
@override
|
|
async def store(self, document: Document, *, collection_name: str | None = None) -> str:
|
|
"""
|
|
Store a document in Weaviate.
|
|
|
|
Args:
|
|
document: Document to store
|
|
|
|
Returns:
|
|
Document ID
|
|
"""
|
|
try:
|
|
# Vectorize content if no vector provided
|
|
if document.vector is None:
|
|
document.vector = await self.vectorizer.vectorize(document.content)
|
|
|
|
collection, resolved_name = await self._prepare_collection(
|
|
collection_name, ensure_exists=True
|
|
)
|
|
|
|
# Prepare properties
|
|
properties = self._build_document_properties(document)
|
|
|
|
# Insert with vector
|
|
result = await self._run_sync(
|
|
collection.data.insert,
|
|
properties=properties,
|
|
vector=document.vector,
|
|
uuid=str(document.id),
|
|
)
|
|
|
|
return str(result)
|
|
|
|
except (WeaviateConnectionError, WeaviateBatchError, WeaviateQueryError) as e:
|
|
raise StorageError(f"Failed to store document: {e}") from e
|
|
|
|
@override
|
|
async def store_batch(
|
|
self, documents: list[Document], *, collection_name: str | None = None
|
|
) -> list[str]:
|
|
"""
|
|
Store multiple documents using proper batch operations.
|
|
|
|
Args:
|
|
documents: List of documents
|
|
|
|
Returns:
|
|
List of successfully stored document IDs
|
|
"""
|
|
try:
|
|
collection, resolved_name = await self._prepare_collection(
|
|
collection_name, ensure_exists=True
|
|
)
|
|
|
|
# Vectorize documents without vectors using batch processing
|
|
to_vectorize = [(i, doc) for i, doc in enumerate(documents) if doc.vector is None]
|
|
if to_vectorize:
|
|
contents = [doc.content for _, doc in to_vectorize]
|
|
vectors = await self.vectorizer.vectorize_batch(contents)
|
|
for (idx, _), vector in zip(to_vectorize, vectors, strict=False):
|
|
documents[idx].vector = vector
|
|
|
|
# Prepare batch data for insert_many
|
|
batch_objects = []
|
|
for doc in documents:
|
|
properties = self._build_document_properties(doc)
|
|
batch_objects.append(
|
|
DataObject(properties=properties, vector=doc.vector, uuid=str(doc.id))
|
|
)
|
|
|
|
# Insert batch using insert_many
|
|
response = await self._run_sync(collection.data.insert_many, batch_objects)
|
|
|
|
successful_ids: list[str] = []
|
|
error_indices = set(response.errors.keys()) if response else set()
|
|
|
|
for index, doc in enumerate(documents):
|
|
if index in error_indices:
|
|
continue
|
|
|
|
uuid_value = response.uuids.get(index) if response else None
|
|
successful_ids.append(str(uuid_value) if uuid_value is not None else str(doc.id))
|
|
|
|
if error_indices:
|
|
error_messages = ", ".join(
|
|
f"{documents[i].id}: {response.errors[i].message}"
|
|
for i in error_indices
|
|
if hasattr(response.errors[i], "message")
|
|
)
|
|
print(
|
|
"Weaviate partial batch failure for collection "
|
|
f"{resolved_name}: {error_messages}"
|
|
)
|
|
|
|
return successful_ids
|
|
|
|
except (WeaviateBatchError, WeaviateConnectionError, WeaviateQueryError) as e:
|
|
raise StorageError(f"Failed to store batch: {e}") from e
|
|
|
|
@override
|
|
async def retrieve(
|
|
self, document_id: str, *, collection_name: str | None = None
|
|
) -> Document | None:
|
|
"""
|
|
Retrieve a document from Weaviate.
|
|
|
|
Args:
|
|
document_id: Document ID
|
|
|
|
Returns:
|
|
Document or None
|
|
"""
|
|
try:
|
|
collection, resolved_name = await self._prepare_collection(
|
|
collection_name, ensure_exists=False
|
|
)
|
|
result = await self._run_sync(collection.query.fetch_object_by_id, document_id)
|
|
|
|
if not result:
|
|
return None
|
|
|
|
# Reconstruct document
|
|
props = self._coerce_properties(
|
|
result.properties,
|
|
context="fetch_object_by_id",
|
|
)
|
|
# Parse timestamp to datetime for consistent metadata format
|
|
from datetime import UTC, datetime
|
|
|
|
timestamp_raw = props.get("timestamp")
|
|
timestamp_parsed: datetime
|
|
try:
|
|
if isinstance(timestamp_raw, str):
|
|
timestamp_parsed = datetime.fromisoformat(timestamp_raw)
|
|
if timestamp_parsed.tzinfo is None:
|
|
timestamp_parsed = timestamp_parsed.replace(tzinfo=UTC)
|
|
elif isinstance(timestamp_raw, datetime):
|
|
timestamp_parsed = timestamp_raw
|
|
if timestamp_parsed.tzinfo is None:
|
|
timestamp_parsed = timestamp_parsed.replace(tzinfo=UTC)
|
|
else:
|
|
timestamp_parsed = datetime.now(UTC)
|
|
except (ValueError, TypeError):
|
|
timestamp_parsed = datetime.now(UTC)
|
|
|
|
metadata_dict = {
|
|
"source_url": str(props["source_url"]),
|
|
"title": str(props.get("title")) if props.get("title") else None,
|
|
"description": str(props.get("description")) if props.get("description") else None,
|
|
"timestamp": timestamp_parsed,
|
|
"content_type": str(props["content_type"]),
|
|
"word_count": int(str(props["word_count"])),
|
|
"char_count": int(str(props["char_count"])),
|
|
}
|
|
metadata = cast(DocumentMetadata, cast(object, metadata_dict))
|
|
|
|
vector = self._extract_vector(cast(VectorContainer, result.vector))
|
|
|
|
return Document(
|
|
id=UUID(document_id),
|
|
content=str(props["content"]),
|
|
metadata=metadata,
|
|
vector=vector,
|
|
source=self._parse_source(props.get("source")),
|
|
collection=resolved_name,
|
|
)
|
|
|
|
except WeaviateQueryError as e:
|
|
raise StorageError(f"Query failed: {e}") from e
|
|
except WeaviateConnectionError as e:
|
|
# Connection issues should be logged and return None
|
|
import logging
|
|
|
|
logging.warning(f"Weaviate connection error retrieving document {document_id}: {e}")
|
|
return None
|
|
except Exception as e:
|
|
# Log unexpected errors for debugging
|
|
import logging
|
|
|
|
logging.warning(f"Unexpected error retrieving document {document_id}: {e}")
|
|
return None
|
|
|
|
def _build_search_metadata(self, props: Mapping[str, object]) -> DocumentMetadata:
|
|
"""Build metadata dictionary from Weaviate properties."""
|
|
metadata_dict = {
|
|
"source_url": str(props["source_url"]),
|
|
"title": str(props.get("title")) if props.get("title") else None,
|
|
"description": str(props.get("description")) if props.get("description") else None,
|
|
"timestamp": str(props["timestamp"]),
|
|
"content_type": str(props["content_type"]),
|
|
"word_count": int(str(props["word_count"])),
|
|
"char_count": int(str(props["char_count"])),
|
|
}
|
|
return cast(DocumentMetadata, cast(object, metadata_dict))
|
|
|
|
def _extract_search_score(self, result: object) -> float | None:
|
|
"""Extract and convert search score from result metadata."""
|
|
metadata_obj = getattr(result, "metadata", None)
|
|
if metadata_obj is None:
|
|
return None
|
|
|
|
raw_distance = getattr(metadata_obj, "distance", None)
|
|
if raw_distance is None:
|
|
return None
|
|
|
|
try:
|
|
distance_value = float(raw_distance)
|
|
return max(0.0, 1.0 - distance_value)
|
|
except (TypeError, ValueError) as e:
|
|
import logging
|
|
|
|
logging.debug(f"Invalid distance value {raw_distance}: {e}")
|
|
return None
|
|
|
|
def _build_search_document(
|
|
self,
|
|
result: object,
|
|
resolved_name: str,
|
|
) -> Document:
|
|
"""Build Document from Weaviate search result."""
|
|
props = self._coerce_properties(
|
|
getattr(result, "properties", None),
|
|
context="search result",
|
|
)
|
|
metadata = self._build_search_metadata(props)
|
|
|
|
vector_attr = getattr(result, "vector", None)
|
|
vector = self._extract_vector(cast(VectorContainer, vector_attr))
|
|
score_value = self._extract_search_score(result)
|
|
|
|
uuid_raw = getattr(result, "uuid", None)
|
|
if uuid_raw is None:
|
|
raise StorageError("Weaviate search result missing uuid")
|
|
uuid_value = uuid_raw if isinstance(uuid_raw, UUID) else UUID(str(uuid_raw))
|
|
|
|
return Document(
|
|
id=uuid_value,
|
|
content=str(props["content"]),
|
|
metadata=metadata,
|
|
vector=vector,
|
|
source=self._parse_source(props.get("source")),
|
|
collection=resolved_name,
|
|
score=score_value,
|
|
)
|
|
|
|
@override
|
|
async def search(
|
|
self,
|
|
query: str,
|
|
limit: int = 10,
|
|
threshold: float = 0.7,
|
|
*,
|
|
collection_name: str | None = None,
|
|
) -> AsyncGenerator[Document, None]:
|
|
"""
|
|
Search for documents in Weaviate using hybrid search.
|
|
|
|
Args:
|
|
query: Search query
|
|
limit: Maximum results
|
|
threshold: Similarity threshold (not used in hybrid search)
|
|
|
|
Yields:
|
|
Matching documents
|
|
"""
|
|
try:
|
|
if not self.client:
|
|
raise StorageError("Weaviate client not initialized")
|
|
|
|
collection, resolved_name = await self._prepare_collection(
|
|
collection_name, ensure_exists=False
|
|
)
|
|
|
|
# Try hybrid search first, fall back to BM25 keyword search
|
|
try:
|
|
response = await self._run_sync(
|
|
collection.query.hybrid, query=query, limit=limit, return_metadata=["score"]
|
|
)
|
|
except (WeaviateQueryError, StorageError):
|
|
# Fall back to BM25 if hybrid search is not supported or fails
|
|
response = await self._run_sync(
|
|
collection.query.bm25, query=query, limit=limit, return_metadata=["score"]
|
|
)
|
|
|
|
for obj in response.objects:
|
|
yield self._build_document_from_search(obj, resolved_name)
|
|
|
|
except (WeaviateQueryError, WeaviateConnectionError) as e:
|
|
raise StorageError(f"Search failed: {e}") from e
|
|
|
|
@override
|
|
async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool:
|
|
"""
|
|
Delete a document from Weaviate.
|
|
|
|
Args:
|
|
document_id: Document ID
|
|
|
|
Returns:
|
|
True if deleted
|
|
"""
|
|
try:
|
|
collection, _ = await self._prepare_collection(collection_name, ensure_exists=False)
|
|
await self._run_sync(collection.data.delete_by_id, document_id)
|
|
return True
|
|
except WeaviateQueryError as e:
|
|
raise StorageError(f"Delete operation failed: {e}") from e
|
|
except Exception:
|
|
return False
|
|
|
|
@override
|
|
async def count(self, *, collection_name: str | None = None) -> int:
|
|
"""
|
|
Get document count in collection.
|
|
|
|
Returns:
|
|
Number of documents
|
|
"""
|
|
try:
|
|
if not self.client:
|
|
return 0
|
|
collection, _ = await self._prepare_collection(collection_name, ensure_exists=False)
|
|
result = collection.aggregate.over_all(total_count=True)
|
|
return result.total_count or 0
|
|
except WeaviateQueryError as e:
|
|
raise StorageError(f"Count query failed: {e}") from e
|
|
except Exception:
|
|
return 0
|
|
|
|
async def list_collections(self) -> list[str]:
|
|
"""
|
|
List all available collections.
|
|
|
|
Returns:
|
|
List of collection names
|
|
"""
|
|
try:
|
|
if not self.client:
|
|
raise StorageError("Weaviate client not initialized")
|
|
|
|
client = self.client
|
|
return list(client.collections.list_all())
|
|
|
|
except WeaviateConnectionError as e:
|
|
raise StorageError(f"Failed to list collections: {e}") from e
|
|
|
|
async def describe_collections(self) -> list[CollectionSummary]:
|
|
"""Return metadata for each Weaviate collection."""
|
|
if not self.client:
|
|
raise StorageError("Weaviate client not initialized")
|
|
|
|
try:
|
|
client = self.client
|
|
collections: list[CollectionSummary] = []
|
|
for name in client.collections.list_all():
|
|
collection_obj = client.collections.get(name)
|
|
if not collection_obj:
|
|
continue
|
|
|
|
count = collection_obj.aggregate.over_all(total_count=True).total_count or 0
|
|
size_mb = count * 0.01
|
|
collection_summary: CollectionSummary = {
|
|
"name": name,
|
|
"count": count,
|
|
"size_mb": size_mb,
|
|
}
|
|
collections.append(collection_summary)
|
|
|
|
return collections
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to describe collections: {e}") from e
|
|
|
|
async def sample_documents(
|
|
self, limit: int = 5, *, collection_name: str | None = None
|
|
) -> list[Document]:
|
|
"""
|
|
Get sample documents from the collection.
|
|
|
|
Args:
|
|
limit: Maximum number of documents to return
|
|
|
|
Returns:
|
|
List of sample documents
|
|
"""
|
|
try:
|
|
collection, resolved_name = await self._prepare_collection(
|
|
collection_name, ensure_exists=False
|
|
)
|
|
|
|
# Query for sample documents
|
|
response = await self._run_sync(collection.query.fetch_objects, limit=limit)
|
|
|
|
documents = []
|
|
for obj in response.objects:
|
|
# Convert back to Document format
|
|
props = self._coerce_properties(
|
|
getattr(obj, "properties", None),
|
|
context="sample_documents",
|
|
allow_missing=True,
|
|
)
|
|
if props is None:
|
|
continue
|
|
uuid_raw = getattr(obj, "uuid", None)
|
|
if uuid_raw is None:
|
|
continue
|
|
document_id = uuid_raw if isinstance(uuid_raw, UUID) else UUID(str(uuid_raw))
|
|
# Safely convert WeaviateField values
|
|
word_count_val = props.get("word_count")
|
|
if isinstance(word_count_val, (int, float)):
|
|
word_count = int(word_count_val)
|
|
elif word_count_val:
|
|
word_count = int(str(word_count_val))
|
|
else:
|
|
word_count = 0
|
|
|
|
char_count_val = props.get("char_count")
|
|
if isinstance(char_count_val, (int, float)):
|
|
char_count = int(char_count_val)
|
|
elif char_count_val:
|
|
char_count = int(str(char_count_val))
|
|
else:
|
|
char_count = 0
|
|
|
|
doc = Document(
|
|
id=document_id,
|
|
content=str(props.get("content", "")),
|
|
source=self._parse_source(props.get("source")),
|
|
metadata={
|
|
"source_url": str(props.get("source_url", "")),
|
|
"title": str(props.get("title", "")) if props.get("title") else None,
|
|
"description": str(props.get("description", ""))
|
|
if props.get("description")
|
|
else None,
|
|
"timestamp": datetime.fromisoformat(
|
|
str(props.get("timestamp", datetime.now(UTC).isoformat()))
|
|
),
|
|
"content_type": str(props.get("content_type", "text/plain")),
|
|
"word_count": word_count,
|
|
"char_count": char_count,
|
|
},
|
|
collection=resolved_name,
|
|
)
|
|
documents.append(doc)
|
|
|
|
return documents
|
|
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to sample documents: {e}") from e
|
|
|
|
def _safe_convert_count(self, value: object) -> int:
|
|
"""Safely convert a value to integer count."""
|
|
if isinstance(value, (int, float)):
|
|
return int(value)
|
|
elif value:
|
|
return int(str(value))
|
|
else:
|
|
return 0
|
|
|
|
def _build_document_metadata(self, props: Mapping[str, object]) -> DocumentMetadata:
|
|
"""Build metadata from search document properties."""
|
|
return {
|
|
"source_url": str(props.get("source_url", "")),
|
|
"title": str(props.get("title", "")) if props.get("title") else None,
|
|
"description": str(props.get("description", "")) if props.get("description") else None,
|
|
"timestamp": datetime.fromisoformat(
|
|
str(props.get("timestamp", datetime.now(UTC).isoformat()))
|
|
),
|
|
"content_type": str(props.get("content_type", "text/plain")),
|
|
"word_count": self._safe_convert_count(props.get("word_count")),
|
|
"char_count": self._safe_convert_count(props.get("char_count")),
|
|
}
|
|
|
|
def _extract_document_score(self, obj: object) -> float | None:
|
|
"""Extract score from document search result."""
|
|
metadata_obj = getattr(obj, "metadata", None)
|
|
if metadata_obj is None:
|
|
return None
|
|
|
|
raw_score = getattr(metadata_obj, "score", None)
|
|
if raw_score is None:
|
|
return None
|
|
|
|
try:
|
|
return float(raw_score)
|
|
except (TypeError, ValueError) as e:
|
|
import logging
|
|
|
|
logging.debug(f"Invalid score value {raw_score}: {e}")
|
|
return None
|
|
|
|
def _build_document_from_search(
|
|
self,
|
|
obj: object,
|
|
resolved_name: str,
|
|
) -> Document:
|
|
"""Build Document from search document result."""
|
|
props = self._coerce_properties(
|
|
getattr(obj, "properties", None),
|
|
context="document search result",
|
|
)
|
|
metadata = self._build_document_metadata(props)
|
|
score_value = self._extract_document_score(obj)
|
|
|
|
uuid_raw = getattr(obj, "uuid", None)
|
|
if uuid_raw is None:
|
|
raise StorageError("Weaviate search document result missing uuid")
|
|
uuid_value = uuid_raw if isinstance(uuid_raw, UUID) else UUID(str(uuid_raw))
|
|
|
|
return Document(
|
|
id=uuid_value,
|
|
content=str(props.get("content", "")),
|
|
source=self._parse_source(props.get("source")),
|
|
metadata=metadata,
|
|
collection=resolved_name,
|
|
score=score_value,
|
|
)
|
|
|
|
async def search_documents(
|
|
self, query: str, limit: int = 10, *, collection_name: str | None = None
|
|
) -> list[Document]:
|
|
"""
|
|
Search documents in the collection.
|
|
|
|
Args:
|
|
query: Search query
|
|
limit: Maximum number of results
|
|
|
|
Returns:
|
|
List of matching documents
|
|
"""
|
|
# Delegate to the unified search method
|
|
results = []
|
|
async for document in self.search(query, limit=limit, collection_name=collection_name):
|
|
results.append(document)
|
|
return results
|
|
|
|
async def list_documents(
|
|
self,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
*,
|
|
collection_name: str | None = None,
|
|
) -> list[DocumentInfo]:
|
|
"""
|
|
List documents in the collection with pagination.
|
|
|
|
Args:
|
|
limit: Maximum number of documents to return
|
|
offset: Number of documents to skip
|
|
|
|
Returns:
|
|
List of document dictionaries with id, title, source_url, and content preview
|
|
"""
|
|
try:
|
|
if not self.client:
|
|
raise StorageError("Weaviate client not initialized")
|
|
|
|
collection, _ = await self._prepare_collection(collection_name, ensure_exists=False)
|
|
|
|
# Query documents with pagination
|
|
response = await self._run_sync(
|
|
collection.query.fetch_objects,
|
|
limit=limit,
|
|
offset=offset,
|
|
return_metadata=["creation_time"],
|
|
)
|
|
|
|
documents: list[DocumentInfo] = []
|
|
for obj in response.objects:
|
|
props = self._coerce_properties(
|
|
obj.properties,
|
|
context="list_documents",
|
|
allow_missing=True,
|
|
)
|
|
if props is None:
|
|
continue
|
|
content = str(props.get("content", ""))
|
|
word_count_value = props.get("word_count", 0)
|
|
# Convert WeaviateField to int
|
|
if isinstance(word_count_value, (int, float)):
|
|
word_count = int(word_count_value)
|
|
elif word_count_value:
|
|
word_count = int(str(word_count_value))
|
|
else:
|
|
word_count = 0
|
|
|
|
doc_info: DocumentInfo = {
|
|
"id": str(obj.uuid),
|
|
"title": str(props.get("title", "Untitled")),
|
|
"source_url": str(props.get("source_url", "")),
|
|
"description": str(props.get("description", "")),
|
|
"content_type": str(props.get("content_type", "text/plain")),
|
|
"content_preview": (f"{content[:200]}..." if len(content) > 200 else content),
|
|
"word_count": word_count,
|
|
"timestamp": str(props.get("timestamp", "")),
|
|
}
|
|
documents.append(doc_info)
|
|
|
|
return documents
|
|
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to list documents: {e}") from e
|
|
|
|
async def delete_documents(
|
|
self, document_ids: list[str], *, collection_name: str | None = None
|
|
) -> dict[str, bool]:
|
|
"""
|
|
Delete multiple documents from Weaviate.
|
|
|
|
Args:
|
|
document_ids: List of document IDs to delete
|
|
|
|
Returns:
|
|
Dictionary mapping document IDs to deletion success status
|
|
"""
|
|
results: dict[str, bool] = {}
|
|
|
|
try:
|
|
if not self.client:
|
|
raise StorageError("Weaviate client not initialized")
|
|
|
|
if not document_ids:
|
|
return results
|
|
|
|
collection, resolved_name = await self._prepare_collection(
|
|
collection_name, ensure_exists=False
|
|
)
|
|
|
|
delete_filter = Filter.by_id().contains_any(document_ids)
|
|
response = await self._run_sync(
|
|
collection.data.delete_many, where=delete_filter, verbose=True
|
|
)
|
|
|
|
if objects := getattr(response, "objects", None):
|
|
for result_obj in objects:
|
|
if doc_uuid := str(getattr(result_obj, "uuid", "")):
|
|
results[doc_uuid] = bool(getattr(result_obj, "successful", False))
|
|
|
|
if len(results) < len(document_ids):
|
|
default_success = getattr(response, "failed", 0) == 0
|
|
for doc_id in document_ids:
|
|
_ = results.setdefault(doc_id, default_success)
|
|
|
|
return results
|
|
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to delete documents: {e}") from e
|
|
|
|
async def delete_by_filter(
|
|
self, filter_dict: dict[str, str], *, collection_name: str | None = None
|
|
) -> int:
|
|
"""
|
|
Delete documents matching a filter.
|
|
|
|
Args:
|
|
filter_dict: Filter criteria (e.g., {"source_url": "example.com"})
|
|
|
|
Returns:
|
|
Number of documents deleted
|
|
"""
|
|
try:
|
|
if not self.client:
|
|
raise StorageError("Weaviate client not initialized")
|
|
|
|
collection, _ = await self._prepare_collection(collection_name, ensure_exists=False)
|
|
|
|
# Build where filter
|
|
where_filter = None
|
|
if "source_url" in filter_dict:
|
|
where_filter = Filter.by_property("source_url").equal(filter_dict["source_url"])
|
|
|
|
# Get documents matching filter
|
|
if where_filter:
|
|
response = await self._run_sync(
|
|
collection.query.fetch_objects,
|
|
filters=where_filter,
|
|
limit=1000, # Max batch size
|
|
)
|
|
else:
|
|
response = await self._run_sync(
|
|
collection.query.fetch_objects,
|
|
limit=1000, # Max batch size
|
|
)
|
|
|
|
# Delete matching documents
|
|
deleted_count = 0
|
|
for obj in response.objects:
|
|
try:
|
|
await self._run_sync(collection.data.delete_by_id, obj.uuid)
|
|
deleted_count += 1
|
|
except Exception:
|
|
continue
|
|
|
|
return deleted_count
|
|
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to delete by filter: {e}") from e
|
|
|
|
async def delete_collection(self, collection_name: str | None = None) -> bool:
|
|
"""
|
|
Delete the entire collection.
|
|
|
|
Returns:
|
|
True if successful
|
|
"""
|
|
try:
|
|
if not self.client:
|
|
raise StorageError("Weaviate client not initialized")
|
|
|
|
target = self._normalize_collection_name(collection_name)
|
|
|
|
# Delete the collection using the client's collections API
|
|
client = self.client
|
|
client.collections.delete(target)
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to delete collection: {e}") from e
|
|
|
|
async def __aenter__(self) -> Self:
|
|
"""Async context manager entry."""
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: object | None,
|
|
) -> None:
|
|
"""Async context manager exit with proper cleanup."""
|
|
await self.close()
|
|
|
|
async def close(self) -> None:
|
|
"""Close client connection and vectorizer HTTP client."""
|
|
if self.client:
|
|
try:
|
|
client = self.client
|
|
client.close()
|
|
except (WeaviateConnectionError, AttributeError) as e:
|
|
import logging
|
|
|
|
logging.warning(f"Error closing Weaviate client: {e}")
|
|
|
|
# Close vectorizer HTTP client to prevent resource leaks
|
|
try:
|
|
await self.vectorizer.close()
|
|
except (AttributeError, OSError) as e:
|
|
import logging
|
|
|
|
logging.warning(f"Error closing vectorizer client: {e}")
|
|
|
|
def __del__(self) -> None:
|
|
"""Clean up client connection as fallback."""
|
|
if self.client:
|
|
try:
|
|
client = self.client
|
|
client.close()
|
|
except Exception:
|
|
pass # Ignore errors in destructor
|
|
</file>
|
|
|
|
</files>
|