diff --git a/.claude/settings.local.json b/.claude/settings.local.json
index a3a1353..a4bbc33 100644
--- a/.claude/settings.local.json
+++ b/.claude/settings.local.json
@@ -8,7 +8,8 @@
"Bash(cat:*)",
"mcp__firecrawl__firecrawl_search",
"mcp__firecrawl__firecrawl_scrape",
- "Bash(python:*)"
+ "Bash(python:*)",
+ "mcp__coder__coder_report_task"
],
"deny": [],
"ask": []
diff --git a/docs/feeds.md b/docs/feeds.md
index b8c7f69..97f96fb 100644
--- a/docs/feeds.md
+++ b/docs/feeds.md
@@ -1,263 +1,255 @@
-## Codebase Analysis Report: RAG Manager Ingestion Pipeline
+TL;DR / Highest‑impact fixes (do these first)
-**Status:** Validated against current codebase implementation
-**Target:** Enhanced implementation guidance for efficient agent execution
+Event-loop blocking in the TUI (user-visible stutter):
+time.sleep(2) is called at the end of an ingestion run on the IngestionScreen. Even though this runs in a thread worker, it blocks that worker and delays UI transition; prefer a scheduled UI callback instead.
-This analysis has been validated against the actual codebase structure and provides implementation-specific details for executing recommended improvements. The codebase demonstrates solid architecture with clear separation of concerns between ingestion flows, storage adapters, and TUI components.
+repomix-output (2)
-### Architecture Overview
-- **Storage Backends**: Weaviate, OpenWebUI, R2R with unified `BaseStorage` interface
-- **TUI Framework**: Textual-based with reactive components and async worker patterns
-- **Orchestration**: Prefect flows with retry logic and progress callbacks
-- **Configuration**: Pydantic-based settings with environment variable support
+Blocking Weaviate client calls inside async methods:
+Several async methods in WeaviateStorage call the synchronous Weaviate client directly (connect(), collections.create, queries, inserts, deletes). Wrap those in asyncio.to_thread(...) (or equivalent) to avoid freezing the loop when these calls take time.
-### Validated Implementation Analysis
+Embedding at scale: use batch vectorization:
+store_batch vectorizes each document one by one; you already have Vectorizer.vectorize_batch. Switching to batch reduces HTTP round trips and improves throughput under backpressure.
-### 1. Bug Fixes & Potential Issues
+Connection lifecycle: close the vectorizer client:
+WeaviateStorage.close() closes the Weaviate client but not the httpx.AsyncClient inside the Vectorizer. Add an await to close it to prevent leaked connections/sockets under heavy usage.
-These are areas where the code may not function as intended or could lead to errors.
+Broadened exception handling in UI utilities:
+Multiple places catch Exception broadly, making failures ambiguous and harder to surface to users (e.g., storage manager and list builders). Narrow these to expected exceptions and fail fast with user-friendly messages where appropriate.
-*
-
- HIGH PRIORITY: `R2RStorage.store_batch` inefficient looping (Lines 161-179)
-
+repomix-output (2)
- * **File:** `ingest_pipeline/storage/r2r/storage.py:161-179`
- * **Issue:** CONFIRMED - Method loops through documents calling `_store_single_document` individually
- * **Impact:** ~5-10x performance degradation for batch operations
- * **Implementation:** Check R2R v3 API for bulk endpoints; current implementation uses `/v3/documents` per document
- * **Effort:** Medium (API research + refactor)
- * **Priority:** High - affects all R2R ingestion workflows
-
+Correctness / Bugs
-*
-
- MEDIUM PRIORITY: Mixed HTTP client usage in `R2RStorage` (Lines 80, 99, 258)
-
+Blocking sleep in TUI:
+time.sleep(2) after posting notifications and before app.pop_screen() blocks the worker thread; use Textual timers instead.
- * **File:** `ingest_pipeline/storage/r2r/storage.py:80,99,258`
- * **Issue:** VALIDATED - Mixes `R2RAsyncClient` (line 80) with direct `httpx.AsyncClient` (lines 99, 258)
- * **Specific Methods:** `initialize()`, `_ensure_collection()`, `_attempt_document_creation()`
- * **Impact:** Inconsistent auth/header handling, connection pooling inefficiency
- * **Implementation:** Extend `R2RAsyncClient` or create adapter pattern for missing endpoints
- * **Test Coverage:** Check if affected methods have unit tests before refactoring
- * **Effort:** Medium (requires SDK analysis)
-
+repomix-output (2)
-*
-
- MEDIUM PRIORITY: TUI blocking during storage init (Line 91)
-
+Synchronous SDK in async contexts (Weaviate):
- * **File:** `ingest_pipeline/cli/tui/utils/runners.py:91`
- * **Issue:** CONFIRMED - `await storage_manager.initialize_all_backends()` blocks TUI startup
- * **Current Implementation:** 30s timeout per backend in `StorageManager.initialize_all_backends()`
- * **User Impact:** Frozen terminal for up to 90s if all backends timeout
- * **Solution:** Move to `CollectionOverviewScreen.on_mount()` as `@work` task
- * **Dependencies:** `dashboard.py:304` already has worker pattern for `refresh_collections`
- * **Implementation:** Use existing loading indicators and status updates (lines 308-312)
- * **Effort:** Low (pattern exists, needs relocation)
-
+initialize() calls self.client.connect() (sync). Wrap with asyncio.to_thread(self.client.connect).
-*
-
- LOW PRIORITY: Weak URL validation in `IngestionScreen` (Lines 240-260)
-
+repomix-output (2)
- * **File:** `ingest_pipeline/cli/tui/screens/ingestion.py:240-260`
- * **Issue:** CONFIRMED - Method accepts `foo/bar` as valid (line 258)
- * **Security Risk:** Medium - malicious URLs could be passed to ingestors
- * **Current Logic:** Basic prefix checks only (http/https/file://)
- * **Enhancement:** Add `pathlib.Path.exists()` for file:// paths, `.git` directory check for repos
- * **Dependencies:** Import `pathlib` and add proper regex validation
- * **Alternative:** Use `validators` library (not currently imported)
- * **Effort:** Low (validation logic only)
-
+Object operations such as collection.data.insert(...), collection.query.fetch_objects(...), collection.data.delete_many(...), and client.collections.delete(...) are sync calls invoked from async methods. These can stall the event loop under latency; use asyncio.to_thread(...) (see snippet below).
-### 2. Code Redundancy & Refactoring Opportunities
+HTTP client lifecycle (Vectorizer):
+Vectorizer owns an httpx.AsyncClient but WeaviateStorage.close() doesn’t close it—add a close call to avoid resource leaks.
-These suggestions aim to make the code more concise, maintainable, and reusable (D.R.Y. - Don't Repeat Yourself).
+Heuristic “word_count” for OpenWebUI file listing:
+word_count is estimated via size/6. That’s fine as a placeholder but can mislead paging and UI logic downstream—consider a sentinel or a clearer label (e.g., estimated_word_count).
-*
-
- HIGH IMPACT: Redundant collection logic in dashboard (Lines 356-424)
-
+repomix-output (2)
- * **File:** `ingest_pipeline/cli/tui/screens/dashboard.py:356-424`
- * **Issue:** CONFIRMED - `list_weaviate_collections()` and `list_openwebui_collections()` duplicate `StorageManager.get_all_collections()`
- * **Code Duplication:** ~70 lines of redundant collection listing logic
- * **Architecture Violation:** UI layer coupled to specific storage implementations
- * **Current Usage:** `refresh_collections()` calls `get_all_collections()` (line 327), making methods obsolete
- * **Action:** DELETE methods `list_weaviate_collections` and `list_openwebui_collections`
- * **Impact:** Code reduction ~70 lines, improved maintainability
- * **Risk:** Low - methods appear unused in current flow
- * **Effort:** Low (deletion only)
-
+Wide except Exception:
+Broad catches appear in places like the storage manager and TUI screen update closures; they hide actionable errors. Prefer catching StorageError, IngestionError, or specific SDK exceptions—and surface toasts with actionable details.
-*
-
- MEDIUM IMPACT: Repetitive backend init pattern (Lines 255-291)
-
+repomix-output (2)
- * **File:** `ingest_pipeline/cli/tui/utils/storage_manager.py:255-291`
- * **Issue:** CONFIRMED - Pattern repeated 3x for each backend type
- * **Code Structure:** Check settings → Create config → Add task (12 lines × 3 backends)
- * **Current Backends:** Weaviate (258-267), OpenWebUI (270-279), R2R (282-291)
- * **Refactor Pattern:** Create `BackendConfig` dataclass with `(backend_type, endpoint_setting, api_key_setting, storage_class)`
- * **Implementation:** Loop over config list, reducing ~36 lines to ~15 lines
- * **Extensibility:** Adding new backend becomes one-line config addition
- * **Testing:** Ensure `asyncio.gather()` behavior unchanged (line 296)
- * **Effort:** Medium (requires dataclass design + testing)
-
+Performance & Scalability
-*
-
- MEDIUM IMPACT: Repeated Prefect block loading pattern (Lines 266-311)
-
+Batch embeddings:
+In WeaviateStorage.store_batch you conditionally vectorize each doc inline. Use your existing vectorize_batch and map the results back to documents. This cuts request overhead and enables controlled concurrency (you already have AsyncTaskManager to help).
- * **File:** `ingest_pipeline/flows/ingestion.py:266-311`
- * **Issue:** CONFIRMED - Pattern in `_create_ingestor()` and `_create_storage()` methods
- * **Duplication:** `Block.aload()` + fallback logic repeated 4x across both methods
- * **Variable Resolution:** Batch size logic (lines 244-255) also needs abstraction
- * **Helper Functions Needed:**
- - `load_block_with_fallback(block_slug: str, default_config: T) -> T`
- - `resolve_prefect_variable(var_name: str, default: T, type_cast: Type[T]) -> T`
- * **Impact:** Cleaner flow logic, better error handling, type safety
- * **Lines Reduced:** ~20 lines of repetitive code
- * **Effort:** Medium (requires generic typing)
-
+Async-friendly Weaviate calls:
+Offload sync SDK operations to a thread so your Textual UI remains responsive while bulk inserting/querying/deleting. (See patch below.)
-### 3. User Experience (UX) Enhancements
+Retry & backoff are solid in your HTTP base client:
+Your TypedHttpClient.request adds exponential backoff + jitter; this is great and worth keeping consistent across adapters.
-These are suggestions to make your TUI more powerful, intuitive, and enjoyable for the user.
+repomix-output (2)
-*
-
- HIGH IMPACT: Document content viewer modal (Add to documents.py)
-
+UX notes (Textual TUI)
- * **Target File:** `ingest_pipeline/cli/tui/screens/documents.py`
- * **Current State:** READY - `DocumentManagementScreen` has table selection (line 212)
- * **Implementation:**
- - Add `Binding("v", "view_document", "View")` to BINDINGS (line 27)
- - Create `DocumentContentModal(ModalScreen)` with `ScrollableContainer` + `Markdown`
- - Use existing `get_current_document()` method (line 212)
- - Fetch full content via `storage.retrieve(document_id)`
- * **Dependencies:** Import `ModalScreen`, `ScrollableContainer`, `Markdown` from textual
- * **User Value:** HIGH - essential for content inspection workflow
- * **Effort:** Low-Medium (~50 lines of modal code)
- * **Pattern:** Follow existing modal patterns in codebase
-
+Notifications are good—make them consistent:
+app.safe_notify(...) exists—use that everywhere instead of self.notify(...) to normalize markup handling & safety.
-*
-
- HIGH IMPACT: Analytics tab visualization (Lines 164-189)
-
+repomix-output (2)
- * **Target File:** `ingest_pipeline/cli/tui/screens/dashboard.py:164-189`
- * **Current State:** PLACEHOLDER - Static widgets with dummy content
- * **Data Source:** Use existing `self.collections` (line 65) populated by `refresh_collections()`
- * **Implementation Options:**
- 1. **Simple Text Chart:** ASCII bar chart using existing collections data
- 2. **textual-plotext:** Add dependency + bar chart widget
- 3. **Custom Widget:** Simple bar visualization with Static widgets
- * **Metrics to Show:**
- - Documents per collection (data available)
- - Storage usage per backend (calculated in `_calculate_metrics()`)
- - Ingestion timeline (requires timestamp tracking)
- * **Effort:** Low-Medium (depends on visualization complexity)
- * **Dependencies:** Consider `textual-plotext` or pure ASCII approach
-
+Avoid visible “freeze” scenarios:
+Replace the final time.sleep(2) with a timer and transition immediately; users see their actions complete without lag. (Patch below.)
-*
-
- MEDIUM IMPACT: Global search implementation (Button exists, needs screen)
-
+repomix-output (2)
- * **Target File:** `ingest_pipeline/cli/tui/screens/dashboard.py`
- * **Current State:** READY - "Search All" button exists (line 122), handler stubbed
- * **Backend Support:** `StorageManager.search_across_backends()` method exists (line 413-441)
- * **Implementation:**
- - Create `GlobalSearchScreen(ModalScreen)` with search input + results table
- - Use existing `search_across_backends()` method for data
- - Add "Backend" column to results table showing data source
- - Handle async search with loading indicators
- * **Current Limitation:** Search only works for Weaviate (line 563), need to extend
- * **Data Flow:** Input → `storage_manager.search_across_backends()` → Results display
- * **Effort:** Medium (~100 lines for new screen + search logic)
-
+Search table quick find:
+You have EnhancedDataTable with quick-search messages; wire a shortcut (e.g., /) to focus a search input and filter rows live.
-*
-
- MEDIUM IMPACT: R2R advanced features integration (Widgets ready)
-
+repomix-output (2)
- * **Target File:** `ingest_pipeline/cli/tui/screens/documents.py`
- * **Available Widgets:** CONFIRMED - `ChunkViewer`, `EntityGraph`, `CollectionStats`, `DocumentOverview` in `r2r_widgets.py`
- * **Current Implementation:** Basic document table only, R2R-specific features unused
- * **Integration Points:**
- - Add "R2R Details" button when `collection["type"] == "r2r"` (conditional UI)
- - Create `R2RDocumentDetailsScreen` using existing widgets
- - Use `StorageManager.get_r2r_storage()` method (exists at line 442)
- * **R2R Methods Available:**
- - `get_document_chunks()`, `extract_entities()`, `get_document_overview()`
- * **User Value:** Medium-High for R2R users, showcases advanced features
- * **Effort:** Low-Medium (widgets exist, need screen integration)
-
+Theme file size & maintainability:
+The theming system is thorough; consider splitting styles.py into smaller modules or generating CSS at build time to keep the Python file lean. (The responsive CSS generators are consolidated here.)
-*
-
- LOW IMPACT: Create collection dialog (Backend methods exist)
-
+repomix-output (2)
- * **Target File:** `ingest_pipeline/cli/tui/screens/dashboard.py`
- * **Backend Support:** CONFIRMED - `create_collection()` method exists for R2R storage (line 690)
- * **Current State:** No "Create Collection" button in existing UI
- * **Implementation:**
- - Add "New Collection" button to dashboard action buttons
- - Create `CreateCollectionModal` with name input + backend checkboxes
- - Iterate over `storage_manager.get_available_backends()` for backend selection
- - Call `storage.create_collection()` on selected backends
- * **Backend Compatibility:** Check which storage backends support collection creation
- * **User Value:** Low-Medium (manual workflow, not critical)
- * **Effort:** Low-Medium (~75 lines for modal + integration)
-
+Modularity / Redundancy
-## Implementation Priority Matrix
+Converge repeated property mapping:
+WeaviateStorage.store and store_batch build the same properties dict; factor this into a small helper to keep schemas in one place (less drift, easier to extend).
-### Quick Wins (High Impact, Low Effort)
-1. **Delete redundant collection methods** (dashboard.py:356-424) - 5 min
-2. **Fix TUI startup blocking** (runners.py:91) - 15 min
-3. **Document content viewer modal** (documents.py) - 30 min
+repomix-output (2)
-### High Impact Fixes (Medium Effort)
-1. **R2R batch operation optimization** (storage.py:161-179) - Research R2R v3 API + implementation
-2. **Analytics tab visualization** (dashboard.py:164-189) - Choose visualization approach + implement
-3. **Backend initialization refactoring** (storage_manager.py:255-291) - Dataclass design + testing
+Common “describe/list/count” patterns across storages:
+R2RStorage, OpenWebUIStorage, and WeaviateStorage present similar collection/document listing and counting methods. Consider a small “collection view” mixin with shared helpers; each backend implements only the API-specific steps.
-### Technical Debt (Long-term)
-1. **R2R client consistency** (storage.py) - SDK analysis + refactoring
-2. **Prefect block loading helpers** (ingestion.py:266-311) - Generic typing + testing
-3. **URL validation enhancement** (ingestion.py:240-260) - Security + validation logic
+Security & Reliability
-### Feature Enhancements (User Value)
-1. **Global search implementation** - Medium effort, requires search backend extension
-2. **R2R advanced features integration** - Showcase existing widget capabilities
-3. **Create collection dialog** - Nice-to-have administrative feature
+Input sanitization for LLM metadata:
+Your MetadataTagger sanitizes and bounds fields (e.g., max lengths, language whitelist). This is a strong pattern—keep it.
-## Agent Execution Notes
+repomix-output (2)
-**Context Efficiency Tips:**
-- Focus on one priority tier at a time
-- Read specific file ranges mentioned in line numbers
-- Use existing patterns (worker decorators, modal screens, async methods)
-- Test changes incrementally, especially async operations
-- Verify import dependencies before implementation
+Timeouts and typed HTTP clients:
+You standardize HTTP clients, headers, and timeouts. Good foundation for consistent behavior & observability.
-**Architecture Constraints:**
-- Maintain async/await patterns throughout
-- Follow Textual reactive widget patterns
-- Preserve Prefect flow structure for orchestration
-- Keep storage backend abstraction intact
+repomix-output (2)
-The codebase demonstrates excellent architectural foundations - these enhancements build upon existing strengths rather than requiring structural changes.
\ No newline at end of file
+Suggested patches (drop‑in)
+1) Don’t block the UI when closing the ingestion screen
+
+Current (blocking):
+
+repomix-output (2)
+
+import time
+time.sleep(2)
+cast("CollectionManagementApp", self.app).pop_screen()
+
+
+Safer (schedule via the app’s timer)
+
+def _pop() -> None:
+ try:
+ self.app.pop_screen()
+ except Exception:
+ pass
+
+# schedule from the worker thread
+cast("CollectionManagementApp", self.app).call_from_thread(
+ lambda: self.app.set_timer(2.0, _pop)
+)
+
+2) Offload Weaviate sync calls from async methods
+
+Example – insert (from WeaviateStorage.store)
+
+repomix-output (2)
+
+# before (sync in async method)
+collection.data.insert(properties=properties, vector=vector)
+
+# after
+await asyncio.to_thread(collection.data.insert, properties=properties, vector=vector)
+
+
+Example – fetch/delete by filter (from delete_by_filter)
+
+repomix-output (2)
+
+response = await asyncio.to_thread(
+ collection.query.fetch_objects, filters=where_filter, limit=1000
+)
+for obj in response.objects:
+ await asyncio.to_thread(collection.data.delete_by_id, obj.uuid)
+
+
+Example – connect during initialize
+
+repomix-output (2)
+
+await asyncio.to_thread(self.client.connect)
+
+3) Batch embeddings in store_batch
+
+Current (per‑doc):
+
+repomix-output (2)
+
+for doc in documents:
+ if doc.vector is None:
+ doc.vector = await self.vectorizer.vectorize(doc.content)
+
+
+Proposed (batch):
+
+repomix-output (2)
+
+# collect contents needing vectors
+to_embed_idxs = [i for i, d in enumerate(documents) if d.vector is None]
+if to_embed_idxs:
+ contents = [documents[i].content for i in to_embed_idxs]
+ vectors = await self.vectorizer.vectorize_batch(contents)
+ for j, idx in enumerate(to_embed_idxs):
+ documents[idx].vector = vectors[j]
+
+4) Close the Vectorizer HTTP client when storage closes
+
+Current: WeaviateStorage.close() only closes the Weaviate client.
+
+repomix-output (2)
+
+Add:
+
+async def close(self) -> None:
+ if self.client:
+ try:
+ cast(weaviate.WeaviateClient, self.client).close()
+ except Exception as e:
+ import logging
+ logging.warning("Error closing Weaviate client: %s", e)
+ # NEW: close vectorizer HTTP client too
+ try:
+ await self.vectorizer.client.aclose()
+ except Exception:
+ pass
+
+
+(Your Vectorizer owns an httpx.AsyncClient with headers/timeouts set.)
+
+repomix-output (2)
+
+5) Prefer safe_notify consistently
+
+Replace direct self.notify(...) calls inside TUI screens with cast(AppType, self.app).safe_notify(...) (same severity, markup=False by default), centralizing styling/sanitization.
+
+repomix-output (2)
+
+Smaller improvements (quality of life)
+
+Quick search focus: Wire the EnhancedDataTable quick-search to a keyboard binding (e.g., /) so users can filter rows without reaching for the mouse.
+
+repomix-output (2)
+
+Refactor repeated properties dict: Extract the property construction in WeaviateStorage.store/store_batch into a helper to avoid drift and reduce line count.
+
+repomix-output (2)
+
+Styles organization: styles.py is hefty by design. Consider splitting “theme palette”, “components”, and “responsive” generators into separate modules to keep diffs small and reviews easier.
+
+repomix-output (2)
+
+Architecture & Modularity
+
+Storage adapters: The three adapters share “describe/list/count” concepts. Introduce a tiny shared “CollectionIntrospection” mixin (interfaces + default fallbacks), and keep only API specifics in each adapter. This will simplify the TUI’s StorageManager as well.
+
+Ingestion flows: Good use of Prefect tasks/flows with retries, variables, and tagging. The Firecrawl→R2R specialization cleanly reuses common steps. The batch boundaries are clear and progress updates are threaded back to the UI.
+
+DX / Testing
+
+Unit tests: In this export I don’t see tests. Add lightweight tests for:
+
+Vector extraction/format parsing in Vectorizer (covers multiple providers).
+
+repomix-output (2)
+
+Weaviate adapters: property building, name normalization, vector extraction.
+
+repomix-output (2)
+
+MetadataTagger sanitization rules.
+
+repomix-output (2)
+
+TUI: use Textual’s pilot to test that notifications and timers trigger and that screens transition without blocking (verifies the sleep fix).
+
+Static analysis: You already have excellent typing. Add ruff + mypy --strict and a pre-commit config to keep it consistent across contributors.
\ No newline at end of file
diff --git a/ingest_pipeline/automations/__init__.py b/ingest_pipeline/automations/__init__.py
index b50ca00..83a0e3b 100644
--- a/ingest_pipeline/automations/__init__.py
+++ b/ingest_pipeline/automations/__init__.py
@@ -19,7 +19,6 @@ actions:
source: inferred
enabled: true
""",
-
"retry_failed": """
name: Retry Failed Ingestion Flows
description: Retries failed ingestion flows with original parameters
@@ -39,7 +38,6 @@ actions:
validate_first: false
enabled: true
""",
-
"resource_monitoring": """
name: Manage Work Pool Based on Resources
description: Pauses work pool when system resources are constrained
diff --git a/ingest_pipeline/cli/__pycache__/__init__.cpython-312.pyc b/ingest_pipeline/cli/__pycache__/__init__.cpython-312.pyc
index 7f1d8f8..04f5255 100644
Binary files a/ingest_pipeline/cli/__pycache__/__init__.cpython-312.pyc and b/ingest_pipeline/cli/__pycache__/__init__.cpython-312.pyc differ
diff --git a/ingest_pipeline/cli/__pycache__/main.cpython-312.pyc b/ingest_pipeline/cli/__pycache__/main.cpython-312.pyc
index c0fabe5..bb38b8b 100644
Binary files a/ingest_pipeline/cli/__pycache__/main.cpython-312.pyc and b/ingest_pipeline/cli/__pycache__/main.cpython-312.pyc differ
diff --git a/ingest_pipeline/cli/main.py b/ingest_pipeline/cli/main.py
index 9022ad8..b52c4a0 100644
--- a/ingest_pipeline/cli/main.py
+++ b/ingest_pipeline/cli/main.py
@@ -139,6 +139,7 @@ def ingest(
# If we're already in an event loop (e.g., in Jupyter), use nest_asyncio
try:
import nest_asyncio
+
nest_asyncio.apply()
result = asyncio.run(run_with_progress())
except ImportError:
@@ -449,7 +450,9 @@ def search(
def blocks_command() -> None:
"""🧩 List and manage Prefect Blocks."""
console.print("[bold cyan]📦 Prefect Blocks Management[/bold cyan]")
- console.print("Use 'prefect block register --module ingest_pipeline.core.models' to register custom blocks")
+ console.print(
+ "Use 'prefect block register --module ingest_pipeline.core.models' to register custom blocks"
+ )
console.print("Use 'prefect block ls' to list available blocks")
@@ -507,7 +510,9 @@ async def run_list_collections() -> None:
weaviate_config = StorageConfig(
backend=StorageBackend.WEAVIATE,
endpoint=settings.weaviate_endpoint,
- api_key=SecretStr(settings.weaviate_api_key) if settings.weaviate_api_key is not None else None,
+ api_key=SecretStr(settings.weaviate_api_key)
+ if settings.weaviate_api_key is not None
+ else None,
collection_name="default",
)
weaviate = WeaviateStorage(weaviate_config)
@@ -528,7 +533,9 @@ async def run_list_collections() -> None:
openwebui_config = StorageConfig(
backend=StorageBackend.OPEN_WEBUI,
endpoint=settings.openwebui_endpoint,
- api_key=SecretStr(settings.openwebui_api_key) if settings.openwebui_api_key is not None else None,
+ api_key=SecretStr(settings.openwebui_api_key)
+ if settings.openwebui_api_key is not None
+ else None,
collection_name="default",
)
openwebui = OpenWebUIStorage(openwebui_config)
@@ -593,7 +600,9 @@ async def run_search(query: str, collection: str | None, backend: str, limit: in
weaviate_config = StorageConfig(
backend=StorageBackend.WEAVIATE,
endpoint=settings.weaviate_endpoint,
- api_key=SecretStr(settings.weaviate_api_key) if settings.weaviate_api_key is not None else None,
+ api_key=SecretStr(settings.weaviate_api_key)
+ if settings.weaviate_api_key is not None
+ else None,
collection_name=collection or "default",
)
weaviate = WeaviateStorage(weaviate_config)
diff --git a/ingest_pipeline/cli/tui/__pycache__/__init__.cpython-312.pyc b/ingest_pipeline/cli/tui/__pycache__/__init__.cpython-312.pyc
index 5b9dae0..c1b4d9c 100644
Binary files a/ingest_pipeline/cli/tui/__pycache__/__init__.cpython-312.pyc and b/ingest_pipeline/cli/tui/__pycache__/__init__.cpython-312.pyc differ
diff --git a/ingest_pipeline/cli/tui/__pycache__/app.cpython-312.pyc b/ingest_pipeline/cli/tui/__pycache__/app.cpython-312.pyc
index fd540df..3cd000f 100644
Binary files a/ingest_pipeline/cli/tui/__pycache__/app.cpython-312.pyc and b/ingest_pipeline/cli/tui/__pycache__/app.cpython-312.pyc differ
diff --git a/ingest_pipeline/cli/tui/__pycache__/models.cpython-312.pyc b/ingest_pipeline/cli/tui/__pycache__/models.cpython-312.pyc
index d2d5850..65dfb85 100644
Binary files a/ingest_pipeline/cli/tui/__pycache__/models.cpython-312.pyc and b/ingest_pipeline/cli/tui/__pycache__/models.cpython-312.pyc differ
diff --git a/ingest_pipeline/cli/tui/__pycache__/styles.cpython-312.pyc b/ingest_pipeline/cli/tui/__pycache__/styles.cpython-312.pyc
index 940720f..6c53f50 100644
Binary files a/ingest_pipeline/cli/tui/__pycache__/styles.cpython-312.pyc and b/ingest_pipeline/cli/tui/__pycache__/styles.cpython-312.pyc differ
diff --git a/ingest_pipeline/cli/tui/app.py b/ingest_pipeline/cli/tui/app.py
index 3e6651d..62d8d9d 100644
--- a/ingest_pipeline/cli/tui/app.py
+++ b/ingest_pipeline/cli/tui/app.py
@@ -31,7 +31,6 @@ else: # pragma: no cover - optional dependency fallback
R2RStorage = BaseStorage
-
class CollectionManagementApp(App[None]):
"""Enhanced modern Textual application with comprehensive keyboard navigation."""
diff --git a/ingest_pipeline/cli/tui/layouts.py b/ingest_pipeline/cli/tui/layouts.py
index c458b01..db26aaf 100644
--- a/ingest_pipeline/cli/tui/layouts.py
+++ b/ingest_pipeline/cli/tui/layouts.py
@@ -57,7 +57,9 @@ class ResponsiveGrid(Container):
markup: bool = True,
) -> None:
"""Initialize responsive grid."""
- super().__init__(*children, name=name, id=id, classes=classes, disabled=disabled, markup=markup)
+ super().__init__(
+ *children, name=name, id=id, classes=classes, disabled=disabled, markup=markup
+ )
self._columns: int = columns
self._auto_fit: bool = auto_fit
self._compact: bool = compact
@@ -327,7 +329,9 @@ class CardLayout(ResponsiveGrid):
) -> None:
"""Initialize card layout with default settings for cards."""
# Default to auto-fit cards with minimum width
- super().__init__(auto_fit=True, name=name, id=id, classes=classes, disabled=disabled, markup=markup)
+ super().__init__(
+ auto_fit=True, name=name, id=id, classes=classes, disabled=disabled, markup=markup
+ )
class SplitPane(Container):
@@ -396,7 +400,9 @@ class SplitPane(Container):
if self._vertical:
cast(Widget, self).add_class("vertical")
- pane_classes = ("top-pane", "bottom-pane") if self._vertical else ("left-pane", "right-pane")
+ pane_classes = (
+ ("top-pane", "bottom-pane") if self._vertical else ("left-pane", "right-pane")
+ )
yield Container(self._left_content, classes=pane_classes[0])
yield Static("", classes="splitter")
diff --git a/ingest_pipeline/cli/tui/models.py b/ingest_pipeline/cli/tui/models.py
index 3aeaeda..25d2ac2 100644
--- a/ingest_pipeline/cli/tui/models.py
+++ b/ingest_pipeline/cli/tui/models.py
@@ -1,7 +1,7 @@
"""Data models and TypedDict definitions for the TUI."""
from enum import IntEnum
-from typing import Any, TypedDict
+from typing import TypedDict
class StorageCapabilities(IntEnum):
@@ -47,7 +47,7 @@ class ChunkInfo(TypedDict):
content: str
start_index: int
end_index: int
- metadata: dict[str, Any]
+ metadata: dict[str, object]
class EntityInfo(TypedDict):
@@ -57,7 +57,7 @@ class EntityInfo(TypedDict):
name: str
type: str
confidence: float
- metadata: dict[str, Any]
+ metadata: dict[str, object]
class FirecrawlOptions(TypedDict, total=False):
@@ -77,7 +77,7 @@ class FirecrawlOptions(TypedDict, total=False):
max_depth: int
# Extraction options
- extract_schema: dict[str, Any] | None
+ extract_schema: dict[str, object] | None
extract_prompt: str | None
diff --git a/ingest_pipeline/cli/tui/screens/__pycache__/__init__.cpython-312.pyc b/ingest_pipeline/cli/tui/screens/__pycache__/__init__.cpython-312.pyc
index 6847502..91587a4 100644
Binary files a/ingest_pipeline/cli/tui/screens/__pycache__/__init__.cpython-312.pyc and b/ingest_pipeline/cli/tui/screens/__pycache__/__init__.cpython-312.pyc differ
diff --git a/ingest_pipeline/cli/tui/screens/__pycache__/dashboard.cpython-312.pyc b/ingest_pipeline/cli/tui/screens/__pycache__/dashboard.cpython-312.pyc
index 80b8710..f85215b 100644
Binary files a/ingest_pipeline/cli/tui/screens/__pycache__/dashboard.cpython-312.pyc and b/ingest_pipeline/cli/tui/screens/__pycache__/dashboard.cpython-312.pyc differ
diff --git a/ingest_pipeline/cli/tui/screens/__pycache__/dialogs.cpython-312.pyc b/ingest_pipeline/cli/tui/screens/__pycache__/dialogs.cpython-312.pyc
index d7c9466..2b98192 100644
Binary files a/ingest_pipeline/cli/tui/screens/__pycache__/dialogs.cpython-312.pyc and b/ingest_pipeline/cli/tui/screens/__pycache__/dialogs.cpython-312.pyc differ
diff --git a/ingest_pipeline/cli/tui/screens/__pycache__/documents.cpython-312.pyc b/ingest_pipeline/cli/tui/screens/__pycache__/documents.cpython-312.pyc
index 5d61cd6..737396f 100644
Binary files a/ingest_pipeline/cli/tui/screens/__pycache__/documents.cpython-312.pyc and b/ingest_pipeline/cli/tui/screens/__pycache__/documents.cpython-312.pyc differ
diff --git a/ingest_pipeline/cli/tui/screens/__pycache__/help.cpython-312.pyc b/ingest_pipeline/cli/tui/screens/__pycache__/help.cpython-312.pyc
index a64eb9a..efaa925 100644
Binary files a/ingest_pipeline/cli/tui/screens/__pycache__/help.cpython-312.pyc and b/ingest_pipeline/cli/tui/screens/__pycache__/help.cpython-312.pyc differ
diff --git a/ingest_pipeline/cli/tui/screens/__pycache__/ingestion.cpython-312.pyc b/ingest_pipeline/cli/tui/screens/__pycache__/ingestion.cpython-312.pyc
index a92317a..e42797b 100644
Binary files a/ingest_pipeline/cli/tui/screens/__pycache__/ingestion.cpython-312.pyc and b/ingest_pipeline/cli/tui/screens/__pycache__/ingestion.cpython-312.pyc differ
diff --git a/ingest_pipeline/cli/tui/screens/__pycache__/search.cpython-312.pyc b/ingest_pipeline/cli/tui/screens/__pycache__/search.cpython-312.pyc
index 3ca3a76..2d58b0b 100644
Binary files a/ingest_pipeline/cli/tui/screens/__pycache__/search.cpython-312.pyc and b/ingest_pipeline/cli/tui/screens/__pycache__/search.cpython-312.pyc differ
diff --git a/ingest_pipeline/cli/tui/screens/base.py b/ingest_pipeline/cli/tui/screens/base.py
index b6c34df..a8d9742 100644
--- a/ingest_pipeline/cli/tui/screens/base.py
+++ b/ingest_pipeline/cli/tui/screens/base.py
@@ -29,7 +29,7 @@ class BaseScreen(Screen[object]):
name: str | None = None,
id: str | None = None,
classes: str | None = None,
- **kwargs: object
+ **kwargs: object,
) -> None:
"""Initialize base screen."""
super().__init__(name=name, id=id, classes=classes)
@@ -55,7 +55,7 @@ class CRUDScreen(BaseScreen, Generic[T]):
name: str | None = None,
id: str | None = None,
classes: str | None = None,
- **kwargs: object
+ **kwargs: object,
) -> None:
"""Initialize CRUD screen."""
super().__init__(storage_manager, name=name, id=id, classes=classes)
@@ -311,7 +311,7 @@ class FormScreen(ModalScreen[T], Generic[T]):
name: str | None = None,
id: str | None = None,
classes: str | None = None,
- **kwargs: object
+ **kwargs: object,
) -> None:
"""Initialize form screen."""
super().__init__(name=name, id=id, classes=classes)
diff --git a/ingest_pipeline/cli/tui/screens/dashboard.py b/ingest_pipeline/cli/tui/screens/dashboard.py
index dacd32e..b94f8de 100644
--- a/ingest_pipeline/cli/tui/screens/dashboard.py
+++ b/ingest_pipeline/cli/tui/screens/dashboard.py
@@ -1,7 +1,6 @@
"""Main dashboard screen with collections overview."""
import logging
-from datetime import datetime
from typing import TYPE_CHECKING, Final
from textual import work
@@ -357,7 +356,6 @@ class CollectionOverviewScreen(Screen[None]):
self.is_loading = False
loading_indicator.display = False
-
async def update_collections_table(self) -> None:
"""Update the collections table with enhanced formatting."""
table = self.query_one("#collections_table", EnhancedDataTable)
diff --git a/ingest_pipeline/cli/tui/screens/dialogs.py b/ingest_pipeline/cli/tui/screens/dialogs.py
index 8bbcc64..b90f526 100644
--- a/ingest_pipeline/cli/tui/screens/dialogs.py
+++ b/ingest_pipeline/cli/tui/screens/dialogs.py
@@ -82,7 +82,10 @@ class ConfirmDeleteScreen(Screen[None]):
try:
if self.collection["type"] == "weaviate" and self.parent_screen.weaviate:
# Delete Weaviate collection
- if self.parent_screen.weaviate.client and self.parent_screen.weaviate.client.collections:
+ if (
+ self.parent_screen.weaviate.client
+ and self.parent_screen.weaviate.client.collections
+ ):
self.parent_screen.weaviate.client.collections.delete(self.collection["name"])
self.notify(
f"Deleted Weaviate collection: {self.collection['name']}",
@@ -100,7 +103,7 @@ class ConfirmDeleteScreen(Screen[None]):
return
# Check if the storage backend supports collection deletion
- if not hasattr(storage_backend, 'delete_collection'):
+ if not hasattr(storage_backend, "delete_collection"):
self.notify(
f"❌ Collection deletion not supported for {self.collection['type']} backend",
severity="error",
@@ -113,10 +116,13 @@ class ConfirmDeleteScreen(Screen[None]):
collection_name = str(self.collection["name"])
collection_type = str(self.collection["type"])
- self.notify(f"Deleting {collection_type} collection: {collection_name}...", severity="information")
+ self.notify(
+ f"Deleting {collection_type} collection: {collection_name}...",
+ severity="information",
+ )
# Use the standard delete_collection method for all backends
- if hasattr(storage_backend, 'delete_collection'):
+ if hasattr(storage_backend, "delete_collection"):
success = await storage_backend.delete_collection(collection_name)
else:
self.notify("❌ Backend does not support collection deletion", severity="error")
@@ -149,7 +155,6 @@ class ConfirmDeleteScreen(Screen[None]):
self.parent_screen.refresh_collections()
-
class ConfirmDocumentDeleteScreen(Screen[None]):
"""Screen for confirming document deletion."""
@@ -223,11 +228,11 @@ class ConfirmDocumentDeleteScreen(Screen[None]):
try:
results: dict[str, bool] = {}
- if hasattr(self.parent_screen, 'storage') and self.parent_screen.storage:
+ if hasattr(self.parent_screen, "storage") and self.parent_screen.storage:
# Delete documents via storage
# The storage should have delete_documents method for weaviate
storage = self.parent_screen.storage
- if hasattr(storage, 'delete_documents'):
+ if hasattr(storage, "delete_documents"):
results = await storage.delete_documents(
self.doc_ids,
collection_name=self.collection["name"],
@@ -280,7 +285,9 @@ class LogViewerScreen(ModalScreen[None]):
yield Header(show_clock=True)
yield Container(
Static("📜 Live Application Logs", classes="title"),
- Static("Logs update in real time. Press S to reveal the log file path.", classes="subtitle"),
+ Static(
+ "Logs update in real time. Press S to reveal the log file path.", classes="subtitle"
+ ),
RichLog(id="log_stream", classes="log-stream", wrap=True, highlight=False),
Static("", id="log_file_path", classes="subtitle"),
classes="main_container log-viewer-container",
@@ -291,13 +298,13 @@ class LogViewerScreen(ModalScreen[None]):
"""Attach this viewer to the parent application once mounted."""
self._log_widget = self.query_one(RichLog)
- if hasattr(self.app, 'attach_log_viewer'):
+ if hasattr(self.app, "attach_log_viewer"):
self.app.attach_log_viewer(self) # type: ignore[arg-type]
def on_unmount(self) -> None:
"""Detach from the parent application when closed."""
- if hasattr(self.app, 'detach_log_viewer'):
+ if hasattr(self.app, "detach_log_viewer"):
self.app.detach_log_viewer(self) # type: ignore[arg-type]
def _get_log_widget(self) -> RichLog:
@@ -340,4 +347,6 @@ class LogViewerScreen(ModalScreen[None]):
if self._log_file is None:
self.notify("File logging is disabled for this session.", severity="warning")
else:
- self.notify(f"Log file available at: {self._log_file}", severity="information", markup=False)
+ self.notify(
+ f"Log file available at: {self._log_file}", severity="information", markup=False
+ )
diff --git a/ingest_pipeline/cli/tui/screens/documents.py b/ingest_pipeline/cli/tui/screens/documents.py
index 4d04e3c..2d24175 100644
--- a/ingest_pipeline/cli/tui/screens/documents.py
+++ b/ingest_pipeline/cli/tui/screens/documents.py
@@ -116,7 +116,9 @@ class DocumentManagementScreen(Screen[None]):
content_type=str(doc.get("content_type", "text/plain")),
content_preview=str(doc.get("content_preview", "")),
word_count=(
- lambda wc_val: int(wc_val) if isinstance(wc_val, (int, str)) and str(wc_val).isdigit() else 0
+ lambda wc_val: int(wc_val)
+ if isinstance(wc_val, (int, str)) and str(wc_val).isdigit()
+ else 0
)(doc.get("word_count", 0)),
timestamp=str(doc.get("timestamp", "")),
)
@@ -126,7 +128,7 @@ class DocumentManagementScreen(Screen[None]):
# For storage backends that don't support document listing, show a message
self.notify(
f"Document listing not supported for {self.storage.__class__.__name__}",
- severity="information"
+ severity="information",
)
self.documents = []
@@ -330,7 +332,9 @@ class DocumentManagementScreen(Screen[None]):
"""View the content of the currently selected document."""
if doc := self.get_current_document():
if self.storage:
- self.app.push_screen(DocumentContentModal(doc, self.storage, self.collection["name"]))
+ self.app.push_screen(
+ DocumentContentModal(doc, self.storage, self.collection["name"])
+ )
else:
self.notify("No storage backend available", severity="error")
else:
@@ -381,13 +385,13 @@ class DocumentContentModal(ModalScreen[None]):
yield Container(
Static(
f"📄 Document: {self.document['title'][:60]}{'...' if len(self.document['title']) > 60 else ''}",
- classes="modal-header"
+ classes="modal-header",
),
ScrollableContainer(
Markdown("Loading document content...", id="document_content"),
LoadingIndicator(id="content_loading"),
- classes="modal-content"
- )
+ classes="modal-content",
+ ),
)
async def on_mount(self) -> None:
@@ -398,30 +402,29 @@ class DocumentContentModal(ModalScreen[None]):
try:
# Get full document content
doc_content = await self.storage.retrieve(
- self.document["id"],
- collection_name=self.collection_name
+ self.document["id"], collection_name=self.collection_name
)
# Format content for display
if isinstance(doc_content, str):
- formatted_content = f"""# {self.document['title']}
+ formatted_content = f"""# {self.document["title"]}
-**Source:** {self.document.get('source_url', 'N/A')}
-**Type:** {self.document.get('content_type', 'text/plain')}
-**Words:** {self.document.get('word_count', 0):,}
-**Timestamp:** {self.document.get('timestamp', 'N/A')}
+**Source:** {self.document.get("source_url", "N/A")}
+**Type:** {self.document.get("content_type", "text/plain")}
+**Words:** {self.document.get("word_count", 0):,}
+**Timestamp:** {self.document.get("timestamp", "N/A")}
---
{doc_content}
"""
else:
- formatted_content = f"""# {self.document['title']}
+ formatted_content = f"""# {self.document["title"]}
-**Source:** {self.document.get('source_url', 'N/A')}
-**Type:** {self.document.get('content_type', 'text/plain')}
-**Words:** {self.document.get('word_count', 0):,}
-**Timestamp:** {self.document.get('timestamp', 'N/A')}
+**Source:** {self.document.get("source_url", "N/A")}
+**Type:** {self.document.get("content_type", "text/plain")}
+**Words:** {self.document.get("word_count", 0):,}
+**Timestamp:** {self.document.get("timestamp", "N/A")}
---
@@ -431,6 +434,8 @@ class DocumentContentModal(ModalScreen[None]):
content_widget.update(formatted_content)
except Exception as e:
- content_widget.update(f"# Error Loading Document\n\nFailed to load document content: {e}")
+ content_widget.update(
+ f"# Error Loading Document\n\nFailed to load document content: {e}"
+ )
finally:
loading.display = False
diff --git a/ingest_pipeline/cli/tui/screens/ingestion.py b/ingest_pipeline/cli/tui/screens/ingestion.py
index 5608228..6149a19 100644
--- a/ingest_pipeline/cli/tui/screens/ingestion.py
+++ b/ingest_pipeline/cli/tui/screens/ingestion.py
@@ -2,7 +2,7 @@
from __future__ import annotations
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, cast
from textual import work
from textual.app import ComposeResult
@@ -105,12 +105,23 @@ class IngestionScreen(ModalScreen[None]):
Label("📋 Source Type (Press 1/2/3):", classes="input-label"),
Horizontal(
Button("🌐 Web (1)", id="web_btn", variant="primary", classes="type-button"),
- Button("📦 Repository (2)", id="repo_btn", variant="default", classes="type-button"),
- Button("📖 Documentation (3)", id="docs_btn", variant="default", classes="type-button"),
+ Button(
+ "📦 Repository (2)", id="repo_btn", variant="default", classes="type-button"
+ ),
+ Button(
+ "📖 Documentation (3)",
+ id="docs_btn",
+ variant="default",
+ classes="type-button",
+ ),
classes="type_buttons",
),
Rule(line_style="dashed"),
- Label(f"🗄️ Target Storages ({len(self.available_backends)} available):", classes="input-label", id="backend_label"),
+ Label(
+ f"🗄️ Target Storages ({len(self.available_backends)} available):",
+ classes="input-label",
+ id="backend_label",
+ ),
Container(
*self._create_backend_checkbox_widgets(),
classes="backend-selection",
@@ -139,7 +150,6 @@ class IngestionScreen(ModalScreen[None]):
yield LoadingIndicator(id="loading", classes="pulse")
-
def _create_backend_checkbox_widgets(self) -> list[Checkbox]:
"""Create checkbox widgets for each available backend."""
checkboxes: list[Checkbox] = [
@@ -219,19 +229,26 @@ class IngestionScreen(ModalScreen[None]):
collection_name = collection_input.value.strip()
if not source_url:
- self.notify("🔍 Please enter a source URL", severity="error")
+ cast("CollectionManagementApp", self.app).safe_notify(
+ "🔍 Please enter a source URL", severity="error"
+ )
url_input.focus()
return
# Validate URL format
if not self._validate_url(source_url):
- self.notify("❌ Invalid URL format. Please enter a valid HTTP/HTTPS URL or file:// path", severity="error")
+ cast("CollectionManagementApp", self.app).safe_notify(
+ "❌ Invalid URL format. Please enter a valid HTTP/HTTPS URL or file:// path",
+ severity="error",
+ )
url_input.focus()
return
resolved_backends = self._resolve_selected_backends()
if not resolved_backends:
- self.notify("⚠️ Select at least one storage backend", severity="warning")
+ cast("CollectionManagementApp", self.app).safe_notify(
+ "⚠️ Select at least one storage backend", severity="warning"
+ )
return
self.selected_backends = resolved_backends
@@ -246,18 +263,16 @@ class IngestionScreen(ModalScreen[None]):
url_lower = url.lower()
# Allow HTTP/HTTPS URLs
- if url_lower.startswith(('http://', 'https://')):
+ if url_lower.startswith(("http://", "https://")):
# Additional validation could be added here
return True
# Allow file:// URLs for repository paths
- if url_lower.startswith('file://'):
+ if url_lower.startswith("file://"):
return True
# Allow local file paths that look like repositories
- return '/' in url and not url_lower.startswith(
- ('javascript:', 'data:', 'vbscript:')
- )
+ return "/" in url and not url_lower.startswith(("javascript:", "data:", "vbscript:"))
def _resolve_selected_backends(self) -> list[StorageBackend]:
selected: list[StorageBackend] = []
@@ -288,13 +303,14 @@ class IngestionScreen(ModalScreen[None]):
if not self.selected_backends:
status_widget.update("📋 Selected: None")
elif len(self.selected_backends) == 1:
- backend_name = BACKEND_LABELS.get(self.selected_backends[0], self.selected_backends[0].value)
+ backend_name = BACKEND_LABELS.get(
+ self.selected_backends[0], self.selected_backends[0].value
+ )
status_widget.update(f"📋 Selected: {backend_name}")
else:
# Multiple backends selected
backend_names = [
- BACKEND_LABELS.get(backend, backend.value)
- for backend in self.selected_backends
+ BACKEND_LABELS.get(backend, backend.value) for backend in self.selected_backends
]
if len(backend_names) <= 3:
# Show all names if 3 or fewer
@@ -319,7 +335,10 @@ class IngestionScreen(ModalScreen[None]):
for backend in BACKEND_ORDER:
if backend not in self.available_backends:
continue
- if backend.value.lower() == backend_name_lower or backend.name.lower() == backend_name_lower:
+ if (
+ backend.value.lower() == backend_name_lower
+ or backend.name.lower() == backend_name_lower
+ ):
matched_backends.append(backend)
break
return matched_backends or [self.available_backends[0]]
@@ -351,6 +370,7 @@ class IngestionScreen(ModalScreen[None]):
loading.display = False
except Exception:
pass
+
cast("CollectionManagementApp", self.app).call_from_thread(_update)
def progress_reporter(percent: int, message: str) -> None:
@@ -362,6 +382,7 @@ class IngestionScreen(ModalScreen[None]):
progress_text.update(message)
except Exception:
pass
+
cast("CollectionManagementApp", self.app).call_from_thread(_update_progress)
try:
@@ -377,13 +398,18 @@ class IngestionScreen(ModalScreen[None]):
for i, backend in enumerate(backends):
progress_percent = 20 + (60 * i) // len(backends)
- progress_reporter(progress_percent, f"🔗 Processing {backend.value} backend ({i+1}/{len(backends)})...")
+ progress_reporter(
+ progress_percent,
+ f"🔗 Processing {backend.value} backend ({i + 1}/{len(backends)})...",
+ )
try:
# Run the Prefect flow for this backend using asyncio.run with timeout
import asyncio
- async def run_flow_with_timeout(current_backend: StorageBackend = backend) -> IngestionResult:
+ async def run_flow_with_timeout(
+ current_backend: StorageBackend = backend,
+ ) -> IngestionResult:
return await asyncio.wait_for(
create_ingestion_flow(
source_url=source_url,
@@ -392,7 +418,7 @@ class IngestionScreen(ModalScreen[None]):
collection_name=final_collection_name,
progress_callback=progress_reporter,
),
- timeout=600.0 # 10 minute timeout
+ timeout=600.0, # 10 minute timeout
)
result = asyncio.run(run_flow_with_timeout())
@@ -401,25 +427,33 @@ class IngestionScreen(ModalScreen[None]):
total_failed += result.documents_failed
if result.error_messages:
- flow_errors.extend([f"{backend.value}: {err}" for err in result.error_messages])
+ flow_errors.extend(
+ [f"{backend.value}: {err}" for err in result.error_messages]
+ )
except TimeoutError:
error_msg = f"{backend.value}: Timeout after 10 minutes"
flow_errors.append(error_msg)
progress_reporter(0, f"❌ {backend.value} timed out")
- def notify_timeout(msg: str = f"⏰ {backend.value} flow timed out after 10 minutes") -> None:
+
+ def notify_timeout(
+ msg: str = f"⏰ {backend.value} flow timed out after 10 minutes",
+ ) -> None:
try:
self.notify(msg, severity="error", markup=False)
except Exception:
pass
+
cast("CollectionManagementApp", self.app).call_from_thread(notify_timeout)
except Exception as exc:
flow_errors.append(f"{backend.value}: {exc}")
+
def notify_error(msg: str = f"❌ {backend.value} flow failed: {exc}") -> None:
try:
self.notify(msg, severity="error", markup=False)
except Exception:
pass
+
cast("CollectionManagementApp", self.app).call_from_thread(notify_error)
successful = total_successful
@@ -445,20 +479,37 @@ class IngestionScreen(ModalScreen[None]):
cast("CollectionManagementApp", self.app).call_from_thread(notify_results)
- import time
- time.sleep(2)
- cast("CollectionManagementApp", self.app).pop_screen()
+ def _pop() -> None:
+ try:
+ self.app.pop_screen()
+ except Exception:
+ pass
+
+ # Schedule screen pop via timer instead of blocking
+ cast("CollectionManagementApp", self.app).call_from_thread(
+ lambda: self.app.set_timer(2.0, _pop)
+ )
except Exception as exc: # pragma: no cover - defensive
progress_reporter(0, f"❌ Prefect flows error: {exc}")
+
def notify_error(msg: str = f"❌ Prefect flows failed: {exc}") -> None:
try:
self.notify(msg, severity="error")
except Exception:
pass
+
cast("CollectionManagementApp", self.app).call_from_thread(notify_error)
- import time
- time.sleep(2)
+
+ def _pop_on_error() -> None:
+ try:
+ self.app.pop_screen()
+ except Exception:
+ pass
+
+ # Schedule screen pop via timer for error case too
+ cast("CollectionManagementApp", self.app).call_from_thread(
+ lambda: self.app.set_timer(2.0, _pop_on_error)
+ )
finally:
update_ui("hide_loading")
-
diff --git a/ingest_pipeline/cli/tui/screens/search.py b/ingest_pipeline/cli/tui/screens/search.py
index 3b8427f..916dd0f 100644
--- a/ingest_pipeline/cli/tui/screens/search.py
+++ b/ingest_pipeline/cli/tui/screens/search.py
@@ -43,12 +43,24 @@ class SearchScreen(Screen[None]):
@override
def compose(self) -> ComposeResult:
yield Header()
+ # Check if search is supported for this backend
+ backends = self.collection["backend"]
+ if isinstance(backends, str):
+ backends = [backends]
+ search_supported = "weaviate" in backends
+ search_indicator = "✅ Search supported" if search_supported else "❌ Search not supported"
+
yield Container(
Static(
- f"🔍 Search in: {self.collection['name']} ({self.collection['backend']})",
+ f"🔍 Search in: {self.collection['name']} ({', '.join(backends)}) - {search_indicator}",
classes="title",
),
- Static("Press / or Ctrl+F to focus search, Enter to search", classes="subtitle"),
+ Static(
+ "Press / or Ctrl+F to focus search, Enter to search"
+ if search_supported
+ else "Search functionality not available for this backend",
+ classes="subtitle",
+ ),
Input(placeholder="Enter search query... (press Enter to search)", id="search_input"),
Button("🔍 Search", id="search_btn", variant="primary"),
Button("🗑️ Clear Results", id="clear_btn", variant="default"),
@@ -127,7 +139,9 @@ class SearchScreen(Screen[None]):
finally:
loading.display = False
- def _setup_search_ui(self, loading: LoadingIndicator, table: EnhancedDataTable, status: Static, query: str) -> None:
+ def _setup_search_ui(
+ self, loading: LoadingIndicator, table: EnhancedDataTable, status: Static, query: str
+ ) -> None:
"""Setup the search UI elements."""
loading.display = True
status.update(f"🔍 Searching for '{query}'...")
@@ -139,10 +153,18 @@ class SearchScreen(Screen[None]):
if self.collection["type"] == "weaviate" and self.weaviate:
return await self.search_weaviate(query)
elif self.collection["type"] == "openwebui" and self.openwebui:
- return await self.search_openwebui(query)
+ # OpenWebUI search is not yet implemented
+ self.notify("Search not supported for OpenWebUI collections", severity="warning")
+ return []
+ elif self.collection["type"] == "r2r":
+ # R2R search would go here when implemented
+ self.notify("Search not supported for R2R collections", severity="warning")
+ return []
return []
- def _populate_results_table(self, table: EnhancedDataTable, results: list[dict[str, str | float]]) -> None:
+ def _populate_results_table(
+ self, table: EnhancedDataTable, results: list[dict[str, str | float]]
+ ) -> None:
"""Populate the results table with search results."""
for result in results:
row_data = self._format_result_row(result)
@@ -193,7 +215,11 @@ class SearchScreen(Screen[None]):
return str(score)
def _update_search_status(
- self, status: Static, query: str, results: list[dict[str, str | float]], table: EnhancedDataTable
+ self,
+ status: Static,
+ query: str,
+ results: list[dict[str, str | float]],
+ table: EnhancedDataTable,
) -> None:
"""Update search status and notifications based on results."""
if not results:
diff --git a/ingest_pipeline/cli/tui/styles.py b/ingest_pipeline/cli/tui/styles.py
index 0095d1e..ac6b237 100644
--- a/ingest_pipeline/cli/tui/styles.py
+++ b/ingest_pipeline/cli/tui/styles.py
@@ -13,6 +13,8 @@ TextualApp = App[object]
class AppProtocol(Protocol):
"""Protocol for apps that support CSS and refresh."""
+ CSS: str
+
def refresh(self) -> None:
"""Refresh the app."""
...
@@ -1122,16 +1124,16 @@ def get_css_for_theme(theme_type: ThemeType) -> str:
def apply_theme_to_app(app: TextualApp | AppProtocol, theme_type: ThemeType) -> None:
"""Apply a theme to a Textual app instance."""
try:
- css = set_theme(theme_type)
- # Set CSS using the standard Textual approach
- if hasattr(app, "CSS") or isinstance(app, App):
- setattr(app, "CSS", css)
- # Refresh the app to apply new CSS
- if hasattr(app, "refresh"):
- app.refresh()
+ # Note: CSS class variable cannot be changed at runtime
+ # This function would need to be called during app initialization
+ # or implement a different approach for dynamic theming
+ _ = set_theme(theme_type) # Keep for future implementation
+ if hasattr(app, "refresh"):
+ app.refresh()
except Exception as e:
# Graceful fallback - log but don't crash the UI
import logging
+
logging.debug(f"Failed to apply theme to app: {e}")
@@ -1185,11 +1187,11 @@ class ThemeSwitcher:
# Responsive breakpoints for dynamic layout adaptation
RESPONSIVE_BREAKPOINTS = {
- "xs": 40, # Extra small terminals
- "sm": 60, # Small terminals
- "md": 100, # Medium terminals
- "lg": 140, # Large terminals
- "xl": 180, # Extra large terminals
+ "xs": 40, # Extra small terminals
+ "sm": 60, # Small terminals
+ "md": 100, # Medium terminals
+ "lg": 140, # Large terminals
+ "xl": 180, # Extra large terminals
}
diff --git a/ingest_pipeline/cli/tui/utils/__pycache__/__init__.cpython-312.pyc b/ingest_pipeline/cli/tui/utils/__pycache__/__init__.cpython-312.pyc
index 970fef0..bdfca59 100644
Binary files a/ingest_pipeline/cli/tui/utils/__pycache__/__init__.cpython-312.pyc and b/ingest_pipeline/cli/tui/utils/__pycache__/__init__.cpython-312.pyc differ
diff --git a/ingest_pipeline/cli/tui/utils/__pycache__/runners.cpython-312.pyc b/ingest_pipeline/cli/tui/utils/__pycache__/runners.cpython-312.pyc
index cd7f620..993c1c4 100644
Binary files a/ingest_pipeline/cli/tui/utils/__pycache__/runners.cpython-312.pyc and b/ingest_pipeline/cli/tui/utils/__pycache__/runners.cpython-312.pyc differ
diff --git a/ingest_pipeline/cli/tui/utils/__pycache__/storage_manager.cpython-312.pyc b/ingest_pipeline/cli/tui/utils/__pycache__/storage_manager.cpython-312.pyc
index e62a89e..8bc2d6f 100644
Binary files a/ingest_pipeline/cli/tui/utils/__pycache__/storage_manager.cpython-312.pyc and b/ingest_pipeline/cli/tui/utils/__pycache__/storage_manager.cpython-312.pyc differ
diff --git a/ingest_pipeline/cli/tui/utils/runners.py b/ingest_pipeline/cli/tui/utils/runners.py
index fc629ea..5e6dc47 100644
--- a/ingest_pipeline/cli/tui/utils/runners.py
+++ b/ingest_pipeline/cli/tui/utils/runners.py
@@ -10,8 +10,9 @@ from pathlib import Path
from queue import Queue
from typing import NamedTuple
+import platformdirs
+
from ....config import configure_prefect, get_settings
-from ....core.models import StorageBackend
from .storage_manager import StorageManager
@@ -53,21 +54,36 @@ def _configure_tui_logging(*, log_level: str) -> _TuiLoggingContext:
log_file: Path | None = None
try:
+ # Try current directory first for development
log_dir = Path.cwd() / "logs"
log_dir.mkdir(parents=True, exist_ok=True)
log_file = log_dir / "tui.log"
- file_handler = RotatingFileHandler(
- log_file,
- maxBytes=2_000_000,
- backupCount=5,
- encoding="utf-8",
- )
- file_handler.setLevel(resolved_level)
- file_handler.setFormatter(formatter)
- root_logger.addHandler(file_handler)
- except OSError as exc: # pragma: no cover - filesystem specific
- fallback = logging.getLogger(__name__)
- fallback.warning("Failed to configure file logging for TUI: %s", exc)
+ except OSError:
+ # Fall back to user log directory
+ try:
+ log_dir = Path(platformdirs.user_log_dir("ingest-pipeline", "ingest-pipeline"))
+ log_dir.mkdir(parents=True, exist_ok=True)
+ log_file = log_dir / "tui.log"
+ except OSError as exc:
+ fallback = logging.getLogger(__name__)
+ fallback.warning("Failed to create log directory, file logging disabled: %s", exc)
+ log_file = None
+
+ if log_file:
+ try:
+ file_handler = RotatingFileHandler(
+ log_file,
+ maxBytes=2_000_000,
+ backupCount=5,
+ encoding="utf-8",
+ )
+ file_handler.setLevel(resolved_level)
+ file_handler.setFormatter(formatter)
+ root_logger.addHandler(file_handler)
+ except OSError as exc:
+ fallback = logging.getLogger(__name__)
+ fallback.warning("Failed to configure file logging for TUI: %s", exc)
+ log_file = None
_logging_context = _TuiLoggingContext(log_queue, formatter, log_file)
return _logging_context
@@ -93,6 +109,7 @@ async def run_textual_tui() -> None:
# Import here to avoid circular import
from ..app import CollectionManagementApp
+
app = CollectionManagementApp(
storage_manager,
None, # weaviate - will be available after initialization
diff --git a/ingest_pipeline/cli/tui/utils/storage_manager.py b/ingest_pipeline/cli/tui/utils/storage_manager.py
index 42313e2..031c9f5 100644
--- a/ingest_pipeline/cli/tui/utils/storage_manager.py
+++ b/ingest_pipeline/cli/tui/utils/storage_manager.py
@@ -1,6 +1,5 @@
"""Storage management utilities for TUI applications."""
-
from __future__ import annotations
import asyncio
@@ -11,12 +10,11 @@ from pydantic import SecretStr
from ....core.exceptions import StorageError
from ....core.models import Document, StorageBackend, StorageConfig
-from ..models import CollectionInfo, StorageCapabilities
-
from ....storage.base import BaseStorage
from ....storage.openwebui import OpenWebUIStorage
from ....storage.r2r.storage import R2RStorage
from ....storage.weaviate import WeaviateStorage
+from ..models import CollectionInfo, StorageCapabilities
if TYPE_CHECKING:
from ....config.settings import Settings
@@ -39,7 +37,6 @@ class StorageBackendProtocol(Protocol):
async def close(self) -> None: ...
-
class MultiStorageAdapter(BaseStorage):
"""Mirror writes to multiple storage backends."""
@@ -70,7 +67,10 @@ class MultiStorageAdapter(BaseStorage):
# Replicate to secondary backends concurrently
if len(self._storages) > 1:
- async def replicate_to_backend(storage: BaseStorage) -> tuple[BaseStorage, bool, Exception | None]:
+
+ async def replicate_to_backend(
+ storage: BaseStorage,
+ ) -> tuple[BaseStorage, bool, Exception | None]:
try:
await storage.store(document, collection_name=collection_name)
return storage, True, None
@@ -106,11 +106,16 @@ class MultiStorageAdapter(BaseStorage):
self, documents: list[Document], *, collection_name: str | None = None
) -> list[str]:
# Store in primary backend first
- primary_ids: list[str] = await self._primary.store_batch(documents, collection_name=collection_name)
+ primary_ids: list[str] = await self._primary.store_batch(
+ documents, collection_name=collection_name
+ )
# Replicate to secondary backends concurrently
if len(self._storages) > 1:
- async def replicate_batch_to_backend(storage: BaseStorage) -> tuple[BaseStorage, bool, Exception | None]:
+
+ async def replicate_batch_to_backend(
+ storage: BaseStorage,
+ ) -> tuple[BaseStorage, bool, Exception | None]:
try:
await storage.store_batch(documents, collection_name=collection_name)
return storage, True, None
@@ -135,7 +140,9 @@ class MultiStorageAdapter(BaseStorage):
if failures:
backends = ", ".join(failures)
- primary_error = errors[0] if errors else Exception("Unknown batch replication error")
+ primary_error = (
+ errors[0] if errors else Exception("Unknown batch replication error")
+ )
raise StorageError(
f"Batch stored in primary backend but replication failed for: {backends}"
) from primary_error
@@ -144,11 +151,16 @@ class MultiStorageAdapter(BaseStorage):
async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool:
# Delete from primary backend first
- primary_deleted: bool = await self._primary.delete(document_id, collection_name=collection_name)
+ primary_deleted: bool = await self._primary.delete(
+ document_id, collection_name=collection_name
+ )
# Delete from secondary backends concurrently
if len(self._storages) > 1:
- async def delete_from_backend(storage: BaseStorage) -> tuple[BaseStorage, bool, Exception | None]:
+
+ async def delete_from_backend(
+ storage: BaseStorage,
+ ) -> tuple[BaseStorage, bool, Exception | None]:
try:
await storage.delete(document_id, collection_name=collection_name)
return storage, True, None
@@ -222,7 +234,6 @@ class MultiStorageAdapter(BaseStorage):
return class_name
-
class StorageManager:
"""Centralized manager for all storage backend operations."""
@@ -237,7 +248,9 @@ class StorageManager:
"""Initialize all available storage backends with timeout protection."""
results: dict[StorageBackend, bool] = {}
- async def init_backend(backend_type: StorageBackend, config: StorageConfig, storage_class: type[BaseStorage]) -> bool:
+ async def init_backend(
+ backend_type: StorageBackend, config: StorageConfig, storage_class: type[BaseStorage]
+ ) -> bool:
"""Initialize a single backend with timeout."""
try:
storage = storage_class(config)
@@ -261,10 +274,17 @@ class StorageManager:
config = StorageConfig(
backend=StorageBackend.WEAVIATE,
endpoint=self.settings.weaviate_endpoint,
- api_key=SecretStr(self.settings.weaviate_api_key) if self.settings.weaviate_api_key else None,
+ api_key=SecretStr(self.settings.weaviate_api_key)
+ if self.settings.weaviate_api_key
+ else None,
collection_name="default",
)
- tasks.append((StorageBackend.WEAVIATE, init_backend(StorageBackend.WEAVIATE, config, WeaviateStorage)))
+ tasks.append(
+ (
+ StorageBackend.WEAVIATE,
+ init_backend(StorageBackend.WEAVIATE, config, WeaviateStorage),
+ )
+ )
else:
results[StorageBackend.WEAVIATE] = False
@@ -273,10 +293,17 @@ class StorageManager:
config = StorageConfig(
backend=StorageBackend.OPEN_WEBUI,
endpoint=self.settings.openwebui_endpoint,
- api_key=SecretStr(self.settings.openwebui_api_key) if self.settings.openwebui_api_key else None,
+ api_key=SecretStr(self.settings.openwebui_api_key)
+ if self.settings.openwebui_api_key
+ else None,
collection_name="default",
)
- tasks.append((StorageBackend.OPEN_WEBUI, init_backend(StorageBackend.OPEN_WEBUI, config, OpenWebUIStorage)))
+ tasks.append(
+ (
+ StorageBackend.OPEN_WEBUI,
+ init_backend(StorageBackend.OPEN_WEBUI, config, OpenWebUIStorage),
+ )
+ )
else:
results[StorageBackend.OPEN_WEBUI] = False
@@ -295,7 +322,9 @@ class StorageManager:
# Execute initialization tasks concurrently
if tasks:
backend_types, task_coroutines = zip(*tasks, strict=False)
- task_results: Sequence[bool | BaseException] = await asyncio.gather(*task_coroutines, return_exceptions=True)
+ task_results: Sequence[bool | BaseException] = await asyncio.gather(
+ *task_coroutines, return_exceptions=True
+ )
for backend_type, task_result in zip(backend_types, task_results, strict=False):
results[backend_type] = task_result if isinstance(task_result, bool) else False
@@ -312,7 +341,9 @@ class StorageManager:
storages: list[BaseStorage] = []
seen: set[StorageBackend] = set()
for backend in backends:
- backend_enum = backend if isinstance(backend, StorageBackend) else StorageBackend(backend)
+ backend_enum = (
+ backend if isinstance(backend, StorageBackend) else StorageBackend(backend)
+ )
if backend_enum in seen:
continue
seen.add(backend_enum)
@@ -350,12 +381,18 @@ class StorageManager:
except StorageError as e:
# Storage-specific errors - log and use 0 count
import logging
- logging.warning(f"Failed to get count for {collection_name} on {backend_type.value}: {e}")
+
+ logging.warning(
+ f"Failed to get count for {collection_name} on {backend_type.value}: {e}"
+ )
count = 0
except Exception as e:
# Unexpected errors - log and skip this collection from this backend
import logging
- logging.warning(f"Unexpected error counting {collection_name} on {backend_type.value}: {e}")
+
+ logging.warning(
+ f"Unexpected error counting {collection_name} on {backend_type.value}: {e}"
+ )
continue
size_mb = count * 0.01 # Rough estimate: 10KB per document
@@ -446,7 +483,9 @@ class StorageManager:
storage = self.backends.get(StorageBackend.R2R)
return storage if isinstance(storage, R2RStorage) else None
- async def get_backend_status(self) -> dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]]:
+ async def get_backend_status(
+ self,
+ ) -> dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]]:
"""Get comprehensive status for all backends."""
status: dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]] = {}
diff --git a/ingest_pipeline/cli/tui/widgets/__pycache__/__init__.cpython-312.pyc b/ingest_pipeline/cli/tui/widgets/__pycache__/__init__.cpython-312.pyc
index 1a8e3a6..49bedc6 100644
Binary files a/ingest_pipeline/cli/tui/widgets/__pycache__/__init__.cpython-312.pyc and b/ingest_pipeline/cli/tui/widgets/__pycache__/__init__.cpython-312.pyc differ
diff --git a/ingest_pipeline/cli/tui/widgets/__pycache__/cards.cpython-312.pyc b/ingest_pipeline/cli/tui/widgets/__pycache__/cards.cpython-312.pyc
index affdeb1..775a6da 100644
Binary files a/ingest_pipeline/cli/tui/widgets/__pycache__/cards.cpython-312.pyc and b/ingest_pipeline/cli/tui/widgets/__pycache__/cards.cpython-312.pyc differ
diff --git a/ingest_pipeline/cli/tui/widgets/__pycache__/indicators.cpython-312.pyc b/ingest_pipeline/cli/tui/widgets/__pycache__/indicators.cpython-312.pyc
index 901bdc2..1b3408d 100644
Binary files a/ingest_pipeline/cli/tui/widgets/__pycache__/indicators.cpython-312.pyc and b/ingest_pipeline/cli/tui/widgets/__pycache__/indicators.cpython-312.pyc differ
diff --git a/ingest_pipeline/cli/tui/widgets/__pycache__/tables.cpython-312.pyc b/ingest_pipeline/cli/tui/widgets/__pycache__/tables.cpython-312.pyc
index 3738e5d..8b956c4 100644
Binary files a/ingest_pipeline/cli/tui/widgets/__pycache__/tables.cpython-312.pyc and b/ingest_pipeline/cli/tui/widgets/__pycache__/tables.cpython-312.pyc differ
diff --git a/ingest_pipeline/cli/tui/widgets/firecrawl_config.py b/ingest_pipeline/cli/tui/widgets/firecrawl_config.py
index f80aa0d..a1cc9fe 100644
--- a/ingest_pipeline/cli/tui/widgets/firecrawl_config.py
+++ b/ingest_pipeline/cli/tui/widgets/firecrawl_config.py
@@ -158,9 +158,7 @@ class ScrapeOptionsForm(Widget):
formats.append("screenshot")
options: dict[str, object] = {
"formats": formats,
- "only_main_content": self.query_one(
- "#only_main_content", Switch
- ).value,
+ "only_main_content": self.query_one("#only_main_content", Switch).value,
}
include_tags_input = self.query_one("#include_tags", Input).value
if include_tags_input.strip():
@@ -195,11 +193,15 @@ class ScrapeOptionsForm(Widget):
if include_tags := options.get("include_tags", []):
include_list = include_tags if isinstance(include_tags, list) else []
- self.query_one("#include_tags", Input).value = ", ".join(str(tag) for tag in include_list)
+ self.query_one("#include_tags", Input).value = ", ".join(
+ str(tag) for tag in include_list
+ )
if exclude_tags := options.get("exclude_tags", []):
exclude_list = exclude_tags if isinstance(exclude_tags, list) else []
- self.query_one("#exclude_tags", Input).value = ", ".join(str(tag) for tag in exclude_list)
+ self.query_one("#exclude_tags", Input).value = ", ".join(
+ str(tag) for tag in exclude_list
+ )
# Set performance
wait_for = options.get("wait_for")
@@ -503,7 +505,9 @@ class ExtractOptionsForm(Widget):
"content": "string",
"tags": ["string"]
}"""
- prompt_widget.text = "Extract article title, author, publication date, main content, and associated tags"
+ prompt_widget.text = (
+ "Extract article title, author, publication date, main content, and associated tags"
+ )
elif event.button.id == "preset_product":
schema_widget.text = """{
@@ -513,7 +517,9 @@ class ExtractOptionsForm(Widget):
"category": "string",
"availability": "string"
}"""
- prompt_widget.text = "Extract product name, price, description, category, and availability status"
+ prompt_widget.text = (
+ "Extract product name, price, description, category, and availability status"
+ )
elif event.button.id == "preset_contact":
schema_widget.text = """{
@@ -523,7 +529,9 @@ class ExtractOptionsForm(Widget):
"company": "string",
"position": "string"
}"""
- prompt_widget.text = "Extract contact information including name, email, phone, company, and position"
+ prompt_widget.text = (
+ "Extract contact information including name, email, phone, company, and position"
+ )
elif event.button.id == "preset_data":
schema_widget.text = """{
diff --git a/ingest_pipeline/cli/tui/widgets/indicators.py b/ingest_pipeline/cli/tui/widgets/indicators.py
index 268b9e3..7209569 100644
--- a/ingest_pipeline/cli/tui/widgets/indicators.py
+++ b/ingest_pipeline/cli/tui/widgets/indicators.py
@@ -26,8 +26,12 @@ class StatusIndicator(Static):
status_lower = status.lower()
- if (status_lower in {"active", "online", "connected", "✓ active"} or
- status_lower.endswith("active") or "✓" in status_lower and "active" in status_lower):
+ if (
+ status_lower in {"active", "online", "connected", "✓ active"}
+ or status_lower.endswith("active")
+ or "✓" in status_lower
+ and "active" in status_lower
+ ):
self.add_class("status-active")
self.add_class("glow")
self.update(f"🟢 {status}")
diff --git a/ingest_pipeline/cli/tui/widgets/r2r_widgets.py b/ingest_pipeline/cli/tui/widgets/r2r_widgets.py
index b571302..f865936 100644
--- a/ingest_pipeline/cli/tui/widgets/r2r_widgets.py
+++ b/ingest_pipeline/cli/tui/widgets/r2r_widgets.py
@@ -99,10 +99,17 @@ class ChunkViewer(Widget):
"id": str(chunk_data.get("id", "")),
"document_id": self.document_id,
"content": str(chunk_data.get("text", "")),
- "start_index": (lambda si: int(si) if isinstance(si, (int, str)) else 0)(chunk_data.get("start_index", 0)),
- "end_index": (lambda ei: int(ei) if isinstance(ei, (int, str)) else 0)(chunk_data.get("end_index", 0)),
+ "start_index": (lambda si: int(si) if isinstance(si, (int, str)) else 0)(
+ chunk_data.get("start_index", 0)
+ ),
+ "end_index": (lambda ei: int(ei) if isinstance(ei, (int, str)) else 0)(
+ chunk_data.get("end_index", 0)
+ ),
"metadata": (
- dict(metadata_val) if (metadata_val := chunk_data.get("metadata")) and isinstance(metadata_val, dict) else {}
+ dict(metadata_val)
+ if (metadata_val := chunk_data.get("metadata"))
+ and isinstance(metadata_val, dict)
+ else {}
),
}
self.chunks.append(chunk_info)
@@ -275,10 +282,10 @@ class EntityGraph(Widget):
"""Show detailed information about an entity."""
details_widget = self.query_one("#entity_details", Static)
- details_text = f"""**Entity:** {entity['name']}
-**Type:** {entity['type']}
-**Confidence:** {entity['confidence']:.2%}
-**ID:** {entity['id']}
+ details_text = f"""**Entity:** {entity["name"]}
+**Type:** {entity["type"]}
+**Confidence:** {entity["confidence"]:.2%}
+**ID:** {entity["id"]}
**Metadata:**
"""
diff --git a/ingest_pipeline/config/__init__.py b/ingest_pipeline/config/__init__.py
index 782f56a..ccd473e 100644
--- a/ingest_pipeline/config/__init__.py
+++ b/ingest_pipeline/config/__init__.py
@@ -26,9 +26,15 @@ def _setup_prefect_settings() -> tuple[object, object, object]:
if registry is not None:
Setting = getattr(ps, "Setting", None)
if Setting is not None:
- api_key = registry.get("PREFECT_API_KEY") or Setting("PREFECT_API_KEY", type_=str, default=None)
- api_url = registry.get("PREFECT_API_URL") or Setting("PREFECT_API_URL", type_=str, default=None)
- work_pool = registry.get("PREFECT_DEFAULT_WORK_POOL_NAME") or Setting("PREFECT_DEFAULT_WORK_POOL_NAME", type_=str, default=None)
+ api_key = registry.get("PREFECT_API_KEY") or Setting(
+ "PREFECT_API_KEY", type_=str, default=None
+ )
+ api_url = registry.get("PREFECT_API_URL") or Setting(
+ "PREFECT_API_URL", type_=str, default=None
+ )
+ work_pool = registry.get("PREFECT_DEFAULT_WORK_POOL_NAME") or Setting(
+ "PREFECT_DEFAULT_WORK_POOL_NAME", type_=str, default=None
+ )
return api_key, api_url, work_pool
except ImportError:
@@ -37,6 +43,7 @@ def _setup_prefect_settings() -> tuple[object, object, object]:
# Ultimate fallback
return None, None, None
+
PREFECT_API_KEY, PREFECT_API_URL, PREFECT_DEFAULT_WORK_POOL_NAME = _setup_prefect_settings()
# Import after Prefect settings setup to avoid circular dependencies
@@ -53,20 +60,30 @@ def configure_prefect(settings: Settings) -> None:
overrides: dict[Setting, str] = {}
- if settings.prefect_api_url is not None and PREFECT_API_URL is not None and isinstance(PREFECT_API_URL, Setting):
+ if (
+ settings.prefect_api_url is not None
+ and PREFECT_API_URL is not None
+ and isinstance(PREFECT_API_URL, Setting)
+ ):
overrides[PREFECT_API_URL] = str(settings.prefect_api_url)
- if settings.prefect_api_key and PREFECT_API_KEY is not None and isinstance(PREFECT_API_KEY, Setting):
+ if (
+ settings.prefect_api_key
+ and PREFECT_API_KEY is not None
+ and isinstance(PREFECT_API_KEY, Setting)
+ ):
overrides[PREFECT_API_KEY] = settings.prefect_api_key
- if settings.prefect_work_pool and PREFECT_DEFAULT_WORK_POOL_NAME is not None and isinstance(PREFECT_DEFAULT_WORK_POOL_NAME, Setting):
+ if (
+ settings.prefect_work_pool
+ and PREFECT_DEFAULT_WORK_POOL_NAME is not None
+ and isinstance(PREFECT_DEFAULT_WORK_POOL_NAME, Setting)
+ ):
overrides[PREFECT_DEFAULT_WORK_POOL_NAME] = settings.prefect_work_pool
if not overrides:
return
filtered_overrides = {
- setting: value
- for setting, value in overrides.items()
- if setting.value() != value
+ setting: value for setting, value in overrides.items() if setting.value() != value
}
if not filtered_overrides:
diff --git a/ingest_pipeline/config/__pycache__/__init__.cpython-312.pyc b/ingest_pipeline/config/__pycache__/__init__.cpython-312.pyc
index 0c86130..575cddd 100644
Binary files a/ingest_pipeline/config/__pycache__/__init__.cpython-312.pyc and b/ingest_pipeline/config/__pycache__/__init__.cpython-312.pyc differ
diff --git a/ingest_pipeline/config/__pycache__/settings.cpython-312.pyc b/ingest_pipeline/config/__pycache__/settings.cpython-312.pyc
index d4d2d4c..d599e53 100644
Binary files a/ingest_pipeline/config/__pycache__/settings.cpython-312.pyc and b/ingest_pipeline/config/__pycache__/settings.cpython-312.pyc differ
diff --git a/ingest_pipeline/config/settings.py b/ingest_pipeline/config/settings.py
index dd3fdeb..f27877e 100644
--- a/ingest_pipeline/config/settings.py
+++ b/ingest_pipeline/config/settings.py
@@ -139,10 +139,11 @@ class Settings(BaseSettings):
key_name, key_value = required_keys[backend]
if not key_value:
import warnings
+
warnings.warn(
f"{key_name} not set - authentication may fail for {backend} backend",
UserWarning,
- stacklevel=2
+ stacklevel=2,
)
return self
@@ -165,16 +166,24 @@ class PrefectVariableConfig:
def __init__(self) -> None:
self._settings: Settings = get_settings()
self._variable_names: list[str] = [
- "default_batch_size", "max_file_size", "max_crawl_depth", "max_crawl_pages",
- "default_storage_backend", "default_collection_prefix", "max_concurrent_tasks",
- "request_timeout", "default_schedule_interval"
+ "default_batch_size",
+ "max_file_size",
+ "max_crawl_depth",
+ "max_crawl_pages",
+ "default_storage_backend",
+ "default_collection_prefix",
+ "max_concurrent_tasks",
+ "request_timeout",
+ "default_schedule_interval",
]
def _get_fallback_value(self, name: str, default_value: object = None) -> object:
"""Get fallback value from settings or default."""
return default_value or getattr(self._settings, name, default_value)
- def get_with_fallback(self, name: str, default_value: str | int | float | None = None) -> str | int | float | None:
+ def get_with_fallback(
+ self, name: str, default_value: str | int | float | None = None
+ ) -> str | int | float | None:
"""Get variable value with fallback synchronously."""
fallback = self._get_fallback_value(name, default_value)
# Ensure fallback is a type that Variable expects
@@ -195,7 +204,9 @@ class PrefectVariableConfig:
return fallback
return str(fallback) if fallback is not None else None
- async def get_with_fallback_async(self, name: str, default_value: str | int | float | None = None) -> str | int | float | None:
+ async def get_with_fallback_async(
+ self, name: str, default_value: str | int | float | None = None
+ ) -> str | int | float | None:
"""Get variable value with fallback asynchronously."""
fallback = self._get_fallback_value(name, default_value)
variable_fallback = str(fallback) if fallback is not None else None
diff --git a/ingest_pipeline/core/__pycache__/__init__.cpython-312.pyc b/ingest_pipeline/core/__pycache__/__init__.cpython-312.pyc
index 8e4579b..ecdb6dc 100644
Binary files a/ingest_pipeline/core/__pycache__/__init__.cpython-312.pyc and b/ingest_pipeline/core/__pycache__/__init__.cpython-312.pyc differ
diff --git a/ingest_pipeline/core/__pycache__/exceptions.cpython-312.pyc b/ingest_pipeline/core/__pycache__/exceptions.cpython-312.pyc
index a991174..f82254b 100644
Binary files a/ingest_pipeline/core/__pycache__/exceptions.cpython-312.pyc and b/ingest_pipeline/core/__pycache__/exceptions.cpython-312.pyc differ
diff --git a/ingest_pipeline/core/__pycache__/models.cpython-312.pyc b/ingest_pipeline/core/__pycache__/models.cpython-312.pyc
index 51f00b4..accf339 100644
Binary files a/ingest_pipeline/core/__pycache__/models.cpython-312.pyc and b/ingest_pipeline/core/__pycache__/models.cpython-312.pyc differ
diff --git a/ingest_pipeline/core/models.py b/ingest_pipeline/core/models.py
index d2af2b2..e2ca503 100644
--- a/ingest_pipeline/core/models.py
+++ b/ingest_pipeline/core/models.py
@@ -12,35 +12,36 @@ from ..config import get_settings
def _default_embedding_model() -> str:
- return get_settings().embedding_model
+ return str(get_settings().embedding_model)
def _default_embedding_endpoint() -> HttpUrl:
- return get_settings().llm_endpoint
+ endpoint = get_settings().llm_endpoint
+ return endpoint if isinstance(endpoint, HttpUrl) else HttpUrl(str(endpoint))
def _default_embedding_dimension() -> int:
- return get_settings().embedding_dimension
+ return int(get_settings().embedding_dimension)
def _default_batch_size() -> int:
- return get_settings().default_batch_size
+ return int(get_settings().default_batch_size)
def _default_collection_name() -> str:
- return get_settings().default_collection_prefix
+ return str(get_settings().default_collection_prefix)
def _default_max_crawl_depth() -> int:
- return get_settings().max_crawl_depth
+ return int(get_settings().max_crawl_depth)
def _default_max_crawl_pages() -> int:
- return get_settings().max_crawl_pages
+ return int(get_settings().max_crawl_pages)
def _default_max_file_size() -> int:
- return get_settings().max_file_size
+ return int(get_settings().max_file_size)
class IngestionStatus(str, Enum):
@@ -84,13 +85,16 @@ class StorageConfig(Block):
_block_type_name: ClassVar[str | None] = "Storage Configuration"
_block_type_slug: ClassVar[str | None] = "storage-config"
- _description: ClassVar[str | None] = "Configures storage backend connections and settings for document ingestion"
+ _description: ClassVar[str | None] = (
+ "Configures storage backend connections and settings for document ingestion"
+ )
backend: StorageBackend
endpoint: HttpUrl
api_key: SecretStr | None = Field(default=None)
collection_name: str = Field(default_factory=_default_collection_name)
batch_size: Annotated[int, Field(gt=0, le=1000)] = Field(default_factory=_default_batch_size)
+ grpc_port: int | None = Field(default=None, description="gRPC port for Weaviate connections")
class FirecrawlConfig(Block):
@@ -112,7 +116,9 @@ class RepomixConfig(Block):
_block_type_name: ClassVar[str | None] = "Repomix Configuration"
_block_type_slug: ClassVar[str | None] = "repomix-config"
- _description: ClassVar[str | None] = "Configures repository ingestion patterns and file processing settings"
+ _description: ClassVar[str | None] = (
+ "Configures repository ingestion patterns and file processing settings"
+ )
include_patterns: list[str] = Field(
default_factory=lambda: ["*.py", "*.js", "*.ts", "*.md", "*.yaml", "*.json"]
@@ -129,7 +135,9 @@ class R2RConfig(Block):
_block_type_name: ClassVar[str | None] = "R2R Configuration"
_block_type_slug: ClassVar[str | None] = "r2r-config"
- _description: ClassVar[str | None] = "Configures R2R-specific ingestion settings including chunking and graph enrichment"
+ _description: ClassVar[str | None] = (
+ "Configures R2R-specific ingestion settings including chunking and graph enrichment"
+ )
chunk_size: Annotated[int, Field(ge=100, le=8192)] = 1000
chunk_overlap: Annotated[int, Field(ge=0, le=1000)] = 200
@@ -139,6 +147,7 @@ class R2RConfig(Block):
class DocumentMetadataRequired(TypedDict):
"""Required metadata fields for a document."""
+
source_url: str
timestamp: datetime
content_type: str
diff --git a/ingest_pipeline/flows/__pycache__/__init__.cpython-312.pyc b/ingest_pipeline/flows/__pycache__/__init__.cpython-312.pyc
index bdb82c6..f01c833 100644
Binary files a/ingest_pipeline/flows/__pycache__/__init__.cpython-312.pyc and b/ingest_pipeline/flows/__pycache__/__init__.cpython-312.pyc differ
diff --git a/ingest_pipeline/flows/__pycache__/ingestion.cpython-312.pyc b/ingest_pipeline/flows/__pycache__/ingestion.cpython-312.pyc
index 029aa8d..4384a6b 100644
Binary files a/ingest_pipeline/flows/__pycache__/ingestion.cpython-312.pyc and b/ingest_pipeline/flows/__pycache__/ingestion.cpython-312.pyc differ
diff --git a/ingest_pipeline/flows/__pycache__/scheduler.cpython-312.pyc b/ingest_pipeline/flows/__pycache__/scheduler.cpython-312.pyc
index 0550c9b..ec27d15 100644
Binary files a/ingest_pipeline/flows/__pycache__/scheduler.cpython-312.pyc and b/ingest_pipeline/flows/__pycache__/scheduler.cpython-312.pyc differ
diff --git a/ingest_pipeline/flows/ingestion.py b/ingest_pipeline/flows/ingestion.py
index 53157bd..e18709e 100644
--- a/ingest_pipeline/flows/ingestion.py
+++ b/ingest_pipeline/flows/ingestion.py
@@ -103,8 +103,13 @@ async def initialize_storage_task(config: StorageConfig | str) -> BaseStorage:
return storage
-@task(name="map_firecrawl_site", retries=2, retry_delay_seconds=15, tags=["firecrawl", "map"],
- cache_key_fn=lambda ctx, p: _safe_cache_key("firecrawl_map", p, "source_url"))
+@task(
+ name="map_firecrawl_site",
+ retries=2,
+ retry_delay_seconds=15,
+ tags=["firecrawl", "map"],
+ cache_key_fn=lambda ctx, p: _safe_cache_key("firecrawl_map", p, "source_url"),
+)
async def map_firecrawl_site_task(source_url: str, config: FirecrawlConfig | str) -> list[str]:
"""Map a site using Firecrawl and return discovered URLs."""
# Load block if string provided
@@ -118,8 +123,13 @@ async def map_firecrawl_site_task(source_url: str, config: FirecrawlConfig | str
return mapped or [source_url]
-@task(name="filter_existing_documents", retries=1, retry_delay_seconds=5, tags=["dedup"],
- cache_key_fn=lambda ctx, p: _safe_cache_key("filter_docs", p, "urls")) # Cache based on URL list
+@task(
+ name="filter_existing_documents",
+ retries=1,
+ retry_delay_seconds=5,
+ tags=["dedup"],
+ cache_key_fn=lambda ctx, p: _safe_cache_key("filter_docs", p, "urls"),
+) # Cache based on URL list
async def filter_existing_documents_task(
urls: list[str],
storage_client: BaseStorage,
@@ -128,19 +138,40 @@ async def filter_existing_documents_task(
collection_name: str | None = None,
) -> list[str]:
"""Filter URLs to only those that need scraping (missing or stale in storage)."""
+ import asyncio
+
logger = get_run_logger()
- eligible: list[str] = []
- for url in urls:
- document_id = str(FirecrawlIngestor.compute_document_id(url))
- exists = await storage_client.check_exists(
- document_id,
- collection_name=collection_name,
- stale_after_days=stale_after_days
- )
+ # Use semaphore to limit concurrent existence checks
+ semaphore = asyncio.Semaphore(20)
- if not exists:
- eligible.append(url)
+ async def check_url_exists(url: str) -> tuple[str, bool]:
+ async with semaphore:
+ try:
+ document_id = str(FirecrawlIngestor.compute_document_id(url))
+ exists = await storage_client.check_exists(
+ document_id, collection_name=collection_name, stale_after_days=stale_after_days
+ )
+ return url, exists
+ except Exception as e:
+ logger.warning("Error checking existence for URL %s: %s", url, e)
+ # Assume doesn't exist on error to ensure we scrape it
+ return url, False
+
+ # Check all URLs in parallel - use return_exceptions=True for partial failure handling
+ results = await asyncio.gather(*[check_url_exists(url) for url in urls], return_exceptions=True)
+
+ # Collect URLs that need scraping, handling any exceptions
+ eligible = []
+ for result in results:
+ if isinstance(result, Exception):
+ logger.error("Unexpected error in parallel existence check: %s", result)
+ continue
+ # Type narrowing: result is now known to be tuple[str, bool]
+ if isinstance(result, tuple) and len(result) == 2:
+ url, exists = result
+ if not exists:
+ eligible.append(url)
skipped = len(urls) - len(eligible)
if skipped > 0:
@@ -260,7 +291,9 @@ async def ingest_documents_task(
if progress_callback:
progress_callback(40, "Starting document processing...")
- return await _process_documents(ingestor, storage, job, batch_size, collection_name, progress_callback)
+ return await _process_documents(
+ ingestor, storage, job, batch_size, collection_name, progress_callback
+ )
async def _create_ingestor(job: IngestionJob, config_block_name: str | None = None) -> BaseIngestor:
@@ -287,7 +320,9 @@ async def _create_ingestor(job: IngestionJob, config_block_name: str | None = No
raise ValueError(f"Unsupported source: {job.source_type}")
-async def _create_storage(job: IngestionJob, collection_name: str | None, storage_block_name: str | None = None) -> BaseStorage:
+async def _create_storage(
+ job: IngestionJob, collection_name: str | None, storage_block_name: str | None = None
+) -> BaseStorage:
"""Create and initialize storage client."""
if collection_name is None:
# Use variable for default collection prefix
@@ -303,6 +338,7 @@ async def _create_storage(job: IngestionJob, collection_name: str | None, storag
else:
# Fallback to building config from settings
from ..config import get_settings
+
settings = get_settings()
storage_config = _build_storage_config(job, settings, collection_name)
@@ -391,7 +427,7 @@ async def _process_documents(
progress_callback(45, "Ingesting documents from source...")
# Use smart ingestion with deduplication if storage supports it
- if hasattr(storage, 'check_exists'):
+ if hasattr(storage, "check_exists"):
try:
# Try to use the smart ingestion method
document_generator = ingestor.ingest_with_dedup(
@@ -412,7 +448,7 @@ async def _process_documents(
if progress_callback:
progress_callback(
45 + min(35, (batch_count * 10)),
- f"Processing batch {batch_count} ({total_documents} documents so far)..."
+ f"Processing batch {batch_count} ({total_documents} documents so far)...",
)
batch_processed, batch_failed = await _store_batch(storage, batch, collection_name)
@@ -482,7 +518,9 @@ async def _store_batch(
log_prints=True,
)
async def firecrawl_to_r2r_flow(
- job: IngestionJob, collection_name: str | None = None, progress_callback: Callable[[int, str], None] | None = None
+ job: IngestionJob,
+ collection_name: str | None = None,
+ progress_callback: Callable[[int, str], None] | None = None,
) -> tuple[int, int]:
"""Specialized flow for Firecrawl ingestion into R2R."""
logger = get_run_logger()
@@ -549,10 +587,8 @@ async def firecrawl_to_r2r_flow(
# Use asyncio.gather for concurrent scraping
import asyncio
- scrape_tasks = [
- scrape_firecrawl_batch_task(batch, firecrawl_config)
- for batch in url_batches
- ]
+
+ scrape_tasks = [scrape_firecrawl_batch_task(batch, firecrawl_config) for batch in url_batches]
batch_results = await asyncio.gather(*scrape_tasks)
scraped_pages: list[FirecrawlPage] = []
@@ -680,9 +716,13 @@ async def create_ingestion_flow(
progress_callback(30, "Starting document ingestion...")
print("Ingesting documents...")
if job.source_type == IngestionSource.WEB and job.storage_backend == StorageBackend.R2R:
- processed, failed = await firecrawl_to_r2r_flow(job, collection_name, progress_callback=progress_callback)
+ processed, failed = await firecrawl_to_r2r_flow(
+ job, collection_name, progress_callback=progress_callback
+ )
else:
- processed, failed = await ingest_documents_task(job, collection_name, progress_callback=progress_callback)
+ processed, failed = await ingest_documents_task(
+ job, collection_name, progress_callback=progress_callback
+ )
if progress_callback:
progress_callback(90, "Finalizing ingestion...")
diff --git a/ingest_pipeline/ingestors/__pycache__/__init__.cpython-312.pyc b/ingest_pipeline/ingestors/__pycache__/__init__.cpython-312.pyc
index 60c5cf8..e589461 100644
Binary files a/ingest_pipeline/ingestors/__pycache__/__init__.cpython-312.pyc and b/ingest_pipeline/ingestors/__pycache__/__init__.cpython-312.pyc differ
diff --git a/ingest_pipeline/ingestors/__pycache__/base.cpython-312.pyc b/ingest_pipeline/ingestors/__pycache__/base.cpython-312.pyc
index 90833af..8f257b5 100644
Binary files a/ingest_pipeline/ingestors/__pycache__/base.cpython-312.pyc and b/ingest_pipeline/ingestors/__pycache__/base.cpython-312.pyc differ
diff --git a/ingest_pipeline/ingestors/__pycache__/firecrawl.cpython-312.pyc b/ingest_pipeline/ingestors/__pycache__/firecrawl.cpython-312.pyc
index 20330d4..1ef9e56 100644
Binary files a/ingest_pipeline/ingestors/__pycache__/firecrawl.cpython-312.pyc and b/ingest_pipeline/ingestors/__pycache__/firecrawl.cpython-312.pyc differ
diff --git a/ingest_pipeline/ingestors/__pycache__/repomix.cpython-312.pyc b/ingest_pipeline/ingestors/__pycache__/repomix.cpython-312.pyc
index 11c2959..e4260d9 100644
Binary files a/ingest_pipeline/ingestors/__pycache__/repomix.cpython-312.pyc and b/ingest_pipeline/ingestors/__pycache__/repomix.cpython-312.pyc differ
diff --git a/ingest_pipeline/ingestors/firecrawl.py b/ingest_pipeline/ingestors/firecrawl.py
index 7e0275b..1db5cee 100644
--- a/ingest_pipeline/ingestors/firecrawl.py
+++ b/ingest_pipeline/ingestors/firecrawl.py
@@ -191,9 +191,10 @@ class FirecrawlIngestor(BaseIngestor):
"http://localhost"
):
# Self-hosted instance - try with api_url if supported
- self.client = cast(AsyncFirecrawlClient, AsyncFirecrawl(
- api_key=api_key, api_url=str(settings.firecrawl_endpoint)
- ))
+ self.client = cast(
+ AsyncFirecrawlClient,
+ AsyncFirecrawl(api_key=api_key, api_url=str(settings.firecrawl_endpoint)),
+ )
else:
# Cloud instance - use standard initialization
self.client = cast(AsyncFirecrawlClient, AsyncFirecrawl(api_key=api_key))
@@ -256,9 +257,7 @@ class FirecrawlIngestor(BaseIngestor):
for check_url in site_map:
document_id = str(self.compute_document_id(check_url))
exists = await storage_client.check_exists(
- document_id,
- collection_name=collection_name,
- stale_after_days=stale_after_days
+ document_id, collection_name=collection_name, stale_after_days=stale_after_days
)
if not exists:
eligible_urls.append(check_url)
@@ -394,7 +393,9 @@ class FirecrawlIngestor(BaseIngestor):
# Extract basic metadata
title: str | None = getattr(metadata, "title", None) if metadata else None
- description: str | None = getattr(metadata, "description", None) if metadata else None
+ description: str | None = (
+ getattr(metadata, "description", None) if metadata else None
+ )
# Extract enhanced metadata if available
author: str | None = getattr(metadata, "author", None) if metadata else None
@@ -402,26 +403,38 @@ class FirecrawlIngestor(BaseIngestor):
sitemap_last_modified: str | None = (
getattr(metadata, "sitemap_last_modified", None) if metadata else None
)
- source_url: str | None = getattr(metadata, "sourceURL", None) if metadata else None
- keywords: str | list[str] | None = getattr(metadata, "keywords", None) if metadata else None
+ source_url: str | None = (
+ getattr(metadata, "sourceURL", None) if metadata else None
+ )
+ keywords: str | list[str] | None = (
+ getattr(metadata, "keywords", None) if metadata else None
+ )
robots: str | None = getattr(metadata, "robots", None) if metadata else None
# Open Graph metadata
og_title: str | None = getattr(metadata, "ogTitle", None) if metadata else None
- og_description: str | None = getattr(metadata, "ogDescription", None) if metadata else None
+ og_description: str | None = (
+ getattr(metadata, "ogDescription", None) if metadata else None
+ )
og_url: str | None = getattr(metadata, "ogUrl", None) if metadata else None
og_image: str | None = getattr(metadata, "ogImage", None) if metadata else None
# Twitter metadata
- twitter_card: str | None = getattr(metadata, "twitterCard", None) if metadata else None
- twitter_site: str | None = getattr(metadata, "twitterSite", None) if metadata else None
+ twitter_card: str | None = (
+ getattr(metadata, "twitterCard", None) if metadata else None
+ )
+ twitter_site: str | None = (
+ getattr(metadata, "twitterSite", None) if metadata else None
+ )
twitter_creator: str | None = (
getattr(metadata, "twitterCreator", None) if metadata else None
)
# Additional metadata
favicon: str | None = getattr(metadata, "favicon", None) if metadata else None
- status_code: int | None = getattr(metadata, "statusCode", None) if metadata else None
+ status_code: int | None = (
+ getattr(metadata, "statusCode", None) if metadata else None
+ )
return FirecrawlPage(
url=url,
@@ -594,10 +607,16 @@ class FirecrawlIngestor(BaseIngestor):
"site_name": domain_info["site_name"],
# Document structure
"heading_hierarchy": (
- list(hierarchy_val) if (hierarchy_val := structure_info.get("heading_hierarchy")) and isinstance(hierarchy_val, (list, tuple)) else []
+ list(hierarchy_val)
+ if (hierarchy_val := structure_info.get("heading_hierarchy"))
+ and isinstance(hierarchy_val, (list, tuple))
+ else []
),
"section_depth": (
- int(depth_val) if (depth_val := structure_info.get("section_depth")) and isinstance(depth_val, (int, str)) else 0
+ int(depth_val)
+ if (depth_val := structure_info.get("section_depth"))
+ and isinstance(depth_val, (int, str))
+ else 0
),
"has_code_blocks": bool(structure_info.get("has_code_blocks", False)),
"has_images": bool(structure_info.get("has_images", False)),
@@ -632,7 +651,11 @@ class FirecrawlIngestor(BaseIngestor):
await self.client.close()
except Exception as e:
logging.debug(f"Error closing Firecrawl client: {e}")
- elif hasattr(self.client, "_session") and self.client._session and hasattr(self.client._session, "close"):
+ elif (
+ hasattr(self.client, "_session")
+ and self.client._session
+ and hasattr(self.client._session, "close")
+ ):
try:
await self.client._session.close()
except Exception as e:
diff --git a/ingest_pipeline/ingestors/repomix.py b/ingest_pipeline/ingestors/repomix.py
index 281a952..0448599 100644
--- a/ingest_pipeline/ingestors/repomix.py
+++ b/ingest_pipeline/ingestors/repomix.py
@@ -80,6 +80,7 @@ class RepomixIngestor(BaseIngestor):
return result.returncode == 0
except Exception as e:
import logging
+
logging.warning(f"Failed to validate repository {source_url}: {e}")
return False
@@ -97,9 +98,7 @@ class RepomixIngestor(BaseIngestor):
try:
with tempfile.TemporaryDirectory() as temp_dir:
# Shallow clone to get file count
- repo_path = await self._clone_repository(
- source_url, temp_dir, shallow=True
- )
+ repo_path = await self._clone_repository(source_url, temp_dir, shallow=True)
# Count files matching patterns
file_count = 0
@@ -111,6 +110,7 @@ class RepomixIngestor(BaseIngestor):
return file_count
except Exception as e:
import logging
+
logging.warning(f"Failed to estimate size for repository {source_url}: {e}")
return 0
@@ -179,9 +179,7 @@ class RepomixIngestor(BaseIngestor):
return output_file
- async def _parse_repomix_output(
- self, output_file: Path, job: IngestionJob
- ) -> list[Document]:
+ async def _parse_repomix_output(self, output_file: Path, job: IngestionJob) -> list[Document]:
"""
Parse repomix output into documents.
@@ -210,14 +208,17 @@ class RepomixIngestor(BaseIngestor):
chunks = self._chunk_content(file_content)
for i, chunk in enumerate(chunks):
doc = self._create_document(
- file_path, chunk, job, chunk_index=i,
- git_metadata=git_metadata, repo_info=repo_info
+ file_path,
+ chunk,
+ job,
+ chunk_index=i,
+ git_metadata=git_metadata,
+ repo_info=repo_info,
)
documents.append(doc)
else:
doc = self._create_document(
- file_path, file_content, job,
- git_metadata=git_metadata, repo_info=repo_info
+ file_path, file_content, job, git_metadata=git_metadata, repo_info=repo_info
)
documents.append(doc)
@@ -300,64 +301,64 @@ class RepomixIngestor(BaseIngestor):
# Map common extensions to languages
ext_map = {
- '.py': 'python',
- '.js': 'javascript',
- '.ts': 'typescript',
- '.jsx': 'javascript',
- '.tsx': 'typescript',
- '.java': 'java',
- '.c': 'c',
- '.cpp': 'cpp',
- '.cc': 'cpp',
- '.cxx': 'cpp',
- '.h': 'c',
- '.hpp': 'cpp',
- '.cs': 'csharp',
- '.go': 'go',
- '.rs': 'rust',
- '.php': 'php',
- '.rb': 'ruby',
- '.swift': 'swift',
- '.kt': 'kotlin',
- '.scala': 'scala',
- '.sh': 'shell',
- '.bash': 'shell',
- '.zsh': 'shell',
- '.sql': 'sql',
- '.html': 'html',
- '.css': 'css',
- '.scss': 'scss',
- '.less': 'less',
- '.yaml': 'yaml',
- '.yml': 'yaml',
- '.json': 'json',
- '.xml': 'xml',
- '.md': 'markdown',
- '.txt': 'text',
- '.cfg': 'config',
- '.ini': 'config',
- '.toml': 'toml',
+ ".py": "python",
+ ".js": "javascript",
+ ".ts": "typescript",
+ ".jsx": "javascript",
+ ".tsx": "typescript",
+ ".java": "java",
+ ".c": "c",
+ ".cpp": "cpp",
+ ".cc": "cpp",
+ ".cxx": "cpp",
+ ".h": "c",
+ ".hpp": "cpp",
+ ".cs": "csharp",
+ ".go": "go",
+ ".rs": "rust",
+ ".php": "php",
+ ".rb": "ruby",
+ ".swift": "swift",
+ ".kt": "kotlin",
+ ".scala": "scala",
+ ".sh": "shell",
+ ".bash": "shell",
+ ".zsh": "shell",
+ ".sql": "sql",
+ ".html": "html",
+ ".css": "css",
+ ".scss": "scss",
+ ".less": "less",
+ ".yaml": "yaml",
+ ".yml": "yaml",
+ ".json": "json",
+ ".xml": "xml",
+ ".md": "markdown",
+ ".txt": "text",
+ ".cfg": "config",
+ ".ini": "config",
+ ".toml": "toml",
}
if extension in ext_map:
return ext_map[extension]
# Try to detect from shebang
- if content.startswith('#!'):
- first_line = content.split('\n')[0]
- if 'python' in first_line:
- return 'python'
- elif 'node' in first_line or 'javascript' in first_line:
- return 'javascript'
- elif 'bash' in first_line or 'sh' in first_line:
- return 'shell'
+ if content.startswith("#!"):
+ first_line = content.split("\n")[0]
+ if "python" in first_line:
+ return "python"
+ elif "node" in first_line or "javascript" in first_line:
+ return "javascript"
+ elif "bash" in first_line or "sh" in first_line:
+ return "shell"
return None
@staticmethod
def _analyze_code_structure(content: str, language: str | None) -> dict[str, object]:
"""Analyze code structure and extract metadata."""
- lines = content.split('\n')
+ lines = content.split("\n")
# Basic metrics
has_functions = False
@@ -366,29 +367,35 @@ class RepomixIngestor(BaseIngestor):
has_comments = False
# Language-specific patterns
- if language == 'python':
- has_functions = bool(re.search(r'^\s*def\s+\w+', content, re.MULTILINE))
- has_classes = bool(re.search(r'^\s*class\s+\w+', content, re.MULTILINE))
- has_imports = bool(re.search(r'^\s*(import|from)\s+', content, re.MULTILINE))
- has_comments = bool(re.search(r'^\s*#', content, re.MULTILINE))
- elif language in ['javascript', 'typescript']:
- has_functions = bool(re.search(r'(function\s+\w+|^\s*\w+\s*:\s*function|\w+\s*=>\s*)', content, re.MULTILINE))
- has_classes = bool(re.search(r'^\s*class\s+\w+', content, re.MULTILINE))
- has_imports = bool(re.search(r'^\s*(import|require)', content, re.MULTILINE))
- has_comments = bool(re.search(r'//|/\*', content))
- elif language == 'java':
- has_functions = bool(re.search(r'(public|private|protected).*\w+\s*\(', content, re.MULTILINE))
- has_classes = bool(re.search(r'(public|private)?\s*class\s+\w+', content, re.MULTILINE))
- has_imports = bool(re.search(r'^\s*import\s+', content, re.MULTILINE))
- has_comments = bool(re.search(r'//|/\*', content))
+ if language == "python":
+ has_functions = bool(re.search(r"^\s*def\s+\w+", content, re.MULTILINE))
+ has_classes = bool(re.search(r"^\s*class\s+\w+", content, re.MULTILINE))
+ has_imports = bool(re.search(r"^\s*(import|from)\s+", content, re.MULTILINE))
+ has_comments = bool(re.search(r"^\s*#", content, re.MULTILINE))
+ elif language in ["javascript", "typescript"]:
+ has_functions = bool(
+ re.search(
+ r"(function\s+\w+|^\s*\w+\s*:\s*function|\w+\s*=>\s*)", content, re.MULTILINE
+ )
+ )
+ has_classes = bool(re.search(r"^\s*class\s+\w+", content, re.MULTILINE))
+ has_imports = bool(re.search(r"^\s*(import|require)", content, re.MULTILINE))
+ has_comments = bool(re.search(r"//|/\*", content))
+ elif language == "java":
+ has_functions = bool(
+ re.search(r"(public|private|protected).*\w+\s*\(", content, re.MULTILINE)
+ )
+ has_classes = bool(re.search(r"(public|private)?\s*class\s+\w+", content, re.MULTILINE))
+ has_imports = bool(re.search(r"^\s*import\s+", content, re.MULTILINE))
+ has_comments = bool(re.search(r"//|/\*", content))
return {
- 'has_functions': has_functions,
- 'has_classes': has_classes,
- 'has_imports': has_imports,
- 'has_comments': has_comments,
- 'line_count': len(lines),
- 'non_empty_lines': len([line for line in lines if line.strip()]),
+ "has_functions": has_functions,
+ "has_classes": has_classes,
+ "has_imports": has_imports,
+ "has_comments": has_comments,
+ "line_count": len(lines),
+ "non_empty_lines": len([line for line in lines if line.strip()]),
}
@staticmethod
@@ -399,18 +406,16 @@ class RepomixIngestor(BaseIngestor):
org_name = None
# Handle different URL formats
- if 'github.com' in repo_url or 'gitlab.com' in repo_url:
- if path_match := re.search(
- r'/([^/]+)/([^/]+?)(?:\.git)?/?$', repo_url
- ):
+ if "github.com" in repo_url or "gitlab.com" in repo_url:
+ if path_match := re.search(r"/([^/]+)/([^/]+?)(?:\.git)?/?$", repo_url):
org_name = path_match[1]
repo_name = path_match[2]
- elif path_match := re.search(r'/([^/]+?)(?:\.git)?/?$', repo_url):
+ elif path_match := re.search(r"/([^/]+?)(?:\.git)?/?$", repo_url):
repo_name = path_match[1]
return {
- 'repository_name': repo_name or 'unknown',
- 'organization': org_name or 'unknown',
+ "repository_name": repo_name or "unknown",
+ "organization": org_name or "unknown",
}
async def _get_git_metadata(self, repo_path: Path) -> dict[str, str | None]:
@@ -418,31 +423,35 @@ class RepomixIngestor(BaseIngestor):
try:
# Get current branch
branch_result = await self._run_command(
- ['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
- cwd=str(repo_path),
- timeout=5
+ ["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd=str(repo_path), timeout=5
+ )
+ branch_name = (
+ branch_result.stdout.decode().strip() if branch_result.returncode == 0 else None
)
- branch_name = branch_result.stdout.decode().strip() if branch_result.returncode == 0 else None
# Get current commit hash
commit_result = await self._run_command(
- ['git', 'rev-parse', 'HEAD'],
- cwd=str(repo_path),
- timeout=5
+ ["git", "rev-parse", "HEAD"], cwd=str(repo_path), timeout=5
+ )
+ commit_hash = (
+ commit_result.stdout.decode().strip() if commit_result.returncode == 0 else None
)
- commit_hash = commit_result.stdout.decode().strip() if commit_result.returncode == 0 else None
return {
- 'branch_name': branch_name,
- 'commit_hash': commit_hash[:8] if commit_hash else None, # Short hash
+ "branch_name": branch_name,
+ "commit_hash": commit_hash[:8] if commit_hash else None, # Short hash
}
except Exception:
- return {'branch_name': None, 'commit_hash': None}
+ return {"branch_name": None, "commit_hash": None}
def _create_document(
- self, file_path: str, content: str, job: IngestionJob, chunk_index: int = 0,
+ self,
+ file_path: str,
+ content: str,
+ job: IngestionJob,
+ chunk_index: int = 0,
git_metadata: dict[str, str | None] | None = None,
- repo_info: dict[str, str] | None = None
+ repo_info: dict[str, str] | None = None,
) -> Document:
"""
Create a Document from repository content with enriched metadata.
@@ -466,19 +475,19 @@ class RepomixIngestor(BaseIngestor):
# Determine content type based on language
content_type_map = {
- 'python': 'text/x-python',
- 'javascript': 'text/javascript',
- 'typescript': 'text/typescript',
- 'java': 'text/x-java-source',
- 'html': 'text/html',
- 'css': 'text/css',
- 'json': 'application/json',
- 'yaml': 'text/yaml',
- 'markdown': 'text/markdown',
- 'shell': 'text/x-shellscript',
- 'sql': 'text/x-sql',
+ "python": "text/x-python",
+ "javascript": "text/javascript",
+ "typescript": "text/typescript",
+ "java": "text/x-java-source",
+ "html": "text/html",
+ "css": "text/css",
+ "json": "application/json",
+ "yaml": "text/yaml",
+ "markdown": "text/markdown",
+ "shell": "text/x-shellscript",
+ "sql": "text/x-sql",
}
- content_type = content_type_map.get(programming_language or 'text', 'text/plain')
+ content_type = content_type_map.get(programming_language or "text", "text/plain")
# Build rich metadata
metadata: DocumentMetadata = {
@@ -487,8 +496,7 @@ class RepomixIngestor(BaseIngestor):
"content_type": content_type,
"word_count": len(content.split()),
"char_count": len(content),
- "title": f"{file_path}"
- + (f" (chunk {chunk_index})" if chunk_index > 0 else ""),
+ "title": f"{file_path}" + (f" (chunk {chunk_index})" if chunk_index > 0 else ""),
"description": f"Repository file: {file_path}",
"category": "source_code" if programming_language else "documentation",
"language": programming_language or "text",
@@ -500,19 +508,23 @@ class RepomixIngestor(BaseIngestor):
# Add repository info if available
if repo_info:
- metadata["repository_name"] = repo_info.get('repository_name')
+ metadata["repository_name"] = repo_info.get("repository_name")
# Add git metadata if available
if git_metadata:
- metadata["branch_name"] = git_metadata.get('branch_name')
- metadata["commit_hash"] = git_metadata.get('commit_hash')
+ metadata["branch_name"] = git_metadata.get("branch_name")
+ metadata["commit_hash"] = git_metadata.get("commit_hash")
# Add code-specific metadata for programming files
if programming_language and structure_info:
# Calculate code quality score
- total_lines = structure_info.get('line_count', 1)
- non_empty_lines = structure_info.get('non_empty_lines', 0)
- if isinstance(total_lines, int) and isinstance(non_empty_lines, int) and total_lines > 0:
+ total_lines = structure_info.get("line_count", 1)
+ non_empty_lines = structure_info.get("non_empty_lines", 0)
+ if (
+ isinstance(total_lines, int)
+ and isinstance(non_empty_lines, int)
+ and total_lines > 0
+ ):
completeness_score = (non_empty_lines / total_lines) * 100
metadata["completeness_score"] = completeness_score
@@ -566,5 +578,6 @@ class RepomixIngestor(BaseIngestor):
await asyncio.wait_for(proc.wait(), timeout=2.0)
except TimeoutError:
import logging
+
logging.warning(f"Process {proc.pid} did not terminate cleanly")
raise IngestionError(f"Command timed out: {' '.join(cmd)}") from e
diff --git a/ingest_pipeline/storage/__init__.py b/ingest_pipeline/storage/__init__.py
index 2e1bbfc..a580b0e 100644
--- a/ingest_pipeline/storage/__init__.py
+++ b/ingest_pipeline/storage/__init__.py
@@ -11,6 +11,7 @@ if TYPE_CHECKING:
try:
from .r2r.storage import R2RStorage as _RuntimeR2RStorage
+
R2RStorage: type[BaseStorage] | None = _RuntimeR2RStorage
except ImportError:
R2RStorage = None
diff --git a/ingest_pipeline/storage/__pycache__/__init__.cpython-312.pyc b/ingest_pipeline/storage/__pycache__/__init__.cpython-312.pyc
index ad92182..ca112e2 100644
Binary files a/ingest_pipeline/storage/__pycache__/__init__.cpython-312.pyc and b/ingest_pipeline/storage/__pycache__/__init__.cpython-312.pyc differ
diff --git a/ingest_pipeline/storage/__pycache__/base.cpython-312.pyc b/ingest_pipeline/storage/__pycache__/base.cpython-312.pyc
index 0e0534a..47f351e 100644
Binary files a/ingest_pipeline/storage/__pycache__/base.cpython-312.pyc and b/ingest_pipeline/storage/__pycache__/base.cpython-312.pyc differ
diff --git a/ingest_pipeline/storage/__pycache__/openwebui.cpython-312.pyc b/ingest_pipeline/storage/__pycache__/openwebui.cpython-312.pyc
index bfec7cf..d88ec2f 100644
Binary files a/ingest_pipeline/storage/__pycache__/openwebui.cpython-312.pyc and b/ingest_pipeline/storage/__pycache__/openwebui.cpython-312.pyc differ
diff --git a/ingest_pipeline/storage/__pycache__/weaviate.cpython-312.pyc b/ingest_pipeline/storage/__pycache__/weaviate.cpython-312.pyc
index 4b20569..c98b719 100644
Binary files a/ingest_pipeline/storage/__pycache__/weaviate.cpython-312.pyc and b/ingest_pipeline/storage/__pycache__/weaviate.cpython-312.pyc differ
diff --git a/ingest_pipeline/storage/base.py b/ingest_pipeline/storage/base.py
index 2e3e8fa..34a5b36 100644
--- a/ingest_pipeline/storage/base.py
+++ b/ingest_pipeline/storage/base.py
@@ -1,10 +1,12 @@
"""Base storage interface."""
+import asyncio
import logging
+import random
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
-from typing import Final
from types import TracebackType
+from typing import Final
import httpx
from pydantic import SecretStr
@@ -37,6 +39,8 @@ class TypedHttpClient:
api_key: SecretStr | None = None,
timeout: float = 30.0,
headers: dict[str, str] | None = None,
+ max_connections: int = 100,
+ max_keepalive_connections: int = 20,
):
"""
Initialize the typed HTTP client.
@@ -46,6 +50,8 @@ class TypedHttpClient:
api_key: Optional API key for authentication
timeout: Request timeout in seconds
headers: Additional headers to include with requests
+ max_connections: Maximum total connections in pool
+ max_keepalive_connections: Maximum keepalive connections
"""
self._base_url = base_url
@@ -54,14 +60,14 @@ class TypedHttpClient:
if api_key:
client_headers["Authorization"] = f"Bearer {api_key.get_secret_value()}"
- # Note: Pylance incorrectly reports "No parameter named 'base_url'"
- # but base_url is a valid AsyncClient parameter (see HTTPX docs)
- client_kwargs: dict[str, str | dict[str, str] | float] = {
- "base_url": base_url,
- "headers": client_headers,
- "timeout": timeout,
- }
- self.client = httpx.AsyncClient(**client_kwargs) # type: ignore
+ # Create typed client configuration with connection pooling
+ limits = httpx.Limits(
+ max_connections=max_connections, max_keepalive_connections=max_keepalive_connections
+ )
+ timeout_config = httpx.Timeout(connect=5.0, read=timeout, write=30.0, pool=10.0)
+ self.client = httpx.AsyncClient(
+ base_url=base_url, headers=client_headers, timeout=timeout_config, limits=limits
+ )
async def request(
self,
@@ -73,44 +79,92 @@ class TypedHttpClient:
data: dict[str, object] | None = None,
files: dict[str, tuple[str, bytes, str]] | None = None,
params: dict[str, str | bool] | None = None,
+ max_retries: int = 3,
+ retry_delay: float = 1.0,
) -> httpx.Response | None:
"""
- Perform an HTTP request with consistent error handling.
+ Perform an HTTP request with consistent error handling and retries.
Args:
method: HTTP method (GET, POST, DELETE, etc.)
path: URL path relative to base_url
allow_404: If True, return None for 404 responses instead of raising
- **kwargs: Arguments passed to httpx request
+ json: JSON data to send
+ data: Form data to send
+ files: Files to upload
+ params: Query parameters
+ max_retries: Maximum number of retry attempts
+ retry_delay: Base delay between retries in seconds
Returns:
HTTP response object, or None if allow_404=True and status is 404
Raises:
- StorageError: If request fails
+ StorageError: If request fails after retries
"""
- try:
- response = await self.client.request( # type: ignore
- method, path, json=json, data=data, files=files, params=params
- )
- response.raise_for_status() # type: ignore
- return response # type: ignore
- except Exception as e:
- # Handle 404 as special case if requested
- if allow_404 and hasattr(e, 'response') and getattr(e.response, 'status_code', None) == 404: # type: ignore
- LOGGER.debug("Resource not found (404): %s %s", method, path)
- return None
+ last_exception: Exception | None = None
- # Convert all HTTP-related exceptions to StorageError
- error_name = e.__class__.__name__
- if 'HTTP' in error_name or 'Connect' in error_name or 'Request' in error_name:
- if hasattr(e, 'response') and hasattr(e.response, 'status_code'): # type: ignore
- status_code = getattr(e.response, 'status_code', 'unknown') # type: ignore
- raise StorageError(f"HTTP {status_code} error from {self._base_url}: {e}") from e
- else:
- raise StorageError(f"Request failed to {self._base_url}: {e}") from e
- # Re-raise non-HTTP exceptions
- raise
+ for attempt in range(max_retries + 1):
+ try:
+ response = await self.client.request(
+ method, path, json=json, data=data, files=files, params=params
+ )
+ response.raise_for_status()
+ return response
+ except httpx.HTTPStatusError as e:
+ # Handle 404 as special case if requested
+ if allow_404 and e.response.status_code == 404:
+ LOGGER.debug("Resource not found (404): %s %s", method, path)
+ return None
+
+ # Don't retry client errors (4xx except for specific cases)
+ if 400 <= e.response.status_code < 500 and e.response.status_code not in [429, 408]:
+ raise StorageError(
+ f"HTTP {e.response.status_code} error from {self._base_url}: {e}"
+ ) from e
+
+ last_exception = e
+ if attempt < max_retries:
+ # Exponential backoff with jitter for retryable errors
+ delay = retry_delay * (2**attempt) + random.uniform(0, 1)
+ LOGGER.warning(
+ "HTTP %d error on attempt %d/%d, retrying in %.2fs: %s",
+ e.response.status_code,
+ attempt + 1,
+ max_retries + 1,
+ delay,
+ e,
+ )
+ await asyncio.sleep(delay)
+
+ except httpx.HTTPError as e:
+ last_exception = e
+ if attempt < max_retries:
+ # Retry transport errors with backoff
+ delay = retry_delay * (2**attempt) + random.uniform(0, 1)
+ LOGGER.warning(
+ "HTTP transport error on attempt %d/%d, retrying in %.2fs: %s",
+ attempt + 1,
+ max_retries + 1,
+ delay,
+ e,
+ )
+ await asyncio.sleep(delay)
+
+ # All retries exhausted - last_exception should always be set if we reach here
+ if last_exception is None:
+ raise StorageError(
+ f"Request to {self._base_url} failed after {max_retries + 1} attempts with unknown error"
+ )
+
+ if isinstance(last_exception, httpx.HTTPStatusError):
+ raise StorageError(
+ f"HTTP {last_exception.response.status_code} error from {self._base_url} after {max_retries + 1} attempts: {last_exception}"
+ ) from last_exception
+ else:
+ raise StorageError(
+ f"HTTP transport error to {self._base_url} after {max_retries + 1} attempts: {last_exception}"
+ ) from last_exception
async def close(self) -> None:
"""Close the HTTP client and cleanup resources."""
@@ -127,7 +181,7 @@ class TypedHttpClient:
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
- exc_tb: TracebackType | None
+ exc_tb: TracebackType | None,
) -> None:
"""Async context manager exit."""
await self.close()
@@ -224,11 +278,30 @@ class BaseStorage(ABC):
# Check staleness if timestamp is available
if "timestamp" in document.metadata:
from datetime import UTC, datetime, timedelta
+
timestamp_obj = document.metadata["timestamp"]
+
+ # Handle both datetime objects and ISO strings
if isinstance(timestamp_obj, datetime):
timestamp = timestamp_obj
- cutoff = datetime.now(UTC) - timedelta(days=stale_after_days)
- return timestamp >= cutoff
+ # Ensure timezone awareness
+ if timestamp.tzinfo is None:
+ timestamp = timestamp.replace(tzinfo=UTC)
+ elif isinstance(timestamp_obj, str):
+ try:
+ timestamp = datetime.fromisoformat(timestamp_obj)
+ # Ensure timezone awareness
+ if timestamp.tzinfo is None:
+ timestamp = timestamp.replace(tzinfo=UTC)
+ except ValueError:
+ # If parsing fails, assume document is stale
+ return False
+ else:
+ # Unknown timestamp format, assume stale
+ return False
+
+ cutoff = datetime.now(UTC) - timedelta(days=stale_after_days)
+ return timestamp >= cutoff
# If no timestamp, assume it exists and is valid
return True
diff --git a/ingest_pipeline/storage/openwebui.py b/ingest_pipeline/storage/openwebui.py
index 8e42d0c..6b0b8e2 100644
--- a/ingest_pipeline/storage/openwebui.py
+++ b/ingest_pipeline/storage/openwebui.py
@@ -1,10 +1,10 @@
"""Open WebUI storage adapter."""
-
import asyncio
import contextlib
import logging
-from typing import Final, TypedDict, cast
+import time
+from typing import Final, NamedTuple, TypedDict
from typing_extensions import override
@@ -18,6 +18,7 @@ LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
class OpenWebUIFileResponse(TypedDict, total=False):
"""OpenWebUI API file response structure."""
+
id: str
filename: str
name: str
@@ -29,6 +30,7 @@ class OpenWebUIFileResponse(TypedDict, total=False):
class OpenWebUIKnowledgeBase(TypedDict, total=False):
"""OpenWebUI knowledge base response structure."""
+
id: str
name: str
description: str
@@ -38,13 +40,19 @@ class OpenWebUIKnowledgeBase(TypedDict, total=False):
updated_at: str
+class CacheEntry(NamedTuple):
+ """Cache entry with value and expiration time."""
+
+ value: str
+ expires_at: float
class OpenWebUIStorage(BaseStorage):
"""Storage adapter for Open WebUI knowledge endpoints."""
http_client: TypedHttpClient
- _knowledge_cache: dict[str, str]
+ _knowledge_cache: dict[str, CacheEntry]
+ _cache_ttl: float
def __init__(self, config: StorageConfig):
"""
@@ -61,6 +69,7 @@ class OpenWebUIStorage(BaseStorage):
timeout=30.0,
)
self._knowledge_cache = {}
+ self._cache_ttl = 300.0 # 5 minutes TTL
@override
async def initialize(self) -> None:
@@ -106,9 +115,25 @@ class OpenWebUIStorage(BaseStorage):
return []
normalized: list[OpenWebUIKnowledgeBase] = []
for item in data:
- if isinstance(item, dict):
- # Cast to our expected structure
- kb_item = cast(OpenWebUIKnowledgeBase, item)
+ if (
+ isinstance(item, dict)
+ and "id" in item
+ and "name" in item
+ and isinstance(item["id"], str)
+ and isinstance(item["name"], str)
+ ):
+ # Create a new dict with known structure
+ kb_item: OpenWebUIKnowledgeBase = {
+ "id": item["id"],
+ "name": item["name"],
+ "description": item.get("description", ""),
+ "created_at": item.get("created_at", ""),
+ "updated_at": item.get("updated_at", ""),
+ }
+ if "files" in item and isinstance(item["files"], list):
+ kb_item["files"] = item["files"]
+ if "data" in item and isinstance(item["data"], dict):
+ kb_item["data"] = item["data"]
normalized.append(kb_item)
return normalized
@@ -124,22 +149,29 @@ class OpenWebUIStorage(BaseStorage):
if not target:
raise StorageError("Knowledge base name is required")
- if cached := self._knowledge_cache.get(target):
- return cached
+ # Check cache with TTL
+ if cached_entry := self._knowledge_cache.get(target):
+ if time.time() < cached_entry.expires_at:
+ return cached_entry.value
+ else:
+ # Entry expired, remove it
+ del self._knowledge_cache[target]
knowledge_bases = await self._fetch_knowledge_bases()
for kb in knowledge_bases:
if kb.get("name") == target:
kb_id = kb.get("id")
if isinstance(kb_id, str):
- self._knowledge_cache[target] = kb_id
+ expires_at = time.time() + self._cache_ttl
+ self._knowledge_cache[target] = CacheEntry(kb_id, expires_at)
return kb_id
if not create:
return None
knowledge_id = await self._create_collection(target)
- self._knowledge_cache[target] = knowledge_id
+ expires_at = time.time() + self._cache_ttl
+ self._knowledge_cache[target] = CacheEntry(knowledge_id, expires_at)
return knowledge_id
@override
@@ -165,7 +197,7 @@ class OpenWebUIStorage(BaseStorage):
# Use document title from metadata if available, otherwise fall back to ID
filename = document.metadata.get("title") or f"doc_{document.id}"
# Ensure filename has proper extension
- if not filename.endswith(('.txt', '.md', '.pdf', '.doc', '.docx')):
+ if not filename.endswith((".txt", ".md", ".pdf", ".doc", ".docx")):
filename = f"{filename}.txt"
files = {"file": (filename, document.content.encode(), "text/plain")}
response = await self.http_client.request(
@@ -185,11 +217,9 @@ class OpenWebUIStorage(BaseStorage):
# Step 2: Add file to knowledge base
response = await self.http_client.request(
- "POST",
- f"/api/v1/knowledge/{knowledge_id}/file/add",
- json={"file_id": file_id}
+ "POST", f"/api/v1/knowledge/{knowledge_id}/file/add", json={"file_id": file_id}
)
-
+
return str(file_id)
except Exception as e:
@@ -220,7 +250,7 @@ class OpenWebUIStorage(BaseStorage):
# Use document title from metadata if available, otherwise fall back to ID
filename = doc.metadata.get("title") or f"doc_{doc.id}"
# Ensure filename has proper extension
- if not filename.endswith(('.txt', '.md', '.pdf', '.doc', '.docx')):
+ if not filename.endswith((".txt", ".md", ".pdf", ".doc", ".docx")):
filename = f"{filename}.txt"
files = {"file": (filename, doc.content.encode(), "text/plain")}
upload_response = await self.http_client.request(
@@ -230,7 +260,9 @@ class OpenWebUIStorage(BaseStorage):
params={"process": True, "process_in_background": False},
)
if upload_response is None:
- raise StorageError(f"Unexpected None response from file upload for document {doc.id}")
+ raise StorageError(
+ f"Unexpected None response from file upload for document {doc.id}"
+ )
file_data = upload_response.json()
file_id = file_data.get("id")
@@ -241,9 +273,7 @@ class OpenWebUIStorage(BaseStorage):
)
await self.http_client.request(
- "POST",
- f"/api/v1/knowledge/{knowledge_id}/file/add",
- json={"file_id": file_id}
+ "POST", f"/api/v1/knowledge/{knowledge_id}/file/add", json={"file_id": file_id}
)
return str(file_id)
@@ -259,7 +289,8 @@ class OpenWebUIStorage(BaseStorage):
if isinstance(result, Exception):
failures.append(f"{doc.id}: {result}")
else:
- file_ids.append(cast(str, result))
+ if isinstance(result, str):
+ file_ids.append(result)
if failures:
LOGGER.warning(
@@ -289,10 +320,86 @@ class OpenWebUIStorage(BaseStorage):
"""
_ = document_id, collection_name # Mark as used
# OpenWebUI uses file-based storage without direct document retrieval
- # This will cause the base check_exists method to return False,
- # which means documents will always be re-scraped for OpenWebUI
raise NotImplementedError("OpenWebUI doesn't support document retrieval by ID")
+ @override
+ async def check_exists(
+ self, document_id: str, *, collection_name: str | None = None, stale_after_days: int = 30
+ ) -> bool:
+ """
+ Check if a document exists in OpenWebUI knowledge base by searching files.
+
+ Args:
+ document_id: Document ID to check (usually based on source URL)
+ collection_name: Knowledge base name
+ stale_after_days: Consider document stale after this many days
+
+ Returns:
+ True if document exists and is not stale, False otherwise
+ """
+ try:
+ from datetime import UTC, datetime, timedelta
+
+ # Get knowledge base
+ knowledge_id = await self._get_knowledge_id(collection_name, create=False)
+ if not knowledge_id:
+ return False
+
+ # Get detailed knowledge base info to access files
+ response = await self.http_client.request("GET", f"/api/v1/knowledge/{knowledge_id}")
+ if response is None:
+ return False
+
+ kb_data = response.json()
+ files = kb_data.get("files", [])
+
+ # Look for file with matching document ID or source URL in metadata
+ cutoff = datetime.now(UTC) - timedelta(days=stale_after_days)
+
+ def _parse_openwebui_timestamp(timestamp_str: str) -> datetime | None:
+ """Parse OpenWebUI timestamp with proper timezone handling."""
+ try:
+ # Handle both 'Z' suffix and explicit timezone
+ normalized = timestamp_str.replace("Z", "+00:00")
+ parsed = datetime.fromisoformat(normalized)
+ # Ensure timezone awareness
+ if parsed.tzinfo is None:
+ parsed = parsed.replace(tzinfo=UTC)
+ return parsed
+ except (ValueError, AttributeError):
+ return None
+
+ def _check_file_freshness(file_info: dict[str, object]) -> bool:
+ """Check if file is fresh enough based on creation date."""
+ created_at = file_info.get("created_at")
+ if not isinstance(created_at, str):
+ # No date info available, consider stale to be safe
+ return False
+
+ file_date = _parse_openwebui_timestamp(created_at)
+ return file_date is not None and file_date >= cutoff
+
+ for file_info in files:
+ if not isinstance(file_info, dict):
+ continue
+
+ file_id = file_info.get("id")
+ if str(file_id) == document_id:
+ return _check_file_freshness(file_info)
+
+ # Also check meta.source_url if available for URL-based document IDs
+ meta = file_info.get("meta", {})
+ if isinstance(meta, dict):
+ source_url = meta.get("source_url")
+ if source_url and document_id in str(source_url):
+ return _check_file_freshness(file_info)
+
+ return False
+
+ except Exception as e:
+ LOGGER.debug("Error checking document existence in OpenWebUI: %s", e)
+ return False
+
@override
async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool:
"""
@@ -316,14 +423,10 @@ class OpenWebUIStorage(BaseStorage):
await self.http_client.request(
"POST",
f"/api/v1/knowledge/{knowledge_id}/file/remove",
- json={"file_id": document_id}
+ json={"file_id": document_id},
)
- await self.http_client.request(
- "DELETE",
- f"/api/v1/files/{document_id}",
- allow_404=True
- )
+ await self.http_client.request("DELETE", f"/api/v1/files/{document_id}", allow_404=True)
return True
except Exception as exc:
LOGGER.error("Error deleting file %s from OpenWebUI", document_id, exc_info=exc)
@@ -366,9 +469,7 @@ class OpenWebUIStorage(BaseStorage):
# Delete the knowledge base using the OpenWebUI API
await self.http_client.request(
- "DELETE",
- f"/api/v1/knowledge/{knowledge_id}/delete",
- allow_404=True
+ "DELETE", f"/api/v1/knowledge/{knowledge_id}/delete", allow_404=True
)
# Remove from cache if it exists
@@ -379,13 +480,16 @@ class OpenWebUIStorage(BaseStorage):
return True
except Exception as e:
- if hasattr(e, 'response'):
- response_attr = getattr(e, 'response', None)
- if response_attr is not None and hasattr(response_attr, 'status_code'):
+ if hasattr(e, "response"):
+ response_attr = getattr(e, "response", None)
+ if response_attr is not None and hasattr(response_attr, "status_code"):
with contextlib.suppress(Exception):
- status_code = response_attr.status_code # type: ignore[attr-defined]
+ status_code = response_attr.status_code
if status_code == 404:
- LOGGER.info("Knowledge base %s was already deleted or not found", collection_name)
+ LOGGER.info(
+ "Knowledge base %s was already deleted or not found",
+ collection_name,
+ )
return True
LOGGER.error(
"Error deleting knowledge base %s from OpenWebUI",
@@ -394,8 +498,6 @@ class OpenWebUIStorage(BaseStorage):
)
return False
-
-
async def _get_knowledge_base_count(self, kb: OpenWebUIKnowledgeBase) -> int:
"""Get the file count for a knowledge base."""
kb_id = kb.get("id")
@@ -411,14 +513,13 @@ class OpenWebUIStorage(BaseStorage):
files = kb.get("files", [])
return len(files) if isinstance(files, list) and files is not None else 0
- async def _count_files_from_detailed_info(self, kb_id: str, name: str, kb: OpenWebUIKnowledgeBase) -> int:
+ async def _count_files_from_detailed_info(
+ self, kb_id: str, name: str, kb: OpenWebUIKnowledgeBase
+ ) -> int:
"""Count files by fetching detailed knowledge base info."""
try:
LOGGER.debug(f"Fetching detailed info for KB '{name}' from /api/v1/knowledge/{kb_id}")
- detail_response = await self.http_client.request(
- "GET",
- f"/api/v1/knowledge/{kb_id}"
- )
+ detail_response = await self.http_client.request("GET", f"/api/v1/knowledge/{kb_id}")
if detail_response is None:
LOGGER.warning(f"Knowledge base '{name}' (ID: {kb_id}) not found")
return self._count_files_from_basic_info(kb)
@@ -489,10 +590,7 @@ class OpenWebUIStorage(BaseStorage):
return 0
# Get detailed knowledge base information to get accurate file count
- detail_response = await self.http_client.request(
- "GET",
- f"/api/v1/knowledge/{kb_id}"
- )
+ detail_response = await self.http_client.request("GET", f"/api/v1/knowledge/{kb_id}")
if detail_response is None:
LOGGER.warning(f"Knowledge base '{collection_name}' (ID: {kb_id}) not found")
return self._count_files_from_basic_info(kb)
@@ -524,14 +622,28 @@ class OpenWebUIStorage(BaseStorage):
return None
knowledge_bases = response.json()
- return next(
- (
- cast(OpenWebUIKnowledgeBase, kb)
- for kb in knowledge_bases
- if isinstance(kb, dict) and kb.get("name") == name
- ),
- None,
- )
+ # Find and properly type the knowledge base
+ for kb in knowledge_bases:
+ if (
+ isinstance(kb, dict)
+ and kb.get("name") == name
+ and "id" in kb
+ and isinstance(kb["id"], str)
+ ):
+ # Create properly typed response
+ result: OpenWebUIKnowledgeBase = {
+ "id": kb["id"],
+ "name": str(kb["name"]),
+ "description": kb.get("description", ""),
+ "created_at": kb.get("created_at", ""),
+ "updated_at": kb.get("updated_at", ""),
+ }
+ if "files" in kb and isinstance(kb["files"], list):
+ result["files"] = kb["files"]
+ if "data" in kb and isinstance(kb["data"], dict):
+ result["data"] = kb["data"]
+ return result
+ return None
except Exception as e:
raise StorageError(f"Failed to get knowledge base by name: {e}") from e
@@ -623,7 +735,9 @@ class OpenWebUIStorage(BaseStorage):
elif isinstance(file_info.get("meta"), dict):
meta = file_info.get("meta")
if isinstance(meta, dict):
- filename = meta.get("name")
+ filename_value = meta.get("name")
+ if isinstance(filename_value, str):
+ filename = filename_value
# Final fallback
if not filename:
@@ -635,9 +749,11 @@ class OpenWebUIStorage(BaseStorage):
size = 0
meta = file_info.get("meta")
if isinstance(meta, dict):
- size = meta.get("size", 0)
+ size_value = meta.get("size", 0)
+ size = int(size_value) if isinstance(size_value, (int, float)) else 0
else:
- size = file_info.get("size", 0)
+ size_value = file_info.get("size", 0)
+ size = int(size_value) if isinstance(size_value, (int, float)) else 0
# Estimate word count from file size (very rough approximation)
word_count = max(1, int(size / 6)) if isinstance(size, (int, float)) else 0
@@ -650,9 +766,7 @@ class OpenWebUIStorage(BaseStorage):
"content_type": str(file_info.get("content_type", "text/plain")),
"content_preview": f"File uploaded to OpenWebUI: {filename}",
"word_count": word_count,
- "timestamp": str(
- file_info.get("created_at") or file_info.get("timestamp", "")
- ),
+ "timestamp": str(file_info.get("created_at") or file_info.get("timestamp", "")),
}
documents.append(doc_info)
diff --git a/ingest_pipeline/storage/r2r/__pycache__/__init__.cpython-312.pyc b/ingest_pipeline/storage/r2r/__pycache__/__init__.cpython-312.pyc
index b4c1153..bf0ba16 100644
Binary files a/ingest_pipeline/storage/r2r/__pycache__/__init__.cpython-312.pyc and b/ingest_pipeline/storage/r2r/__pycache__/__init__.cpython-312.pyc differ
diff --git a/ingest_pipeline/storage/r2r/__pycache__/storage.cpython-312.pyc b/ingest_pipeline/storage/r2r/__pycache__/storage.cpython-312.pyc
index bfe63ed..a3e6e82 100644
Binary files a/ingest_pipeline/storage/r2r/__pycache__/storage.cpython-312.pyc and b/ingest_pipeline/storage/r2r/__pycache__/storage.cpython-312.pyc differ
diff --git a/ingest_pipeline/storage/r2r/storage.py b/ingest_pipeline/storage/r2r/storage.py
index 1c43302..bbe0719 100644
--- a/ingest_pipeline/storage/r2r/storage.py
+++ b/ingest_pipeline/storage/r2r/storage.py
@@ -4,9 +4,10 @@ from __future__ import annotations
import asyncio
import contextlib
+import logging
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
from datetime import UTC, datetime
-from typing import Self, TypeVar, cast
+from typing import Final, Self, TypeVar, cast
from uuid import UUID, uuid4
# Direct imports for runtime and type checking
@@ -19,6 +20,8 @@ from ...core.models import Document, DocumentMetadata, IngestionSource, StorageC
from ..base import BaseStorage
from ..types import DocumentInfo
+LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
+
T = TypeVar("T")
@@ -183,8 +186,10 @@ class R2RStorage(BaseStorage):
) -> list[str]:
"""Store multiple documents efficiently with connection reuse."""
collection_id = await self._resolve_collection_id(collection_name)
- print(
- f"Using collection ID: {collection_id} for collection: {collection_name or self.config.collection_name}"
+ LOGGER.info(
+ "Using collection ID: %s for collection: %s",
+ collection_id,
+ collection_name or self.config.collection_name,
)
# Filter valid documents upfront
@@ -218,7 +223,7 @@ class R2RStorage(BaseStorage):
if isinstance(result, str):
stored_ids.append(result)
elif isinstance(result, Exception):
- print(f"Document upload failed: {result}")
+ LOGGER.error("Document upload failed: %s", result)
return stored_ids
@@ -239,12 +244,14 @@ class R2RStorage(BaseStorage):
requested_id = str(document.id)
if not document.content or not document.content.strip():
- print(f"Skipping document {requested_id}: empty content")
+ LOGGER.warning("Skipping document %s: empty content", requested_id)
return False
if len(document.content) > 1_000_000: # 1MB limit
- print(
- f"Skipping document {requested_id}: content too large ({len(document.content)} chars)"
+ LOGGER.warning(
+ "Skipping document %s: content too large (%d chars)",
+ requested_id,
+ len(document.content),
)
return False
@@ -263,7 +270,7 @@ class R2RStorage(BaseStorage):
) -> str | None:
"""Store a single document with retry logic using provided HTTP client."""
requested_id = str(document.id)
- print(f"Creating document with ID: {requested_id}")
+ LOGGER.debug("Creating document with ID: %s", requested_id)
max_retries = 3
retry_delay = 1.0
@@ -313,7 +320,7 @@ class R2RStorage(BaseStorage):
requested_id = str(document.id)
metadata = self._build_metadata(document)
- print(f"Built metadata for document {requested_id}: {metadata}")
+ LOGGER.debug("Built metadata for document %s: %s", requested_id, metadata)
files = {
"raw_text": (None, document.content),
@@ -324,10 +331,12 @@ class R2RStorage(BaseStorage):
if collection_id:
files["collection_ids"] = (None, json.dumps([collection_id]))
- print(f"Creating document {requested_id} with collection_ids: [{collection_id}]")
+ LOGGER.debug(
+ "Creating document %s with collection_ids: [%s]", requested_id, collection_id
+ )
- print(f"Sending to R2R - files keys: {list(files.keys())}")
- print(f"Metadata JSON: {files['metadata'][1]}")
+ LOGGER.debug("Sending to R2R - files keys: %s", list(files.keys()))
+ LOGGER.debug("Metadata JSON: %s", files["metadata"][1])
response = await http_client.post(f"{self.endpoint}/v3/documents", files=files) # type: ignore[call-arg]
@@ -346,15 +355,17 @@ class R2RStorage(BaseStorage):
error_detail = (
getattr(response, "json", lambda: {})() if hasattr(response, "json") else {}
)
- print(f"R2R validation error for document {requested_id}: {error_detail}")
- print(f"Document metadata sent: {metadata}")
- print(f"Response status: {getattr(response, 'status_code', 'unknown')}")
- print(f"Response headers: {dict(getattr(response, 'headers', {}))}")
+ LOGGER.error("R2R validation error for document %s: %s", requested_id, error_detail)
+ LOGGER.error("Document metadata sent: %s", metadata)
+ LOGGER.error("Response status: %s", getattr(response, "status_code", "unknown"))
+ LOGGER.error("Response headers: %s", dict(getattr(response, "headers", {})))
except Exception:
- print(
- f"R2R validation error for document {requested_id}: {getattr(response, 'text', 'unknown error')}"
+ LOGGER.error(
+ "R2R validation error for document %s: %s",
+ requested_id,
+ getattr(response, "text", "unknown error"),
)
- print(f"Document metadata sent: {metadata}")
+ LOGGER.error("Document metadata sent: %s", metadata)
def _process_document_response(
self, doc_response: dict[str, object], requested_id: str, collection_id: str
@@ -363,14 +374,16 @@ class R2RStorage(BaseStorage):
response_payload = doc_response.get("results", doc_response)
doc_id = _extract_id(response_payload, requested_id)
- print(f"R2R returned document ID: {doc_id}")
+ LOGGER.info("R2R returned document ID: %s", doc_id)
if doc_id != requested_id:
- print(f"Warning: Requested ID {requested_id} but got {doc_id}")
+ LOGGER.warning("Requested ID %s but got %s", requested_id, doc_id)
if collection_id:
- print(
- f"Document {doc_id} should be assigned to collection {collection_id} via creation API"
+ LOGGER.info(
+ "Document %s should be assigned to collection %s via creation API",
+ doc_id,
+ collection_id,
)
return doc_id
@@ -387,7 +400,7 @@ class R2RStorage(BaseStorage):
if attempt >= max_retries - 1:
return False
- print(f"Timeout for document {requested_id}, retrying in {retry_delay}s...")
+ LOGGER.warning("Timeout for document %s, retrying in %ss...", requested_id, retry_delay)
await asyncio.sleep(retry_delay)
return True
@@ -404,23 +417,26 @@ class R2RStorage(BaseStorage):
if status_code < 500 or attempt >= max_retries - 1:
return False
- print(
- f"Server error {status_code} for document {requested_id}, retrying in {retry_delay}s..."
+ LOGGER.warning(
+ "Server error %s for document %s, retrying in %ss...",
+ status_code,
+ requested_id,
+ retry_delay,
)
await asyncio.sleep(retry_delay)
return True
def _log_document_error(self, document_id: object, exc: Exception) -> None:
"""Log document storage errors with specific categorization."""
- print(f"Failed to store document {document_id}: {exc}")
+ LOGGER.error("Failed to store document %s: %s", document_id, exc)
exc_str = str(exc)
if "422" in exc_str:
- print(" → Data validation issue - check document content and metadata format")
+ LOGGER.error(" → Data validation issue - check document content and metadata format")
elif "timeout" in exc_str.lower():
- print(" → Network timeout - R2R may be overloaded")
+ LOGGER.error(" → Network timeout - R2R may be overloaded")
elif "500" in exc_str:
- print(" → Server error - R2R internal issue")
+ LOGGER.error(" → Server error - R2R internal issue")
else:
import traceback
diff --git a/ingest_pipeline/storage/types.py b/ingest_pipeline/storage/types.py
index 5e5a4f4..62ad0c8 100644
--- a/ingest_pipeline/storage/types.py
+++ b/ingest_pipeline/storage/types.py
@@ -5,6 +5,7 @@ from typing import TypedDict
class CollectionSummary(TypedDict):
"""Collection metadata for describe_collections."""
+
name: str
count: int
size_mb: float
@@ -12,6 +13,7 @@ class CollectionSummary(TypedDict):
class DocumentInfo(TypedDict):
"""Document information for list_documents."""
+
id: str
title: str
source_url: str
@@ -19,4 +21,4 @@ class DocumentInfo(TypedDict):
content_type: str
content_preview: str
word_count: int
- timestamp: str
\ No newline at end of file
+ timestamp: str
diff --git a/ingest_pipeline/storage/weaviate.py b/ingest_pipeline/storage/weaviate.py
index 2ee92ec..1cdfbc5 100644
--- a/ingest_pipeline/storage/weaviate.py
+++ b/ingest_pipeline/storage/weaviate.py
@@ -1,8 +1,9 @@
"""Weaviate storage adapter."""
-from collections.abc import AsyncGenerator, Mapping, Sequence
+import asyncio
+from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
from datetime import UTC, datetime
-from typing import Literal, Self, TypeAlias, cast, overload
+from typing import Literal, Self, TypeAlias, TypeVar, cast, overload
from uuid import UUID
import weaviate
@@ -24,6 +25,7 @@ from .base import BaseStorage
from .types import CollectionSummary, DocumentInfo
VectorContainer: TypeAlias = Mapping[str, object] | Sequence[object] | None
+T = TypeVar("T")
class WeaviateStorage(BaseStorage):
@@ -45,6 +47,28 @@ class WeaviateStorage(BaseStorage):
self.vectorizer = Vectorizer(config)
self._default_collection = self._normalize_collection_name(config.collection_name)
+ async def _run_sync(self, func: Callable[..., T], *args: object, **kwargs: object) -> T:
+ """
+ Run synchronous Weaviate operations in thread pool to avoid blocking event loop.
+
+ Args:
+ func: Synchronous function to run
+ *args: Positional arguments for the function
+ **kwargs: Keyword arguments for the function
+
+ Returns:
+ Result of the function call
+
+ Raises:
+ StorageError: If the operation fails
+ """
+ try:
+ return await asyncio.to_thread(func, *args, **kwargs)
+ except (WeaviateConnectionError, WeaviateBatchError, WeaviateQueryError) as e:
+ raise StorageError(f"Weaviate operation failed: {e}") from e
+ except Exception as e:
+ raise StorageError(f"Unexpected error in Weaviate operation: {e}") from e
+
@override
async def initialize(self) -> None:
"""Initialize Weaviate client and create collection if needed."""
@@ -53,7 +77,7 @@ class WeaviateStorage(BaseStorage):
self.client = weaviate.WeaviateClient(
connection_params=weaviate.connect.ConnectionParams.from_url(
url=str(self.config.endpoint),
- grpc_port=50051, # Default gRPC port
+ grpc_port=self.config.grpc_port or 50051,
),
additional_config=weaviate.classes.init.AdditionalConfig(
timeout=weaviate.classes.init.Timeout(init=30, query=60, insert=120),
@@ -61,7 +85,7 @@ class WeaviateStorage(BaseStorage):
)
# Connect to the client
- self.client.connect()
+ await self._run_sync(self.client.connect)
# Ensure the default collection exists
await self._ensure_collection(self._default_collection)
@@ -76,8 +100,8 @@ class WeaviateStorage(BaseStorage):
if not self.client:
raise StorageError("Weaviate client not initialized")
try:
- client = cast(weaviate.WeaviateClient, self.client)
- client.collections.create(
+ await self._run_sync(
+ self.client.collections.create,
name=collection_name,
properties=[
Property(
@@ -106,7 +130,7 @@ class WeaviateStorage(BaseStorage):
],
vectorizer_config=Configure.Vectorizer.none(),
)
- except Exception as e:
+ except (WeaviateConnectionError, WeaviateBatchError) as e:
raise StorageError(f"Failed to create collection: {e}") from e
@staticmethod
@@ -114,13 +138,9 @@ class WeaviateStorage(BaseStorage):
"""Normalize vector payloads returned by Weaviate into a float list."""
if isinstance(vector_raw, Mapping):
default_vector = vector_raw.get("default")
- return WeaviateStorage._extract_vector(
- cast(VectorContainer, default_vector)
- )
+ return WeaviateStorage._extract_vector(cast(VectorContainer, default_vector))
- if not isinstance(vector_raw, Sequence) or isinstance(
- vector_raw, (str, bytes, bytearray)
- ):
+ if not isinstance(vector_raw, Sequence) or isinstance(vector_raw, (str, bytes, bytearray)):
return None
items = list(vector_raw)
@@ -135,9 +155,7 @@ class WeaviateStorage(BaseStorage):
except (TypeError, ValueError):
return None
- if isinstance(first_item, Sequence) and not isinstance(
- first_item, (str, bytes, bytearray)
- ):
+ if isinstance(first_item, Sequence) and not isinstance(first_item, (str, bytes, bytearray)):
inner_items = list(first_item)
if all(isinstance(item, (int, float)) for item in inner_items):
try:
@@ -168,8 +186,7 @@ class WeaviateStorage(BaseStorage):
properties: object,
*,
context: str,
- ) -> Mapping[str, object]:
- ...
+ ) -> Mapping[str, object]: ...
@staticmethod
@overload
@@ -178,8 +195,7 @@ class WeaviateStorage(BaseStorage):
*,
context: str,
allow_missing: Literal[False],
- ) -> Mapping[str, object]:
- ...
+ ) -> Mapping[str, object]: ...
@staticmethod
@overload
@@ -188,8 +204,7 @@ class WeaviateStorage(BaseStorage):
*,
context: str,
allow_missing: Literal[True],
- ) -> Mapping[str, object] | None:
- ...
+ ) -> Mapping[str, object] | None: ...
@staticmethod
def _coerce_properties(
@@ -211,6 +226,29 @@ class WeaviateStorage(BaseStorage):
return cast(Mapping[str, object], properties)
+ @staticmethod
+ def _build_document_properties(doc: Document) -> dict[str, object]:
+ """
+ Build Weaviate properties dict from document.
+
+ Args:
+ doc: Document to build properties for
+
+ Returns:
+ Properties dict suitable for Weaviate
+ """
+ return {
+ "content": doc.content,
+ "source_url": doc.metadata["source_url"],
+ "title": doc.metadata.get("title", ""),
+ "description": doc.metadata.get("description", ""),
+ "timestamp": doc.metadata["timestamp"].isoformat(),
+ "content_type": doc.metadata["content_type"],
+ "word_count": doc.metadata["word_count"],
+ "char_count": doc.metadata["char_count"],
+ "source": doc.source.value,
+ }
+
def _normalize_collection_name(self, collection_name: str | None) -> str:
"""Return a canonicalized collection name, defaulting to configured value."""
candidate = collection_name or self.config.collection_name
@@ -227,7 +265,7 @@ class WeaviateStorage(BaseStorage):
if not self.client:
raise StorageError("Weaviate client not initialized")
- client = cast(weaviate.WeaviateClient, self.client)
+ client = self.client
existing = client.collections.list_all()
if collection_name not in existing:
await self._create_collection(collection_name)
@@ -247,7 +285,7 @@ class WeaviateStorage(BaseStorage):
if ensure_exists:
await self._ensure_collection(normalized)
- client = cast(weaviate.WeaviateClient, self.client)
+ client = self.client
return client.collections.get(normalized), normalized
@override
@@ -271,26 +309,19 @@ class WeaviateStorage(BaseStorage):
)
# Prepare properties
- properties = {
- "content": document.content,
- "source_url": document.metadata["source_url"],
- "title": document.metadata.get("title", ""),
- "description": document.metadata.get("description", ""),
- "timestamp": document.metadata["timestamp"].isoformat(),
- "content_type": document.metadata["content_type"],
- "word_count": document.metadata["word_count"],
- "char_count": document.metadata["char_count"],
- "source": document.source.value,
- }
+ properties = self._build_document_properties(document)
# Insert with vector
- result = collection.data.insert(
- properties=properties, vector=document.vector, uuid=str(document.id)
+ result = await self._run_sync(
+ collection.data.insert,
+ properties=properties,
+ vector=document.vector,
+ uuid=str(document.id),
)
return str(result)
- except Exception as e:
+ except (WeaviateConnectionError, WeaviateBatchError, WeaviateQueryError) as e:
raise StorageError(f"Failed to store document: {e}") from e
@override
@@ -311,32 +342,24 @@ class WeaviateStorage(BaseStorage):
collection_name, ensure_exists=True
)
- # Vectorize documents without vectors
- for doc in documents:
- if doc.vector is None:
- doc.vector = await self.vectorizer.vectorize(doc.content)
+ # Vectorize documents without vectors using batch processing
+ to_vectorize = [(i, doc) for i, doc in enumerate(documents) if doc.vector is None]
+ if to_vectorize:
+ contents = [doc.content for _, doc in to_vectorize]
+ vectors = await self.vectorizer.vectorize_batch(contents)
+ for (idx, _), vector in zip(to_vectorize, vectors, strict=False):
+ documents[idx].vector = vector
# Prepare batch data for insert_many
batch_objects = []
for doc in documents:
- properties = {
- "content": doc.content,
- "source_url": doc.metadata["source_url"],
- "title": doc.metadata.get("title", ""),
- "description": doc.metadata.get("description", ""),
- "timestamp": doc.metadata["timestamp"].isoformat(),
- "content_type": doc.metadata["content_type"],
- "word_count": doc.metadata["word_count"],
- "char_count": doc.metadata["char_count"],
- "source": doc.source.value,
- }
-
+ properties = self._build_document_properties(doc)
batch_objects.append(
DataObject(properties=properties, vector=doc.vector, uuid=str(doc.id))
)
# Insert batch using insert_many
- response = collection.data.insert_many(batch_objects)
+ response = await self._run_sync(collection.data.insert_many, batch_objects)
successful_ids: list[str] = []
error_indices = set(response.errors.keys()) if response else set()
@@ -361,11 +384,7 @@ class WeaviateStorage(BaseStorage):
return successful_ids
- except WeaviateBatchError as e:
- raise StorageError(f"Batch operation failed: {e}") from e
- except WeaviateConnectionError as e:
- raise StorageError(f"Connection to Weaviate failed: {e}") from e
- except Exception as e:
+ except (WeaviateBatchError, WeaviateConnectionError, WeaviateQueryError) as e:
raise StorageError(f"Failed to store batch: {e}") from e
@override
@@ -385,7 +404,7 @@ class WeaviateStorage(BaseStorage):
collection, resolved_name = await self._prepare_collection(
collection_name, ensure_exists=False
)
- result = collection.query.fetch_object_by_id(document_id)
+ result = await self._run_sync(collection.query.fetch_object_by_id, document_id)
if not result:
return None
@@ -395,13 +414,30 @@ class WeaviateStorage(BaseStorage):
result.properties,
context="fetch_object_by_id",
)
+ # Parse timestamp to datetime for consistent metadata format
+ from datetime import UTC, datetime
+
+ timestamp_raw = props.get("timestamp")
+ timestamp_parsed: datetime
+ try:
+ if isinstance(timestamp_raw, str):
+ timestamp_parsed = datetime.fromisoformat(timestamp_raw)
+ if timestamp_parsed.tzinfo is None:
+ timestamp_parsed = timestamp_parsed.replace(tzinfo=UTC)
+ elif isinstance(timestamp_raw, datetime):
+ timestamp_parsed = timestamp_raw
+ if timestamp_parsed.tzinfo is None:
+ timestamp_parsed = timestamp_parsed.replace(tzinfo=UTC)
+ else:
+ timestamp_parsed = datetime.now(UTC)
+ except (ValueError, TypeError):
+ timestamp_parsed = datetime.now(UTC)
+
metadata_dict = {
"source_url": str(props["source_url"]),
"title": str(props.get("title")) if props.get("title") else None,
- "description": str(props.get("description"))
- if props.get("description")
- else None,
- "timestamp": str(props["timestamp"]),
+ "description": str(props.get("description")) if props.get("description") else None,
+ "timestamp": timestamp_parsed,
"content_type": str(props["content_type"]),
"word_count": int(str(props["word_count"])),
"char_count": int(str(props["char_count"])),
@@ -424,11 +460,13 @@ class WeaviateStorage(BaseStorage):
except WeaviateConnectionError as e:
# Connection issues should be logged and return None
import logging
+
logging.warning(f"Weaviate connection error retrieving document {document_id}: {e}")
return None
except Exception as e:
# Log unexpected errors for debugging
import logging
+
logging.warning(f"Unexpected error retrieving document {document_id}: {e}")
return None
@@ -437,9 +475,7 @@ class WeaviateStorage(BaseStorage):
metadata_dict = {
"source_url": str(props["source_url"]),
"title": str(props.get("title")) if props.get("title") else None,
- "description": str(props.get("description"))
- if props.get("description")
- else None,
+ "description": str(props.get("description")) if props.get("description") else None,
"timestamp": str(props["timestamp"]),
"content_type": str(props["content_type"]),
"word_count": int(str(props["word_count"])),
@@ -462,6 +498,7 @@ class WeaviateStorage(BaseStorage):
return max(0.0, 1.0 - distance_value)
except (TypeError, ValueError) as e:
import logging
+
logging.debug(f"Invalid distance value {raw_distance}: {e}")
return None
@@ -506,37 +543,39 @@ class WeaviateStorage(BaseStorage):
collection_name: str | None = None,
) -> AsyncGenerator[Document, None]:
"""
- Search for documents in Weaviate.
+ Search for documents in Weaviate using hybrid search.
Args:
query: Search query
limit: Maximum results
- threshold: Similarity threshold
+ threshold: Similarity threshold (not used in hybrid search)
Yields:
Matching documents
"""
try:
- query_vector = await self.vectorizer.vectorize(query)
+ if not self.client:
+ raise StorageError("Weaviate client not initialized")
+
collection, resolved_name = await self._prepare_collection(
collection_name, ensure_exists=False
)
- results = collection.query.near_vector(
- near_vector=query_vector,
- limit=limit,
- distance=1 - threshold,
- return_metadata=["distance"],
- )
+ # Try hybrid search first, fall back to BM25 keyword search
+ try:
+ response = await self._run_sync(
+ collection.query.hybrid, query=query, limit=limit, return_metadata=["score"]
+ )
+ except (WeaviateQueryError, StorageError):
+ # Fall back to BM25 if hybrid search is not supported or fails
+ response = await self._run_sync(
+ collection.query.bm25, query=query, limit=limit, return_metadata=["score"]
+ )
- for result in results.objects:
- yield self._build_search_document(result, resolved_name)
+ for obj in response.objects:
+ yield self._build_document_from_search(obj, resolved_name)
- except WeaviateQueryError as e:
- raise StorageError(f"Search query failed: {e}") from e
- except WeaviateConnectionError as e:
- raise StorageError(f"Connection to Weaviate failed during search: {e}") from e
- except Exception as e:
+ except (WeaviateQueryError, WeaviateConnectionError) as e:
raise StorageError(f"Search failed: {e}") from e
@override
@@ -552,7 +591,7 @@ class WeaviateStorage(BaseStorage):
"""
try:
collection, _ = await self._prepare_collection(collection_name, ensure_exists=False)
- collection.data.delete_by_id(document_id)
+ await self._run_sync(collection.data.delete_by_id, document_id)
return True
except WeaviateQueryError as e:
raise StorageError(f"Delete operation failed: {e}") from e
@@ -589,10 +628,10 @@ class WeaviateStorage(BaseStorage):
if not self.client:
raise StorageError("Weaviate client not initialized")
- client = cast(weaviate.WeaviateClient, self.client)
+ client = self.client
return list(client.collections.list_all())
- except Exception as e:
+ except WeaviateConnectionError as e:
raise StorageError(f"Failed to list collections: {e}") from e
async def describe_collections(self) -> list[CollectionSummary]:
@@ -601,7 +640,7 @@ class WeaviateStorage(BaseStorage):
raise StorageError("Weaviate client not initialized")
try:
- client = cast(weaviate.WeaviateClient, self.client)
+ client = self.client
collections: list[CollectionSummary] = []
for name in client.collections.list_all():
collection_obj = client.collections.get(name)
@@ -639,7 +678,7 @@ class WeaviateStorage(BaseStorage):
)
# Query for sample documents
- response = collection.query.fetch_objects(limit=limit)
+ response = await self._run_sync(collection.query.fetch_objects, limit=limit)
documents = []
for obj in response.objects:
@@ -712,9 +751,7 @@ class WeaviateStorage(BaseStorage):
return {
"source_url": str(props.get("source_url", "")),
"title": str(props.get("title", "")) if props.get("title") else None,
- "description": str(props.get("description", ""))
- if props.get("description")
- else None,
+ "description": str(props.get("description", "")) if props.get("description") else None,
"timestamp": datetime.fromisoformat(
str(props.get("timestamp", datetime.now(UTC).isoformat()))
),
@@ -737,6 +774,7 @@ class WeaviateStorage(BaseStorage):
return float(raw_score)
except (TypeError, ValueError) as e:
import logging
+
logging.debug(f"Invalid score value {raw_score}: {e}")
return None
@@ -780,31 +818,11 @@ class WeaviateStorage(BaseStorage):
Returns:
List of matching documents
"""
- try:
- if not self.client:
- raise StorageError("Weaviate client not initialized")
-
- collection, resolved_name = await self._prepare_collection(
- collection_name, ensure_exists=False
- )
-
- # Try hybrid search first, fall back to BM25 keyword search
- try:
- response = collection.query.hybrid(
- query=query, limit=limit, return_metadata=["score"]
- )
- except Exception:
- response = collection.query.bm25(
- query=query, limit=limit, return_metadata=["score"]
- )
-
- return [
- self._build_document_from_search(obj, resolved_name)
- for obj in response.objects
- ]
-
- except Exception as e:
- raise StorageError(f"Failed to search documents: {e}") from e
+ # Delegate to the unified search method
+ results = []
+ async for document in self.search(query, limit=limit, collection_name=collection_name):
+ results.append(document)
+ return results
async def list_documents(
self,
@@ -830,8 +848,11 @@ class WeaviateStorage(BaseStorage):
collection, _ = await self._prepare_collection(collection_name, ensure_exists=False)
# Query documents with pagination
- response = collection.query.fetch_objects(
- limit=limit, offset=offset, return_metadata=["creation_time"]
+ response = await self._run_sync(
+ collection.query.fetch_objects,
+ limit=limit,
+ offset=offset,
+ return_metadata=["creation_time"],
)
documents: list[DocumentInfo] = []
@@ -896,7 +917,9 @@ class WeaviateStorage(BaseStorage):
)
delete_filter = Filter.by_id().contains_any(document_ids)
- response = collection.data.delete_many(where=delete_filter, verbose=True)
+ response = await self._run_sync(
+ collection.data.delete_many, where=delete_filter, verbose=True
+ )
if objects := getattr(response, "objects", None):
for result_obj in objects:
@@ -938,20 +961,22 @@ class WeaviateStorage(BaseStorage):
# Get documents matching filter
if where_filter:
- response = collection.query.fetch_objects(
+ response = await self._run_sync(
+ collection.query.fetch_objects,
filters=where_filter,
limit=1000, # Max batch size
)
else:
- response = collection.query.fetch_objects(
- limit=1000 # Max batch size
+ response = await self._run_sync(
+ collection.query.fetch_objects,
+ limit=1000, # Max batch size
)
# Delete matching documents
deleted_count = 0
for obj in response.objects:
try:
- collection.data.delete_by_id(obj.uuid)
+ await self._run_sync(collection.data.delete_by_id, obj.uuid)
deleted_count += 1
except Exception:
continue
@@ -975,7 +1000,7 @@ class WeaviateStorage(BaseStorage):
target = self._normalize_collection_name(collection_name)
# Delete the collection using the client's collections API
- client = cast(weaviate.WeaviateClient, self.client)
+ client = self.client
client.collections.delete(target)
return True
@@ -997,20 +1022,29 @@ class WeaviateStorage(BaseStorage):
await self.close()
async def close(self) -> None:
- """Close client connection."""
+ """Close client connection and vectorizer HTTP client."""
if self.client:
try:
- client = cast(weaviate.WeaviateClient, self.client)
+ client = self.client
client.close()
- except Exception as e:
+ except (WeaviateConnectionError, AttributeError) as e:
import logging
+
logging.warning(f"Error closing Weaviate client: {e}")
+ # Close vectorizer HTTP client to prevent resource leaks
+ try:
+ await self.vectorizer.close()
+ except (AttributeError, OSError) as e:
+ import logging
+
+ logging.warning(f"Error closing vectorizer client: {e}")
+
def __del__(self) -> None:
"""Clean up client connection as fallback."""
if self.client:
try:
- client = cast(weaviate.WeaviateClient, self.client)
+ client = self.client
client.close()
except Exception:
pass # Ignore errors in destructor
diff --git a/ingest_pipeline/utils/__pycache__/__init__.cpython-312.pyc b/ingest_pipeline/utils/__pycache__/__init__.cpython-312.pyc
index 1ecac22..c6334c5 100644
Binary files a/ingest_pipeline/utils/__pycache__/__init__.cpython-312.pyc and b/ingest_pipeline/utils/__pycache__/__init__.cpython-312.pyc differ
diff --git a/ingest_pipeline/utils/__pycache__/metadata_tagger.cpython-312.pyc b/ingest_pipeline/utils/__pycache__/metadata_tagger.cpython-312.pyc
index 88ac753..daf5fc0 100644
Binary files a/ingest_pipeline/utils/__pycache__/metadata_tagger.cpython-312.pyc and b/ingest_pipeline/utils/__pycache__/metadata_tagger.cpython-312.pyc differ
diff --git a/ingest_pipeline/utils/__pycache__/vectorizer.cpython-312.pyc b/ingest_pipeline/utils/__pycache__/vectorizer.cpython-312.pyc
index 48c83db..d822e83 100644
Binary files a/ingest_pipeline/utils/__pycache__/vectorizer.cpython-312.pyc and b/ingest_pipeline/utils/__pycache__/vectorizer.cpython-312.pyc differ
diff --git a/ingest_pipeline/utils/async_helpers.py b/ingest_pipeline/utils/async_helpers.py
new file mode 100644
index 0000000..d4593fb
--- /dev/null
+++ b/ingest_pipeline/utils/async_helpers.py
@@ -0,0 +1,117 @@
+"""Async utilities for task management and backpressure control."""
+
+import asyncio
+import logging
+from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable
+from contextlib import asynccontextmanager
+from typing import Final, TypeVar
+
+LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
+
+T = TypeVar("T")
+
+
+class AsyncTaskManager:
+ """Manages concurrent tasks with backpressure control."""
+
+ def __init__(self, max_concurrent: int = 10):
+ """
+ Initialize task manager.
+
+ Args:
+ max_concurrent: Maximum number of concurrent tasks
+ """
+ self.semaphore = asyncio.Semaphore(max_concurrent)
+ self.max_concurrent = max_concurrent
+
+ @asynccontextmanager
+ async def acquire(self) -> AsyncGenerator[None, None]:
+ """Acquire a slot for task execution."""
+ async with self.semaphore:
+ yield
+
+ async def run_tasks(
+ self, tasks: Iterable[Awaitable[T]], return_exceptions: bool = False
+ ) -> list[T | BaseException]:
+ """
+ Run multiple tasks with backpressure control.
+
+ Args:
+ tasks: Iterable of awaitable tasks
+ return_exceptions: Whether to return exceptions or raise them
+
+ Returns:
+ List of task results or exceptions
+ """
+
+ async def _controlled_task(task: Awaitable[T]) -> T:
+ async with self.acquire():
+ return await task
+
+ controlled_tasks = [_controlled_task(task) for task in tasks]
+
+ if return_exceptions:
+ results = await asyncio.gather(*controlled_tasks, return_exceptions=True)
+ return list(results)
+ else:
+ results = await asyncio.gather(*controlled_tasks)
+ return list(results)
+
+ async def map_async(
+ self, func: Callable[[T], Awaitable[T]], items: Iterable[T], return_exceptions: bool = False
+ ) -> list[T | BaseException]:
+ """
+ Apply async function to items with backpressure control.
+
+ Args:
+ func: Async function to apply
+ items: Items to process
+ return_exceptions: Whether to return exceptions or raise them
+
+ Returns:
+ List of processed results or exceptions
+ """
+ tasks = [func(item) for item in items]
+ return await self.run_tasks(tasks, return_exceptions=return_exceptions)
+
+
+async def run_with_semaphore(semaphore: asyncio.Semaphore, coro: Awaitable[T]) -> T:
+ """Run coroutine with semaphore-controlled concurrency."""
+ async with semaphore:
+ return await coro
+
+
+async def batch_process(
+ items: list[T],
+ processor: Callable[[T], Awaitable[T]],
+ batch_size: int = 50,
+ max_concurrent: int = 5,
+) -> list[T]:
+ """
+ Process items in batches with controlled concurrency.
+
+ Args:
+ items: Items to process
+ processor: Async function to process each item
+ batch_size: Number of items per batch
+ max_concurrent: Maximum concurrent tasks per batch
+
+ Returns:
+ List of processed results
+ """
+ task_manager = AsyncTaskManager(max_concurrent)
+ results: list[T] = []
+
+ for i in range(0, len(items), batch_size):
+ batch = items[i : i + batch_size]
+ LOGGER.debug(
+ "Processing batch %d-%d of %d items", i, min(i + batch_size, len(items)), len(items)
+ )
+
+ batch_results = await task_manager.map_async(processor, batch, return_exceptions=False)
+ # If return_exceptions=False, exceptions would have been raised, so all results are successful
+ # Type checker doesn't know this, so we need to cast
+ successful_results: list[T] = [r for r in batch_results if not isinstance(r, BaseException)]
+ results.extend(successful_results)
+
+ return results
diff --git a/ingest_pipeline/utils/metadata_tagger.py b/ingest_pipeline/utils/metadata_tagger.py
index 9beb2d2..ebc9318 100644
--- a/ingest_pipeline/utils/metadata_tagger.py
+++ b/ingest_pipeline/utils/metadata_tagger.py
@@ -6,12 +6,12 @@ from typing import Final, Protocol, TypedDict, cast
import httpx
+from ..config import get_settings
from ..core.exceptions import IngestionError
from ..core.models import Document
JSON_CONTENT_TYPE: Final[str] = "application/json"
AUTHORIZATION_HEADER: Final[str] = "Authorization"
-from ..config import get_settings
class HttpResponse(Protocol):
@@ -24,12 +24,7 @@ class HttpResponse(Protocol):
class AsyncHttpClient(Protocol):
"""Protocol for async HTTP client."""
- async def post(
- self,
- url: str,
- *,
- json: dict[str, object] | None = None
- ) -> HttpResponse: ...
+ async def post(self, url: str, *, json: dict[str, object] | None = None) -> HttpResponse: ...
async def aclose(self) -> None: ...
@@ -45,16 +40,19 @@ class AsyncHttpClient(Protocol):
class LlmResponse(TypedDict):
"""Type for LLM API response structure."""
+
choices: list[dict[str, object]]
class LlmChoice(TypedDict):
"""Type for individual choice in LLM response."""
+
message: dict[str, object]
class LlmMessage(TypedDict):
"""Type for message in LLM choice."""
+
content: str
@@ -96,7 +94,7 @@ class MetadataTagger:
"""
settings = get_settings()
endpoint_value = llm_endpoint or str(settings.llm_endpoint)
- self.endpoint = endpoint_value.rstrip('/')
+ self.endpoint = endpoint_value.rstrip("/")
self.model = model or settings.metadata_model
resolved_timeout = timeout if timeout is not None else float(settings.request_timeout)
@@ -161,12 +159,16 @@ class MetadataTagger:
# Build a proper DocumentMetadata instance with only valid keys
new_metadata: CoreDocumentMetadata = {
"source_url": str(updated_metadata.get("source_url", "")),
- "timestamp": (
- lambda ts: ts if isinstance(ts, datetime) else datetime.now(UTC)
- )(updated_metadata.get("timestamp", datetime.now(UTC))),
+ "timestamp": (lambda ts: ts if isinstance(ts, datetime) else datetime.now(UTC))(
+ updated_metadata.get("timestamp", datetime.now(UTC))
+ ),
"content_type": str(updated_metadata.get("content_type", "text/plain")),
- "word_count": (lambda wc: int(wc) if isinstance(wc, (int, str)) else 0)(updated_metadata.get("word_count", 0)),
- "char_count": (lambda cc: int(cc) if isinstance(cc, (int, str)) else 0)(updated_metadata.get("char_count", 0)),
+ "word_count": (lambda wc: int(wc) if isinstance(wc, (int, str)) else 0)(
+ updated_metadata.get("word_count", 0)
+ ),
+ "char_count": (lambda cc: int(cc) if isinstance(cc, (int, str)) else 0)(
+ updated_metadata.get("char_count", 0)
+ ),
}
# Add optional fields if they exist
diff --git a/ingest_pipeline/utils/vectorizer.py b/ingest_pipeline/utils/vectorizer.py
index adc63f6..87e597f 100644
--- a/ingest_pipeline/utils/vectorizer.py
+++ b/ingest_pipeline/utils/vectorizer.py
@@ -1,20 +1,79 @@
"""Vectorizer utility for generating embeddings."""
+import asyncio
from types import TracebackType
-from typing import Final, Self, cast
+from typing import Final, NotRequired, Self, TypedDict
import httpx
-from typings import EmbeddingResponse
-
+from ..config import get_settings
from ..core.exceptions import VectorizationError
from ..core.models import StorageConfig, VectorConfig
-from ..config import get_settings
JSON_CONTENT_TYPE: Final[str] = "application/json"
AUTHORIZATION_HEADER: Final[str] = "Authorization"
+class EmbeddingData(TypedDict):
+ """Structure for embedding data from providers."""
+
+ embedding: list[float]
+ index: NotRequired[int]
+ object: NotRequired[str]
+
+
+class EmbeddingResponse(TypedDict):
+ """Embedding response format for multiple providers."""
+
+ data: list[EmbeddingData]
+ model: NotRequired[str]
+ object: NotRequired[str]
+ usage: NotRequired[dict[str, int]]
+ # Alternative formats
+ embedding: NotRequired[list[float]]
+ vector: NotRequired[list[float]]
+ embeddings: NotRequired[list[list[float]]]
+
+
+def _extract_embedding_from_response(response_data: dict[str, object]) -> list[float]:
+ """Extract embedding vector from provider response."""
+ # OpenAI/Ollama format: {"data": [{"embedding": [...]}]}
+ if "data" in response_data:
+ data_list = response_data["data"]
+ if isinstance(data_list, list) and data_list:
+ first_item = data_list[0]
+ if isinstance(first_item, dict) and "embedding" in first_item:
+ embedding = first_item["embedding"]
+ if isinstance(embedding, list) and all(
+ isinstance(x, (int, float)) for x in embedding
+ ):
+ return [float(x) for x in embedding]
+
+ # Direct embedding format: {"embedding": [...]}
+ if "embedding" in response_data:
+ embedding = response_data["embedding"]
+ if isinstance(embedding, list) and all(isinstance(x, (int, float)) for x in embedding):
+ return [float(x) for x in embedding]
+
+ # Vector format: {"vector": [...]}
+ if "vector" in response_data:
+ vector = response_data["vector"]
+ if isinstance(vector, list) and all(isinstance(x, (int, float)) for x in vector):
+ return [float(x) for x in vector]
+
+ # Embeddings array format: {"embeddings": [[...]]}
+ if "embeddings" in response_data:
+ embeddings = response_data["embeddings"]
+ if isinstance(embeddings, list) and embeddings:
+ first_embedding = embeddings[0]
+ if isinstance(first_embedding, list) and all(
+ isinstance(x, (int, float)) for x in first_embedding
+ ):
+ return [float(x) for x in first_embedding]
+
+ raise VectorizationError("Unrecognized embedding response format")
+
+
class Vectorizer:
"""Handles text vectorization using LLM endpoints."""
@@ -72,21 +131,34 @@ class Vectorizer:
async def vectorize_batch(self, texts: list[str]) -> list[list[float]]:
"""
- Generate embeddings for multiple texts.
+ Generate embeddings for multiple texts in parallel.
Args:
texts: List of texts to vectorize
Returns:
List of embedding vectors
+
+ Raises:
+ VectorizationError: If any vectorization fails
"""
- vectors: list[list[float]] = []
- for text in texts:
- vector = await self.vectorize(text)
- vectors.append(vector)
+ if not texts:
+ return []
- return vectors
+ # Use semaphore to limit concurrent requests and prevent overwhelming the endpoint
+ semaphore = asyncio.Semaphore(20)
+
+ async def vectorize_with_semaphore(text: str) -> list[float]:
+ async with semaphore:
+ return await self.vectorize(text)
+
+ try:
+ # Execute all vectorization requests concurrently
+ vectors = await asyncio.gather(*[vectorize_with_semaphore(text) for text in texts])
+ return list(vectors)
+ except Exception as e:
+ raise VectorizationError(f"Batch vectorization failed: {e}") from e
async def _ollama_embed(self, text: str) -> list[float]:
"""
@@ -112,23 +184,12 @@ class Vectorizer:
_ = response.raise_for_status()
response_json = response.json()
- # Response is expected to be dict[str, object] from our type stub
+ if not isinstance(response_json, dict):
+ raise VectorizationError("Invalid JSON response format")
- response_data = cast(EmbeddingResponse, cast(object, response_json))
+ # Extract embedding using type-safe helper
+ embedding = _extract_embedding_from_response(response_json)
- # Parse OpenAI-compatible response format
- embeddings_list = response_data.get("data", [])
- if not embeddings_list:
- raise VectorizationError("No embeddings returned")
-
- first_embedding = embeddings_list[0]
- embedding_raw = first_embedding.get("embedding")
- if not embedding_raw:
- raise VectorizationError("Invalid embedding format")
-
- # Convert to float list and validate
- embedding: list[float] = []
- embedding.extend(float(item) for item in embedding_raw)
# Ensure correct dimension
if len(embedding) != self.dimension:
raise VectorizationError(
@@ -157,22 +218,12 @@ class Vectorizer:
_ = response.raise_for_status()
response_json = response.json()
- # Response is expected to be dict[str, object] from our type stub
+ if not isinstance(response_json, dict):
+ raise VectorizationError("Invalid JSON response format")
- response_data = cast(EmbeddingResponse, cast(object, response_json))
+ # Extract embedding using type-safe helper
+ embedding = _extract_embedding_from_response(response_json)
- embeddings_list = response_data.get("data", [])
- if not embeddings_list:
- raise VectorizationError("No embeddings returned")
-
- first_embedding = embeddings_list[0]
- embedding_raw = first_embedding.get("embedding")
- if not embedding_raw:
- raise VectorizationError("Invalid embedding format")
-
- # Convert to float list and validate
- embedding: list[float] = []
- embedding.extend(float(item) for item in embedding_raw)
# Ensure correct dimension
if len(embedding) != self.dimension:
raise VectorizationError(
@@ -185,6 +236,14 @@ class Vectorizer:
"""Async context manager entry."""
return self
+ async def close(self) -> None:
+ """Close the HTTP client connection."""
+ try:
+ await self.client.aclose()
+ except Exception:
+ # Already closed or connection lost
+ pass
+
async def __aexit__(
self,
exc_type: type[BaseException] | None,
@@ -192,4 +251,4 @@ class Vectorizer:
exc_tb: TracebackType | None,
) -> None:
"""Async context manager exit."""
- await self.client.aclose()
+ await self.close()
diff --git a/notebooks/welcome.py b/notebooks/welcome.py
new file mode 100644
index 0000000..4b7ad91
--- /dev/null
+++ b/notebooks/welcome.py
@@ -0,0 +1,27 @@
+import marimo
+
+__generated_with = "0.16.0"
+app = marimo.App()
+
+
+@app.cell
+def __():
+ import marimo as mo
+
+ return (mo,)
+
+
+@app.cell
+def __(mo):
+ mo.md("# Welcome to Marimo!")
+ return
+
+
+@app.cell
+def __(mo):
+ mo.md("This is your interactive notebook environment.")
+ return
+
+
+if __name__ == "__main__":
+ app.run()
diff --git a/repomix-output.xml b/repomix-output.xml
index afa1c62..35e2616 100644
--- a/repomix-output.xml
+++ b/repomix-output.xml
@@ -4,7 +4,7 @@ This file is a merged representation of a subset of the codebase, containing spe
This section contains a summary of this file.
-This file contains a packed representation of the entire repository's contents.
+This file contains a packed representation of a subset of the repository's contents that is considered the most important context.
It is designed to be easily consumable by AI systems for analysis, code review,
or other automated processes.
@@ -14,7 +14,8 @@ The content is organized as follows:
1. This summary section
2. Repository information
3. Directory structure
-4. Repository files, each consisting of:
+4. Repository files (if enabled)
+5. Multiple file entries, each consisting of:
- File path as an attribute
- Full contents of the file
@@ -37,10 +38,6 @@ The content is organized as follows:
- Files are sorted by Git change count (files with more changes are at the bottom)
-
-
-
-
@@ -99,9 +96,11 @@ ingest_pipeline/
__init__.py
base.py
openwebui.py
+ types.py
weaviate.py
utils/
__init__.py
+ async_helpers.py
metadata_tagger.py
vectorizer.py
__main__.py
@@ -110,6 +109,126 @@ ingest_pipeline/
This section contains the contents of the repository's files.
+
+"""Async utilities for task management and backpressure control."""
+
+import asyncio
+import logging
+from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable
+from contextlib import asynccontextmanager
+from typing import Final, TypeVar
+
+LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
+
+T = TypeVar("T")
+
+
+class AsyncTaskManager:
+ """Manages concurrent tasks with backpressure control."""
+
+ def __init__(self, max_concurrent: int = 10):
+ """
+ Initialize task manager.
+
+ Args:
+ max_concurrent: Maximum number of concurrent tasks
+ """
+ self.semaphore = asyncio.Semaphore(max_concurrent)
+ self.max_concurrent = max_concurrent
+
+ @asynccontextmanager
+ async def acquire(self) -> AsyncGenerator[None, None]:
+ """Acquire a slot for task execution."""
+ async with self.semaphore:
+ yield
+
+ async def run_tasks(
+ self, tasks: Iterable[Awaitable[T]], return_exceptions: bool = False
+ ) -> list[T | BaseException]:
+ """
+ Run multiple tasks with backpressure control.
+
+ Args:
+ tasks: Iterable of awaitable tasks
+ return_exceptions: Whether to return exceptions or raise them
+
+ Returns:
+ List of task results or exceptions
+ """
+
+ async def _controlled_task(task: Awaitable[T]) -> T:
+ async with self.acquire():
+ return await task
+
+ controlled_tasks = [_controlled_task(task) for task in tasks]
+
+ if return_exceptions:
+ results = await asyncio.gather(*controlled_tasks, return_exceptions=True)
+ return list(results)
+ else:
+ results = await asyncio.gather(*controlled_tasks)
+ return list(results)
+
+ async def map_async(
+ self, func: Callable[[T], Awaitable[T]], items: Iterable[T], return_exceptions: bool = False
+ ) -> list[T | BaseException]:
+ """
+ Apply async function to items with backpressure control.
+
+ Args:
+ func: Async function to apply
+ items: Items to process
+ return_exceptions: Whether to return exceptions or raise them
+
+ Returns:
+ List of processed results or exceptions
+ """
+ tasks = [func(item) for item in items]
+ return await self.run_tasks(tasks, return_exceptions=return_exceptions)
+
+
+async def run_with_semaphore(semaphore: asyncio.Semaphore, coro: Awaitable[T]) -> T:
+ """Run coroutine with semaphore-controlled concurrency."""
+ async with semaphore:
+ return await coro
+
+
+async def batch_process(
+ items: list[T],
+ processor: Callable[[T], Awaitable[T]],
+ batch_size: int = 50,
+ max_concurrent: int = 5,
+) -> list[T]:
+ """
+ Process items in batches with controlled concurrency.
+
+ Args:
+ items: Items to process
+ processor: Async function to process each item
+ batch_size: Number of items per batch
+ max_concurrent: Maximum concurrent tasks per batch
+
+ Returns:
+ List of processed results
+ """
+ task_manager = AsyncTaskManager(max_concurrent)
+ results: list[T] = []
+
+ for i in range(0, len(items), batch_size):
+ batch = items[i : i + batch_size]
+ LOGGER.debug(
+ "Processing batch %d-%d of %d items", i, min(i + batch_size, len(items)), len(items)
+ )
+
+ batch_results = await task_manager.map_async(processor, batch, return_exceptions=False)
+ # If return_exceptions=False, exceptions would have been raised, so all results are successful
+ # Type checker doesn't know this, so we need to cast
+ successful_results: list[T] = [r for r in batch_results if not isinstance(r, BaseException)]
+ results.extend(successful_results)
+
+ return results
+
+
"""Prefect Automations for ingestion pipeline monitoring and management."""
@@ -132,7 +251,6 @@ actions:
source: inferred
enabled: true
""",
-
"retry_failed": """
name: Retry Failed Ingestion Flows
description: Retries failed ingestion flows with original parameters
@@ -152,7 +270,6 @@ actions:
validate_first: false
enabled: true
""",
-
"resource_monitoring": """
name: Manage Work Pool Based on Resources
description: Pauses work pool when system resources are constrained
@@ -519,6 +636,33 @@ from .storage import R2RStorage
__all__ = ["R2RStorage"]
+
+"""Shared types for storage adapters."""
+
+from typing import TypedDict
+
+
+class CollectionSummary(TypedDict):
+ """Collection metadata for describe_collections."""
+
+ name: str
+ count: int
+ size_mb: float
+
+
+class DocumentInfo(TypedDict):
+ """Document information for list_documents."""
+
+ id: str
+ title: str
+ source_url: str
+ description: str
+ content_type: str
+ content_preview: str
+ word_count: int
+ timestamp: str
+
+
"""Utility modules."""
@@ -592,7 +736,7 @@ class BaseScreen(Screen[object]):
name: str | None = None,
id: str | None = None,
classes: str | None = None,
- **kwargs: object
+ **kwargs: object,
) -> None:
"""Initialize base screen."""
super().__init__(name=name, id=id, classes=classes)
@@ -618,7 +762,7 @@ class CRUDScreen(BaseScreen, Generic[T]):
name: str | None = None,
id: str | None = None,
classes: str | None = None,
- **kwargs: object
+ **kwargs: object,
) -> None:
"""Initialize CRUD screen."""
super().__init__(storage_manager, name=name, id=id, classes=classes)
@@ -874,7 +1018,7 @@ class FormScreen(ModalScreen[T], Generic[T]):
name: str | None = None,
id: str | None = None,
classes: str | None = None,
- **kwargs: object
+ **kwargs: object,
) -> None:
"""Initialize form screen."""
super().__init__(name=name, id=id, classes=classes)
@@ -947,503 +1091,6 @@ class FormScreen(ModalScreen[T], Generic[T]):
self.dismiss(None)
-
-"""Storage management utilities for TUI applications."""
-
-
-from __future__ import annotations
-
-import asyncio
-from collections.abc import AsyncGenerator, Sequence
-from typing import TYPE_CHECKING, Protocol
-
-from ....core.exceptions import StorageError
-from ....core.models import Document, StorageBackend, StorageConfig
-from ..models import CollectionInfo, StorageCapabilities
-
-from ....storage.base import BaseStorage
-from ....storage.openwebui import OpenWebUIStorage
-from ....storage.r2r.storage import R2RStorage
-from ....storage.weaviate import WeaviateStorage
-
-if TYPE_CHECKING:
- from ....config.settings import Settings
-
-
-class StorageBackendProtocol(Protocol):
- """Protocol defining storage backend interface."""
-
- async def initialize(self) -> None: ...
- async def count(self, *, collection_name: str | None = None) -> int: ...
- async def list_collections(self) -> list[str]: ...
- async def search(
- self,
- query: str,
- limit: int = 10,
- threshold: float = 0.7,
- *,
- collection_name: str | None = None,
- ) -> AsyncGenerator[Document, None]: ...
- async def close(self) -> None: ...
-
-
-
-class MultiStorageAdapter(BaseStorage):
- """Mirror writes to multiple storage backends."""
-
- def __init__(self, storages: Sequence[BaseStorage]) -> None:
- if not storages:
- raise ValueError("MultiStorageAdapter requires at least one storage backend")
-
- unique: list[BaseStorage] = []
- seen_ids: set[int] = set()
- for storage in storages:
- storage_id = id(storage)
- if storage_id in seen_ids:
- continue
- seen_ids.add(storage_id)
- unique.append(storage)
-
- self._storages = unique
- self._primary = unique[0]
- super().__init__(self._primary.config)
-
- async def initialize(self) -> None:
- for storage in self._storages:
- await storage.initialize()
-
- async def store(self, document: Document, *, collection_name: str | None = None) -> str:
- # Store in primary backend first
- primary_id: str = await self._primary.store(document, collection_name=collection_name)
-
- # Replicate to secondary backends concurrently
- if len(self._storages) > 1:
- async def replicate_to_backend(storage: BaseStorage) -> tuple[BaseStorage, bool, Exception | None]:
- try:
- await storage.store(document, collection_name=collection_name)
- return storage, True, None
- except Exception as exc:
- return storage, False, exc
-
- tasks = [replicate_to_backend(storage) for storage in self._storages[1:]]
- results = await asyncio.gather(*tasks, return_exceptions=True)
-
- failures: list[str] = []
- errors: list[Exception] = []
-
- for result in results:
- if isinstance(result, tuple):
- storage, success, error = result
- if not success and error is not None:
- failures.append(self._format_backend_label(storage))
- errors.append(error)
- elif isinstance(result, Exception):
- failures.append("unknown")
- errors.append(result)
-
- if failures:
- backends = ", ".join(failures)
- primary_error = errors[0] if errors else Exception("Unknown replication error")
- raise StorageError(
- f"Document stored in primary backend but replication failed for: {backends}"
- ) from primary_error
-
- return primary_id
-
- async def store_batch(
- self, documents: list[Document], *, collection_name: str | None = None
- ) -> list[str]:
- # Store in primary backend first
- primary_ids: list[str] = await self._primary.store_batch(documents, collection_name=collection_name)
-
- # Replicate to secondary backends concurrently
- if len(self._storages) > 1:
- async def replicate_batch_to_backend(storage: BaseStorage) -> tuple[BaseStorage, bool, Exception | None]:
- try:
- await storage.store_batch(documents, collection_name=collection_name)
- return storage, True, None
- except Exception as exc:
- return storage, False, exc
-
- tasks = [replicate_batch_to_backend(storage) for storage in self._storages[1:]]
- results = await asyncio.gather(*tasks, return_exceptions=True)
-
- failures: list[str] = []
- errors: list[Exception] = []
-
- for result in results:
- if isinstance(result, tuple):
- storage, success, error = result
- if not success and error is not None:
- failures.append(self._format_backend_label(storage))
- errors.append(error)
- elif isinstance(result, Exception):
- failures.append("unknown")
- errors.append(result)
-
- if failures:
- backends = ", ".join(failures)
- primary_error = errors[0] if errors else Exception("Unknown batch replication error")
- raise StorageError(
- f"Batch stored in primary backend but replication failed for: {backends}"
- ) from primary_error
-
- return primary_ids
-
- async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool:
- # Delete from primary backend first
- primary_deleted: bool = await self._primary.delete(document_id, collection_name=collection_name)
-
- # Delete from secondary backends concurrently
- if len(self._storages) > 1:
- async def delete_from_backend(storage: BaseStorage) -> tuple[BaseStorage, bool, Exception | None]:
- try:
- await storage.delete(document_id, collection_name=collection_name)
- return storage, True, None
- except Exception as exc:
- return storage, False, exc
-
- tasks = [delete_from_backend(storage) for storage in self._storages[1:]]
- results = await asyncio.gather(*tasks, return_exceptions=True)
-
- failures: list[str] = []
- errors: list[Exception] = []
-
- for result in results:
- if isinstance(result, tuple):
- storage, success, error = result
- if not success and error is not None:
- failures.append(self._format_backend_label(storage))
- errors.append(error)
- elif isinstance(result, Exception):
- failures.append("unknown")
- errors.append(result)
-
- if failures:
- backends = ", ".join(failures)
- primary_error = errors[0] if errors else Exception("Unknown deletion error")
- raise StorageError(
- f"Document deleted from primary backend but failed for: {backends}"
- ) from primary_error
-
- return primary_deleted
-
- async def count(self, *, collection_name: str | None = None) -> int:
- count_result: int = await self._primary.count(collection_name=collection_name)
- return count_result
-
- async def list_collections(self) -> list[str]:
- list_fn = getattr(self._primary, "list_collections", None)
- if list_fn is None:
- return []
- collections_result: list[str] = await list_fn()
- return collections_result
-
- async def search(
- self,
- query: str,
- limit: int = 10,
- threshold: float = 0.7,
- *,
- collection_name: str | None = None,
- ) -> AsyncGenerator[Document, None]:
- async for item in self._primary.search(
- query,
- limit=limit,
- threshold=threshold,
- collection_name=collection_name,
- ):
- yield item
-
- async def close(self) -> None:
- for storage in self._storages:
- close_fn = getattr(storage, "close", None)
- if close_fn is not None:
- await close_fn()
-
- def _format_backend_label(self, storage: BaseStorage) -> str:
- backend = getattr(storage.config, "backend", None)
- if isinstance(backend, StorageBackend):
- backend_value: str = backend.value
- return backend_value
- class_name: str = storage.__class__.__name__
- return class_name
-
-
-
-class StorageManager:
- """Centralized manager for all storage backend operations."""
-
- def __init__(self, settings: Settings) -> None:
- """Initialize storage manager with application settings."""
- self.settings = settings
- self.backends: dict[StorageBackend, BaseStorage] = {}
- self.capabilities: dict[StorageBackend, StorageCapabilities] = {}
- self._initialized = False
-
- async def initialize_all_backends(self) -> dict[StorageBackend, bool]:
- """Initialize all available storage backends with timeout protection."""
- results: dict[StorageBackend, bool] = {}
-
- async def init_backend(backend_type: StorageBackend, config: StorageConfig, storage_class: type[BaseStorage]) -> bool:
- """Initialize a single backend with timeout."""
- try:
- storage = storage_class(config)
- await asyncio.wait_for(storage.initialize(), timeout=30.0)
- self.backends[backend_type] = storage
- if backend_type == StorageBackend.WEAVIATE:
- self.capabilities[backend_type] = StorageCapabilities.VECTOR_SEARCH
- elif backend_type == StorageBackend.OPEN_WEBUI:
- self.capabilities[backend_type] = StorageCapabilities.KNOWLEDGE_BASE
- elif backend_type == StorageBackend.R2R:
- self.capabilities[backend_type] = StorageCapabilities.FULL_FEATURED
- return True
- except (TimeoutError, Exception):
- return False
-
- # Initialize backends concurrently with timeout protection
- tasks = []
-
- # Try Weaviate
- if self.settings.weaviate_endpoint:
- config = StorageConfig(
- backend=StorageBackend.WEAVIATE,
- endpoint=self.settings.weaviate_endpoint,
- api_key=self.settings.weaviate_api_key,
- collection_name="default",
- )
- tasks.append((StorageBackend.WEAVIATE, init_backend(StorageBackend.WEAVIATE, config, WeaviateStorage)))
- else:
- results[StorageBackend.WEAVIATE] = False
-
- # Try OpenWebUI
- if self.settings.openwebui_endpoint and self.settings.openwebui_api_key:
- config = StorageConfig(
- backend=StorageBackend.OPEN_WEBUI,
- endpoint=self.settings.openwebui_endpoint,
- api_key=self.settings.openwebui_api_key,
- collection_name="default",
- )
- tasks.append((StorageBackend.OPEN_WEBUI, init_backend(StorageBackend.OPEN_WEBUI, config, OpenWebUIStorage)))
- else:
- results[StorageBackend.OPEN_WEBUI] = False
-
- # Try R2R
- if self.settings.r2r_endpoint:
- config = StorageConfig(
- backend=StorageBackend.R2R,
- endpoint=self.settings.r2r_endpoint,
- api_key=self.settings.r2r_api_key,
- collection_name="default",
- )
- tasks.append((StorageBackend.R2R, init_backend(StorageBackend.R2R, config, R2RStorage)))
- else:
- results[StorageBackend.R2R] = False
-
- # Execute initialization tasks concurrently
- if tasks:
- backend_types, task_coroutines = zip(*tasks, strict=False)
- task_results = await asyncio.gather(*task_coroutines, return_exceptions=True)
-
- for backend_type, task_result in zip(backend_types, task_results, strict=False):
- results[backend_type] = task_result if isinstance(task_result, bool) else False
- self._initialized = True
- return results
-
- def get_backend(self, backend_type: StorageBackend) -> BaseStorage | None:
- """Get storage backend by type."""
- return self.backends.get(backend_type)
-
- def build_multi_storage_adapter(
- self, backends: Sequence[StorageBackend]
- ) -> MultiStorageAdapter:
- storages: list[BaseStorage] = []
- seen: set[StorageBackend] = set()
- for backend in backends:
- backend_enum = backend if isinstance(backend, StorageBackend) else StorageBackend(backend)
- if backend_enum in seen:
- continue
- seen.add(backend_enum)
- storage = self.backends.get(backend_enum)
- if storage is None:
- raise ValueError(f"Storage backend {backend_enum.value} is not initialized")
- storages.append(storage)
- return MultiStorageAdapter(storages)
-
- def get_available_backends(self) -> list[StorageBackend]:
- """Get list of successfully initialized backends."""
- return list(self.backends.keys())
-
- def has_capability(self, backend: StorageBackend, capability: StorageCapabilities) -> bool:
- """Check if backend has specific capability."""
- backend_caps = self.capabilities.get(backend, StorageCapabilities.BASIC)
- return capability.value <= backend_caps.value
-
- async def get_all_collections(self) -> list[CollectionInfo]:
- """Get collections from all available backends, merging collections with same name."""
- collection_map: dict[str, CollectionInfo] = {}
-
- for backend_type, storage in self.backends.items():
- try:
- backend_collections = await storage.list_collections()
- for collection_name in backend_collections:
- # Validate collection name
- if not collection_name or not isinstance(collection_name, str):
- continue
-
- try:
- count = await storage.count(collection_name=collection_name)
- # Validate count is non-negative
- count = max(count, 0)
- except StorageError as e:
- # Storage-specific errors - log and use 0 count
- import logging
- logging.warning(f"Failed to get count for {collection_name} on {backend_type.value}: {e}")
- count = 0
- except Exception as e:
- # Unexpected errors - log and skip this collection from this backend
- import logging
- logging.warning(f"Unexpected error counting {collection_name} on {backend_type.value}: {e}")
- continue
-
- size_mb = count * 0.01 # Rough estimate: 10KB per document
-
- if collection_name in collection_map:
- # Merge with existing collection
- existing = collection_map[collection_name]
- existing_backends = existing["backend"]
- backend_value = backend_type.value
-
- if isinstance(existing_backends, str):
- existing["backend"] = [existing_backends, backend_value]
- elif isinstance(existing_backends, list):
- # Prevent duplicates
- if backend_value not in existing_backends:
- existing_backends.append(backend_value)
-
- # Aggregate counts and sizes
- existing["count"] += count
- existing["size_mb"] += size_mb
- else:
- # Create new collection entry
- collection_info: CollectionInfo = {
- "name": collection_name,
- "type": self._get_collection_type(collection_name, backend_type),
- "count": count,
- "backend": backend_type.value,
- "status": "active",
- "last_updated": "2024-01-01T00:00:00Z",
- "size_mb": size_mb,
- }
- collection_map[collection_name] = collection_info
- except Exception:
- continue
-
- return list(collection_map.values())
-
- def _get_collection_type(self, collection_name: str, backend: StorageBackend) -> str:
- """Determine collection type based on name and backend."""
- # Prioritize definitive backend type first
- if backend == StorageBackend.R2R:
- return "r2r"
- elif backend == StorageBackend.WEAVIATE:
- return "weaviate"
- elif backend == StorageBackend.OPEN_WEBUI:
- return "openwebui"
-
- # Fallback to name-based guessing if backend is not specific
- name_lower = collection_name.lower()
- if "web" in name_lower or "doc" in name_lower:
- return "documentation"
- elif "repo" in name_lower or "code" in name_lower:
- return "repository"
- else:
- return "general"
-
- async def search_across_backends(
- self,
- query: str,
- limit: int = 10,
- backends: list[StorageBackend] | None = None,
- ) -> dict[StorageBackend, list[Document]]:
- """Search across multiple backends and return grouped results."""
- if backends is None:
- backends = self.get_available_backends()
-
- results: dict[StorageBackend, list[Document]] = {}
-
- async def search_backend(backend_type: StorageBackend) -> None:
- storage = self.backends.get(backend_type)
- if storage:
- try:
- documents = []
- async for doc in storage.search(query, limit=limit):
- documents.append(doc)
- results[backend_type] = documents
- except Exception:
- results[backend_type] = []
-
- # Run searches in parallel
- tasks = [search_backend(backend) for backend in backends]
- await asyncio.gather(*tasks, return_exceptions=True)
-
- return results
-
- def get_r2r_storage(self) -> R2RStorage | None:
- """Get R2R storage instance if available."""
- storage = self.backends.get(StorageBackend.R2R)
- return storage if isinstance(storage, R2RStorage) else None
-
- async def get_backend_status(self) -> dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]]:
- """Get comprehensive status for all backends."""
- status: dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]] = {}
-
- for backend_type, storage in self.backends.items():
- try:
- collections = await storage.list_collections()
- total_docs = 0
- for collection in collections:
- total_docs += await storage.count(collection_name=collection)
-
- backend_status = {
- "available": True,
- "collections": len(collections),
- "total_documents": total_docs,
- "capabilities": self.capabilities.get(backend_type, StorageCapabilities.BASIC),
- "endpoint": getattr(storage.config, "endpoint", "unknown"),
- }
- status[backend_type] = backend_status
- except Exception as e:
- status[backend_type] = {
- "available": False,
- "error": str(e),
- "capabilities": StorageCapabilities.NONE,
- }
-
- return status
-
- async def close_all(self) -> None:
- """Close all storage connections."""
- for storage in self.backends.values():
- try:
- await storage.close()
- except Exception:
- pass
-
- self.backends.clear()
- self.capabilities.clear()
- self._initialized = False
-
- @property
- def is_initialized(self) -> bool:
- """Check if storage manager is initialized."""
- return self._initialized
-
- def supports_advanced_features(self, backend: StorageBackend) -> bool:
- """Check if backend supports advanced features like chunks and entities."""
- return self.has_capability(backend, StorageCapabilities.FULL_FEATURED)
-
-
"""Status indicators and progress bars with enhanced visual feedback."""
@@ -1473,8 +1120,12 @@ class StatusIndicator(Static):
status_lower = status.lower()
- if (status_lower in {"active", "online", "connected", "✓ active"} or
- status_lower.endswith("active") or "✓" in status_lower and "active" in status_lower):
+ if (
+ status_lower in {"active", "online", "connected", "✓ active"}
+ or status_lower.endswith("active")
+ or "✓" in status_lower
+ and "active" in status_lower
+ ):
self.add_class("status-active")
self.add_class("glow")
self.update(f"🟢 {status}")
@@ -1540,7 +1191,7 @@ class EnhancedProgressBar(Static):
"""Data models and TypedDict definitions for the TUI."""
from enum import IntEnum
-from typing import Any, TypedDict
+from typing import TypedDict
class StorageCapabilities(IntEnum):
@@ -1586,7 +1237,7 @@ class ChunkInfo(TypedDict):
content: str
start_index: int
end_index: int
- metadata: dict[str, Any]
+ metadata: dict[str, object]
class EntityInfo(TypedDict):
@@ -1596,7 +1247,7 @@ class EntityInfo(TypedDict):
name: str
type: str
confidence: float
- metadata: dict[str, Any]
+ metadata: dict[str, object]
class FirecrawlOptions(TypedDict, total=False):
@@ -1616,7 +1267,7 @@ class FirecrawlOptions(TypedDict, total=False):
max_depth: int
# Extraction options
- extract_schema: dict[str, Any] | None
+ extract_schema: dict[str, object] | None
extract_prompt: str | None
@@ -2158,6 +1809,7 @@ if TYPE_CHECKING:
try:
from .r2r.storage import R2RStorage as _RuntimeR2RStorage
+
R2RStorage: type[BaseStorage] | None = _RuntimeR2RStorage
except ImportError:
R2RStorage = None
@@ -2171,341 +1823,12 @@ __all__ = [
]
-
-"""Document management screen with enhanced navigation."""
-
-from datetime import datetime
-
-from textual.app import ComposeResult
-from textual.binding import Binding
-from textual.containers import Container, Horizontal
-from textual.screen import Screen
-from textual.widgets import Button, Footer, Header, Label, LoadingIndicator, Static
-from typing_extensions import override
-
-from ....storage.base import BaseStorage
-from ..models import CollectionInfo, DocumentInfo
-from ..widgets import EnhancedDataTable
-
-
-class DocumentManagementScreen(Screen[None]):
- """Screen for managing documents within a collection with enhanced keyboard navigation."""
-
- collection: CollectionInfo
- storage: BaseStorage | None
- documents: list[DocumentInfo]
- selected_docs: set[str]
- current_offset: int
- page_size: int
-
- BINDINGS = [
- Binding("escape", "app.pop_screen", "Back"),
- Binding("r", "refresh", "Refresh"),
- Binding("delete", "delete_selected", "Delete Selected"),
- Binding("a", "select_all", "Select All"),
- Binding("ctrl+a", "select_all", "Select All"),
- Binding("n", "select_none", "Clear Selection"),
- Binding("ctrl+shift+a", "select_none", "Clear Selection"),
- Binding("space", "toggle_selection", "Toggle Selection"),
- Binding("ctrl+d", "delete_selected", "Delete Selected"),
- Binding("pageup", "prev_page", "Previous Page"),
- Binding("pagedown", "next_page", "Next Page"),
- Binding("home", "first_page", "First Page"),
- Binding("end", "last_page", "Last Page"),
- ]
-
- def __init__(self, collection: CollectionInfo, storage: BaseStorage | None):
- super().__init__()
- self.collection = collection
- self.storage = storage
- self.documents: list[DocumentInfo] = []
- self.selected_docs: set[str] = set()
- self.current_offset = 0
- self.page_size = 50
-
- @override
- def compose(self) -> ComposeResult:
- yield Header()
- yield Container(
- Static(f"📄 Document Management: {self.collection['name']}", classes="title"),
- Static(
- f"Total Documents: {self.collection['count']:,} | Use Space to select, Delete to remove",
- classes="subtitle",
- ),
- Label(f"Page size: {self.page_size} documents"),
- EnhancedDataTable(id="documents_table", classes="enhanced-table"),
- Horizontal(
- Button("🔄 Refresh", id="refresh_docs_btn", variant="primary"),
- Button("🗑️ Delete Selected", id="delete_selected_btn", variant="error"),
- Button("✅ Select All", id="select_all_btn", variant="default"),
- Button("❌ Clear Selection", id="clear_selection_btn", variant="default"),
- Button("⬅️ Previous Page", id="prev_page_btn", variant="default"),
- Button("➡️ Next Page", id="next_page_btn", variant="default"),
- classes="button_bar",
- ),
- Label("", id="selection_status"),
- Static("", id="page_info", classes="status-text"),
- LoadingIndicator(id="loading"),
- classes="main_container",
- )
- yield Footer()
-
- async def on_mount(self) -> None:
- """Initialize the screen."""
- self.query_one("#loading").display = False
-
- # Setup documents table with enhanced columns
- table = self.query_one("#documents_table", EnhancedDataTable)
- table.add_columns(
- "✓", "Title", "Source URL", "Description", "Type", "Words", "Timestamp", "ID"
- )
-
- # Set up message handling for table events
- table.can_focus = True
-
- await self.load_documents()
-
- async def load_documents(self) -> None:
- """Load documents from the collection."""
- loading = self.query_one("#loading")
- loading.display = True
-
- try:
- if self.storage:
- # Try to load documents using the storage backend
- try:
- raw_docs = await self.storage.list_documents(
- limit=self.page_size,
- offset=self.current_offset,
- collection_name=self.collection["name"],
- )
- # Cast to proper type with type checking
- self.documents = [
- DocumentInfo(
- id=str(doc.get("id", f"doc_{i}")),
- title=str(doc.get("title", "Untitled Document")),
- source_url=str(doc.get("source_url", "")),
- description=str(doc.get("description", "")),
- content_type=str(doc.get("content_type", "text/plain")),
- content_preview=str(doc.get("content_preview", "")),
- word_count=(
- lambda wc_val: int(wc_val) if isinstance(wc_val, (int, str)) and str(wc_val).isdigit() else 0
- )(doc.get("word_count", 0)),
- timestamp=str(doc.get("timestamp", "")),
- )
- for i, doc in enumerate(raw_docs)
- ]
- except NotImplementedError:
- # For storage backends that don't support document listing, show a message
- self.notify(
- f"Document listing not supported for {self.storage.__class__.__name__}",
- severity="information"
- )
- self.documents = []
-
- await self.update_table()
- self.update_selection_status()
- self.update_page_info()
-
- except Exception as e:
- self.notify(f"Error loading documents: {e}", severity="error", markup=False)
- finally:
- loading.display = False
-
- async def update_table(self) -> None:
- """Update the documents table with enhanced metadata display."""
- table = self.query_one("#documents_table", EnhancedDataTable)
- table.clear(columns=True)
-
- # Add enhanced columns with more metadata
- table.add_columns(
- "✓", "Title", "Source URL", "Description", "Type", "Words", "Timestamp", "ID"
- )
-
- # Add rows with enhanced metadata
- for doc in self.documents:
- selected = "✓" if doc["id"] in self.selected_docs else ""
-
- # Get additional metadata from the raw docs
- description = str(doc.get("description") or "").strip()[:40]
- if not description:
- description = "[dim]No description[/dim]"
- elif len(str(doc.get("description") or "")) > 40:
- description += "..."
-
- # Format content type with appropriate icon
- content_type = doc.get("content_type", "text/plain")
- if "markdown" in content_type.lower():
- type_display = "📝 md"
- elif "html" in content_type.lower():
- type_display = "🌐 html"
- elif "text" in content_type.lower():
- type_display = "📄 txt"
- else:
- type_display = f"📄 {content_type.split('/')[-1][:5]}"
-
- # Format timestamp to be more readable
- timestamp = doc.get("timestamp", "")
- if timestamp:
- try:
- # Parse ISO format timestamp
- dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
- timestamp = dt.strftime("%m/%d %H:%M")
- except Exception:
- timestamp = str(timestamp)[:16] # Fallback
- table.add_row(
- selected,
- doc.get("title", "Untitled")[:40],
- doc.get("source_url", "")[:35],
- description,
- type_display,
- str(doc.get("word_count", 0)),
- timestamp,
- doc["id"][:8] + "...", # Show truncated ID
- )
-
- def update_selection_status(self) -> None:
- """Update the selection status label."""
- status_label = self.query_one("#selection_status", Label)
- total_selected = len(self.selected_docs)
- status_label.update(f"Selected: {total_selected} documents")
-
- def update_page_info(self) -> None:
- """Update the page information."""
- page_info = self.query_one("#page_info", Static)
- total_docs = self.collection["count"]
- start = self.current_offset + 1
- end = min(self.current_offset + len(self.documents), total_docs)
- page_num = (self.current_offset // self.page_size) + 1
- total_pages = (total_docs + self.page_size - 1) // self.page_size
-
- page_info.update(
- f"Showing {start:,}-{end:,} of {total_docs:,} documents (Page {page_num} of {total_pages})"
- )
-
- def get_current_document(self) -> DocumentInfo | None:
- """Get the currently selected document."""
- table = self.query_one("#documents_table", EnhancedDataTable)
- try:
- if 0 <= table.cursor_coordinate.row < len(self.documents):
- return self.documents[table.cursor_coordinate.row]
- except (AttributeError, IndexError):
- pass
- return None
-
- # Action methods
- def action_refresh(self) -> None:
- """Refresh the document list."""
- self.run_worker(self.load_documents())
-
- def action_toggle_selection(self) -> None:
- """Toggle selection of current row."""
- if doc := self.get_current_document():
- doc_id = doc["id"]
- if doc_id in self.selected_docs:
- self.selected_docs.remove(doc_id)
- else:
- self.selected_docs.add(doc_id)
-
- self.run_worker(self.update_table())
- self.update_selection_status()
-
- def action_select_all(self) -> None:
- """Select all documents on current page."""
- for doc in self.documents:
- self.selected_docs.add(doc["id"])
- self.run_worker(self.update_table())
- self.update_selection_status()
-
- def action_select_none(self) -> None:
- """Clear all selections."""
- self.selected_docs.clear()
- self.run_worker(self.update_table())
- self.update_selection_status()
-
- def action_delete_selected(self) -> None:
- """Delete selected documents."""
- if self.selected_docs:
- from .dialogs import ConfirmDocumentDeleteScreen
-
- self.app.push_screen(
- ConfirmDocumentDeleteScreen(list(self.selected_docs), self.collection, self)
- )
- else:
- self.notify("No documents selected", severity="warning")
-
- def action_next_page(self) -> None:
- """Go to next page."""
- if self.current_offset + self.page_size < self.collection["count"]:
- self.current_offset += self.page_size
- self.run_worker(self.load_documents())
-
- def action_prev_page(self) -> None:
- """Go to previous page."""
- if self.current_offset >= self.page_size:
- self.current_offset -= self.page_size
- self.run_worker(self.load_documents())
-
- def action_first_page(self) -> None:
- """Go to first page."""
- if self.current_offset > 0:
- self.current_offset = 0
- self.run_worker(self.load_documents())
-
- def action_last_page(self) -> None:
- """Go to last page."""
- total_docs = self.collection["count"]
- last_offset = ((total_docs - 1) // self.page_size) * self.page_size
- if self.current_offset != last_offset:
- self.current_offset = last_offset
- self.run_worker(self.load_documents())
-
- def on_button_pressed(self, event: Button.Pressed) -> None:
- """Handle button presses."""
- if event.button.id == "refresh_docs_btn":
- self.action_refresh()
- elif event.button.id == "delete_selected_btn":
- self.action_delete_selected()
- elif event.button.id == "select_all_btn":
- self.action_select_all()
- elif event.button.id == "clear_selection_btn":
- self.action_select_none()
- elif event.button.id == "next_page_btn":
- self.action_next_page()
- elif event.button.id == "prev_page_btn":
- self.action_prev_page()
-
- def on_enhanced_data_table_row_toggled(self, event: EnhancedDataTable.RowToggled) -> None:
- """Handle row toggle from enhanced table."""
- if 0 <= event.row_index < len(self.documents):
- doc = self.documents[event.row_index]
- doc_id = doc["id"]
-
- if doc_id in self.selected_docs:
- self.selected_docs.remove(doc_id)
- else:
- self.selected_docs.add(doc_id)
-
- self.run_worker(self.update_table())
- self.update_selection_status()
-
- def on_enhanced_data_table_select_all(self, event: EnhancedDataTable.SelectAll) -> None:
- """Handle select all from enhanced table."""
- self.action_select_all()
-
- def on_enhanced_data_table_clear_selection(
- self, event: EnhancedDataTable.ClearSelection
- ) -> None:
- """Handle clear selection from enhanced table."""
- self.action_select_none()
-
-
"""Enhanced ingestion screen with multi-storage support."""
from __future__ import annotations
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, cast
from textual import work
from textual.app import ComposeResult
@@ -2608,12 +1931,23 @@ class IngestionScreen(ModalScreen[None]):
Label("📋 Source Type (Press 1/2/3):", classes="input-label"),
Horizontal(
Button("🌐 Web (1)", id="web_btn", variant="primary", classes="type-button"),
- Button("📦 Repository (2)", id="repo_btn", variant="default", classes="type-button"),
- Button("📖 Documentation (3)", id="docs_btn", variant="default", classes="type-button"),
+ Button(
+ "📦 Repository (2)", id="repo_btn", variant="default", classes="type-button"
+ ),
+ Button(
+ "📖 Documentation (3)",
+ id="docs_btn",
+ variant="default",
+ classes="type-button",
+ ),
classes="type_buttons",
),
Rule(line_style="dashed"),
- Label(f"🗄️ Target Storages ({len(self.available_backends)} available):", classes="input-label", id="backend_label"),
+ Label(
+ f"🗄️ Target Storages ({len(self.available_backends)} available):",
+ classes="input-label",
+ id="backend_label",
+ ),
Container(
*self._create_backend_checkbox_widgets(),
classes="backend-selection",
@@ -2642,7 +1976,6 @@ class IngestionScreen(ModalScreen[None]):
yield LoadingIndicator(id="loading", classes="pulse")
-
def _create_backend_checkbox_widgets(self) -> list[Checkbox]:
"""Create checkbox widgets for each available backend."""
checkboxes: list[Checkbox] = [
@@ -2722,19 +2055,26 @@ class IngestionScreen(ModalScreen[None]):
collection_name = collection_input.value.strip()
if not source_url:
- self.notify("🔍 Please enter a source URL", severity="error")
+ cast("CollectionManagementApp", self.app).safe_notify(
+ "🔍 Please enter a source URL", severity="error"
+ )
url_input.focus()
return
# Validate URL format
if not self._validate_url(source_url):
- self.notify("❌ Invalid URL format. Please enter a valid HTTP/HTTPS URL or file:// path", severity="error")
+ cast("CollectionManagementApp", self.app).safe_notify(
+ "❌ Invalid URL format. Please enter a valid HTTP/HTTPS URL or file:// path",
+ severity="error",
+ )
url_input.focus()
return
resolved_backends = self._resolve_selected_backends()
if not resolved_backends:
- self.notify("⚠️ Select at least one storage backend", severity="warning")
+ cast("CollectionManagementApp", self.app).safe_notify(
+ "⚠️ Select at least one storage backend", severity="warning"
+ )
return
self.selected_backends = resolved_backends
@@ -2749,18 +2089,16 @@ class IngestionScreen(ModalScreen[None]):
url_lower = url.lower()
# Allow HTTP/HTTPS URLs
- if url_lower.startswith(('http://', 'https://')):
+ if url_lower.startswith(("http://", "https://")):
# Additional validation could be added here
return True
# Allow file:// URLs for repository paths
- if url_lower.startswith('file://'):
+ if url_lower.startswith("file://"):
return True
# Allow local file paths that look like repositories
- return '/' in url and not url_lower.startswith(
- ('javascript:', 'data:', 'vbscript:')
- )
+ return "/" in url and not url_lower.startswith(("javascript:", "data:", "vbscript:"))
def _resolve_selected_backends(self) -> list[StorageBackend]:
selected: list[StorageBackend] = []
@@ -2791,13 +2129,14 @@ class IngestionScreen(ModalScreen[None]):
if not self.selected_backends:
status_widget.update("📋 Selected: None")
elif len(self.selected_backends) == 1:
- backend_name = BACKEND_LABELS.get(self.selected_backends[0], self.selected_backends[0].value)
+ backend_name = BACKEND_LABELS.get(
+ self.selected_backends[0], self.selected_backends[0].value
+ )
status_widget.update(f"📋 Selected: {backend_name}")
else:
# Multiple backends selected
backend_names = [
- BACKEND_LABELS.get(backend, backend.value)
- for backend in self.selected_backends
+ BACKEND_LABELS.get(backend, backend.value) for backend in self.selected_backends
]
if len(backend_names) <= 3:
# Show all names if 3 or fewer
@@ -2822,7 +2161,10 @@ class IngestionScreen(ModalScreen[None]):
for backend in BACKEND_ORDER:
if backend not in self.available_backends:
continue
- if backend.value.lower() == backend_name_lower or backend.name.lower() == backend_name_lower:
+ if (
+ backend.value.lower() == backend_name_lower
+ or backend.name.lower() == backend_name_lower
+ ):
matched_backends.append(backend)
break
return matched_backends or [self.available_backends[0]]
@@ -2854,6 +2196,7 @@ class IngestionScreen(ModalScreen[None]):
loading.display = False
except Exception:
pass
+
cast("CollectionManagementApp", self.app).call_from_thread(_update)
def progress_reporter(percent: int, message: str) -> None:
@@ -2865,6 +2208,7 @@ class IngestionScreen(ModalScreen[None]):
progress_text.update(message)
except Exception:
pass
+
cast("CollectionManagementApp", self.app).call_from_thread(_update_progress)
try:
@@ -2880,13 +2224,18 @@ class IngestionScreen(ModalScreen[None]):
for i, backend in enumerate(backends):
progress_percent = 20 + (60 * i) // len(backends)
- progress_reporter(progress_percent, f"🔗 Processing {backend.value} backend ({i+1}/{len(backends)})...")
+ progress_reporter(
+ progress_percent,
+ f"🔗 Processing {backend.value} backend ({i + 1}/{len(backends)})...",
+ )
try:
# Run the Prefect flow for this backend using asyncio.run with timeout
import asyncio
- async def run_flow_with_timeout(current_backend: StorageBackend = backend) -> IngestionResult:
+ async def run_flow_with_timeout(
+ current_backend: StorageBackend = backend,
+ ) -> IngestionResult:
return await asyncio.wait_for(
create_ingestion_flow(
source_url=source_url,
@@ -2895,7 +2244,7 @@ class IngestionScreen(ModalScreen[None]):
collection_name=final_collection_name,
progress_callback=progress_reporter,
),
- timeout=600.0 # 10 minute timeout
+ timeout=600.0, # 10 minute timeout
)
result = asyncio.run(run_flow_with_timeout())
@@ -2904,25 +2253,33 @@ class IngestionScreen(ModalScreen[None]):
total_failed += result.documents_failed
if result.error_messages:
- flow_errors.extend([f"{backend.value}: {err}" for err in result.error_messages])
+ flow_errors.extend(
+ [f"{backend.value}: {err}" for err in result.error_messages]
+ )
except TimeoutError:
error_msg = f"{backend.value}: Timeout after 10 minutes"
flow_errors.append(error_msg)
progress_reporter(0, f"❌ {backend.value} timed out")
- def notify_timeout(msg: str = f"⏰ {backend.value} flow timed out after 10 minutes") -> None:
+
+ def notify_timeout(
+ msg: str = f"⏰ {backend.value} flow timed out after 10 minutes",
+ ) -> None:
try:
self.notify(msg, severity="error", markup=False)
except Exception:
pass
+
cast("CollectionManagementApp", self.app).call_from_thread(notify_timeout)
except Exception as exc:
flow_errors.append(f"{backend.value}: {exc}")
+
def notify_error(msg: str = f"❌ {backend.value} flow failed: {exc}") -> None:
try:
self.notify(msg, severity="error", markup=False)
except Exception:
pass
+
cast("CollectionManagementApp", self.app).call_from_thread(notify_error)
successful = total_successful
@@ -2948,20 +2305,38 @@ class IngestionScreen(ModalScreen[None]):
cast("CollectionManagementApp", self.app).call_from_thread(notify_results)
- import time
- time.sleep(2)
- cast("CollectionManagementApp", self.app).pop_screen()
+ def _pop() -> None:
+ try:
+ self.app.pop_screen()
+ except Exception:
+ pass
+
+ # Schedule screen pop via timer instead of blocking
+ cast("CollectionManagementApp", self.app).call_from_thread(
+ lambda: self.app.set_timer(2.0, _pop)
+ )
except Exception as exc: # pragma: no cover - defensive
progress_reporter(0, f"❌ Prefect flows error: {exc}")
+
def notify_error(msg: str = f"❌ Prefect flows failed: {exc}") -> None:
try:
self.notify(msg, severity="error")
except Exception:
pass
+
cast("CollectionManagementApp", self.app).call_from_thread(notify_error)
- import time
- time.sleep(2)
+
+ def _pop_on_error() -> None:
+ try:
+ self.app.pop_screen()
+ except Exception:
+ pass
+
+ # Schedule screen pop via timer for error case too
+ cast("CollectionManagementApp", self.app).call_from_thread(
+ lambda: self.app.set_timer(2.0, _pop_on_error)
+ )
finally:
update_ui("hide_loading")
@@ -3012,12 +2387,24 @@ class SearchScreen(Screen[None]):
@override
def compose(self) -> ComposeResult:
yield Header()
+ # Check if search is supported for this backend
+ backends = self.collection["backend"]
+ if isinstance(backends, str):
+ backends = [backends]
+ search_supported = "weaviate" in backends
+ search_indicator = "✅ Search supported" if search_supported else "❌ Search not supported"
+
yield Container(
Static(
- f"🔍 Search in: {self.collection['name']} ({self.collection['backend']})",
+ f"🔍 Search in: {self.collection['name']} ({', '.join(backends)}) - {search_indicator}",
classes="title",
),
- Static("Press / or Ctrl+F to focus search, Enter to search", classes="subtitle"),
+ Static(
+ "Press / or Ctrl+F to focus search, Enter to search"
+ if search_supported
+ else "Search functionality not available for this backend",
+ classes="subtitle",
+ ),
Input(placeholder="Enter search query... (press Enter to search)", id="search_input"),
Button("🔍 Search", id="search_btn", variant="primary"),
Button("🗑️ Clear Results", id="clear_btn", variant="default"),
@@ -3096,7 +2483,9 @@ class SearchScreen(Screen[None]):
finally:
loading.display = False
- def _setup_search_ui(self, loading: LoadingIndicator, table: EnhancedDataTable, status: Static, query: str) -> None:
+ def _setup_search_ui(
+ self, loading: LoadingIndicator, table: EnhancedDataTable, status: Static, query: str
+ ) -> None:
"""Setup the search UI elements."""
loading.display = True
status.update(f"🔍 Searching for '{query}'...")
@@ -3108,10 +2497,18 @@ class SearchScreen(Screen[None]):
if self.collection["type"] == "weaviate" and self.weaviate:
return await self.search_weaviate(query)
elif self.collection["type"] == "openwebui" and self.openwebui:
- return await self.search_openwebui(query)
+ # OpenWebUI search is not yet implemented
+ self.notify("Search not supported for OpenWebUI collections", severity="warning")
+ return []
+ elif self.collection["type"] == "r2r":
+ # R2R search would go here when implemented
+ self.notify("Search not supported for R2R collections", severity="warning")
+ return []
return []
- def _populate_results_table(self, table: EnhancedDataTable, results: list[dict[str, str | float]]) -> None:
+ def _populate_results_table(
+ self, table: EnhancedDataTable, results: list[dict[str, str | float]]
+ ) -> None:
"""Populate the results table with search results."""
for result in results:
row_data = self._format_result_row(result)
@@ -3162,7 +2559,11 @@ class SearchScreen(Screen[None]):
return str(score)
def _update_search_status(
- self, status: Static, query: str, results: list[dict[str, str | float]], table: EnhancedDataTable
+ self,
+ status: Static,
+ query: str,
+ results: list[dict[str, str | float]],
+ table: EnhancedDataTable,
) -> None:
"""Update search status and notifications based on results."""
if not results:
@@ -3232,153 +2633,542 @@ class SearchScreen(Screen[None]):
return []
-
-"""TUI runner functions and initialization."""
+
+"""Storage management utilities for TUI applications."""
from __future__ import annotations
import asyncio
-import logging
-from logging import Logger
-from logging.handlers import QueueHandler, RotatingFileHandler
-from pathlib import Path
-from queue import Queue
-from typing import NamedTuple
+from collections.abc import AsyncGenerator, Coroutine, Sequence
+from typing import TYPE_CHECKING, Protocol
-from ....config import configure_prefect, get_settings
-from ....core.models import StorageBackend
-from .storage_manager import StorageManager
+from pydantic import SecretStr
+
+from ....core.exceptions import StorageError
+from ....core.models import Document, StorageBackend, StorageConfig
+from ....storage.base import BaseStorage
+from ....storage.openwebui import OpenWebUIStorage
+from ....storage.r2r.storage import R2RStorage
+from ....storage.weaviate import WeaviateStorage
+from ..models import CollectionInfo, StorageCapabilities
+
+if TYPE_CHECKING:
+ from ....config.settings import Settings
-class _TuiLoggingContext(NamedTuple):
- """Container describing configured logging outputs for the TUI."""
+class StorageBackendProtocol(Protocol):
+ """Protocol defining storage backend interface."""
- queue: Queue[logging.LogRecord]
- formatter: logging.Formatter
- log_file: Path | None
+ async def initialize(self) -> None: ...
+ async def count(self, *, collection_name: str | None = None) -> int: ...
+ async def list_collections(self) -> list[str]: ...
+ async def search(
+ self,
+ query: str,
+ limit: int = 10,
+ threshold: float = 0.7,
+ *,
+ collection_name: str | None = None,
+ ) -> AsyncGenerator[Document, None]: ...
+ async def close(self) -> None: ...
-_logging_context: _TuiLoggingContext | None = None
+class MultiStorageAdapter(BaseStorage):
+ """Mirror writes to multiple storage backends."""
+ def __init__(self, storages: Sequence[BaseStorage]) -> None:
+ if not storages:
+ raise ValueError("MultiStorageAdapter requires at least one storage backend")
-def _configure_tui_logging(*, log_level: str) -> _TuiLoggingContext:
- """Configure logging so that messages do not break the TUI output."""
+ unique: list[BaseStorage] = []
+ seen_ids: set[int] = set()
+ for storage in storages:
+ storage_id = id(storage)
+ if storage_id in seen_ids:
+ continue
+ seen_ids.add(storage_id)
+ unique.append(storage)
- global _logging_context
- if _logging_context is not None:
- return _logging_context
+ self._storages: list[BaseStorage] = unique
+ self._primary: BaseStorage = unique[0]
+ super().__init__(self._primary.config)
- resolved_level = getattr(logging, log_level.upper(), logging.INFO)
- log_queue: Queue[logging.LogRecord] = Queue()
- formatter = logging.Formatter(
- fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
- datefmt="%Y-%m-%d %H:%M:%S",
- )
+ async def initialize(self) -> None:
+ for storage in self._storages:
+ await storage.initialize()
- root_logger = logging.getLogger()
- root_logger.setLevel(resolved_level)
+ async def store(self, document: Document, *, collection_name: str | None = None) -> str:
+ # Store in primary backend first
+ primary_id: str = await self._primary.store(document, collection_name=collection_name)
- # Remove existing stream handlers to prevent console flicker inside the TUI
- for handler in list(root_logger.handlers):
- root_logger.removeHandler(handler)
+ # Replicate to secondary backends concurrently
+ if len(self._storages) > 1:
- queue_handler = QueueHandler(log_queue)
- queue_handler.setLevel(resolved_level)
- root_logger.addHandler(queue_handler)
+ async def replicate_to_backend(
+ storage: BaseStorage,
+ ) -> tuple[BaseStorage, bool, Exception | None]:
+ try:
+ await storage.store(document, collection_name=collection_name)
+ return storage, True, None
+ except Exception as exc:
+ return storage, False, exc
- log_file: Path | None = None
- try:
- log_dir = Path.cwd() / "logs"
- log_dir.mkdir(parents=True, exist_ok=True)
- log_file = log_dir / "tui.log"
- file_handler = RotatingFileHandler(
- log_file,
- maxBytes=2_000_000,
- backupCount=5,
- encoding="utf-8",
+ tasks = [replicate_to_backend(storage) for storage in self._storages[1:]]
+ results = await asyncio.gather(*tasks, return_exceptions=True)
+
+ failures: list[str] = []
+ errors: list[Exception] = []
+
+ for result in results:
+ if isinstance(result, tuple):
+ storage, success, error = result
+ if not success and error is not None:
+ failures.append(self._format_backend_label(storage))
+ errors.append(error)
+ elif isinstance(result, Exception):
+ failures.append("unknown")
+ errors.append(result)
+
+ if failures:
+ backends = ", ".join(failures)
+ primary_error = errors[0] if errors else Exception("Unknown replication error")
+ raise StorageError(
+ f"Document stored in primary backend but replication failed for: {backends}"
+ ) from primary_error
+
+ return primary_id
+
+ async def store_batch(
+ self, documents: list[Document], *, collection_name: str | None = None
+ ) -> list[str]:
+ # Store in primary backend first
+ primary_ids: list[str] = await self._primary.store_batch(
+ documents, collection_name=collection_name
)
- file_handler.setLevel(resolved_level)
- file_handler.setFormatter(formatter)
- root_logger.addHandler(file_handler)
- except OSError as exc: # pragma: no cover - filesystem specific
- fallback = logging.getLogger(__name__)
- fallback.warning("Failed to configure file logging for TUI: %s", exc)
- _logging_context = _TuiLoggingContext(log_queue, formatter, log_file)
- return _logging_context
+ # Replicate to secondary backends concurrently
+ if len(self._storages) > 1:
+
+ async def replicate_batch_to_backend(
+ storage: BaseStorage,
+ ) -> tuple[BaseStorage, bool, Exception | None]:
+ try:
+ await storage.store_batch(documents, collection_name=collection_name)
+ return storage, True, None
+ except Exception as exc:
+ return storage, False, exc
+
+ tasks = [replicate_batch_to_backend(storage) for storage in self._storages[1:]]
+ results = await asyncio.gather(*tasks, return_exceptions=True)
+
+ failures: list[str] = []
+ errors: list[Exception] = []
+
+ for result in results:
+ if isinstance(result, tuple):
+ storage, success, error = result
+ if not success and error is not None:
+ failures.append(self._format_backend_label(storage))
+ errors.append(error)
+ elif isinstance(result, Exception):
+ failures.append("unknown")
+ errors.append(result)
+
+ if failures:
+ backends = ", ".join(failures)
+ primary_error = (
+ errors[0] if errors else Exception("Unknown batch replication error")
+ )
+ raise StorageError(
+ f"Batch stored in primary backend but replication failed for: {backends}"
+ ) from primary_error
+
+ return primary_ids
+
+ async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool:
+ # Delete from primary backend first
+ primary_deleted: bool = await self._primary.delete(
+ document_id, collection_name=collection_name
+ )
+
+ # Delete from secondary backends concurrently
+ if len(self._storages) > 1:
+
+ async def delete_from_backend(
+ storage: BaseStorage,
+ ) -> tuple[BaseStorage, bool, Exception | None]:
+ try:
+ await storage.delete(document_id, collection_name=collection_name)
+ return storage, True, None
+ except Exception as exc:
+ return storage, False, exc
+
+ tasks = [delete_from_backend(storage) for storage in self._storages[1:]]
+ results = await asyncio.gather(*tasks, return_exceptions=True)
+
+ failures: list[str] = []
+ errors: list[Exception] = []
+
+ for result in results:
+ if isinstance(result, tuple):
+ storage, success, error = result
+ if not success and error is not None:
+ failures.append(self._format_backend_label(storage))
+ errors.append(error)
+ elif isinstance(result, Exception):
+ failures.append("unknown")
+ errors.append(result)
+
+ if failures:
+ backends = ", ".join(failures)
+ primary_error = errors[0] if errors else Exception("Unknown deletion error")
+ raise StorageError(
+ f"Document deleted from primary backend but failed for: {backends}"
+ ) from primary_error
+
+ return primary_deleted
+
+ async def count(self, *, collection_name: str | None = None) -> int:
+ count_result: int = await self._primary.count(collection_name=collection_name)
+ return count_result
+
+ async def list_collections(self) -> list[str]:
+ list_fn = getattr(self._primary, "list_collections", None)
+ if list_fn is None:
+ return []
+ collections_result: list[str] = await list_fn()
+ return collections_result
+
+ async def search(
+ self,
+ query: str,
+ limit: int = 10,
+ threshold: float = 0.7,
+ *,
+ collection_name: str | None = None,
+ ) -> AsyncGenerator[Document, None]:
+ async for item in self._primary.search(
+ query,
+ limit=limit,
+ threshold=threshold,
+ collection_name=collection_name,
+ ):
+ yield item
+
+ async def close(self) -> None:
+ for storage in self._storages:
+ close_fn = getattr(storage, "close", None)
+ if close_fn is not None:
+ await close_fn()
+
+ def _format_backend_label(self, storage: BaseStorage) -> str:
+ backend = getattr(storage.config, "backend", None)
+ if isinstance(backend, StorageBackend):
+ backend_value: str = backend.value
+ return backend_value
+ class_name: str = storage.__class__.__name__
+ return class_name
-LOGGER: Logger = logging.getLogger(__name__)
+class StorageManager:
+ """Centralized manager for all storage backend operations."""
+ def __init__(self, settings: Settings) -> None:
+ """Initialize storage manager with application settings."""
+ self.settings: Settings = settings
+ self.backends: dict[StorageBackend, BaseStorage] = {}
+ self.capabilities: dict[StorageBackend, StorageCapabilities] = {}
+ self._initialized: bool = False
-async def run_textual_tui() -> None:
- """Run the enhanced modern TUI with better error handling and initialization."""
- settings = get_settings()
- configure_prefect(settings)
+ async def initialize_all_backends(self) -> dict[StorageBackend, bool]:
+ """Initialize all available storage backends with timeout protection."""
+ results: dict[StorageBackend, bool] = {}
- logging_context = _configure_tui_logging(log_level=settings.log_level)
+ async def init_backend(
+ backend_type: StorageBackend, config: StorageConfig, storage_class: type[BaseStorage]
+ ) -> bool:
+ """Initialize a single backend with timeout."""
+ try:
+ storage = storage_class(config)
+ await asyncio.wait_for(storage.initialize(), timeout=30.0)
+ self.backends[backend_type] = storage
+ if backend_type == StorageBackend.WEAVIATE:
+ self.capabilities[backend_type] = StorageCapabilities.VECTOR_SEARCH
+ elif backend_type == StorageBackend.OPEN_WEBUI:
+ self.capabilities[backend_type] = StorageCapabilities.KNOWLEDGE_BASE
+ elif backend_type == StorageBackend.R2R:
+ self.capabilities[backend_type] = StorageCapabilities.FULL_FEATURED
+ return True
+ except (TimeoutError, Exception):
+ return False
- LOGGER.info("Initializing collection management TUI")
- LOGGER.info("Scanning available storage backends")
+ # Initialize backends concurrently with timeout protection
+ tasks: list[tuple[StorageBackend, Coroutine[None, None, bool]]] = []
- # Initialize storage manager
- storage_manager = StorageManager(settings)
- backend_status = await storage_manager.initialize_all_backends()
-
- # Report initialization results
- for backend, success in backend_status.items():
- if success:
- LOGGER.info("%s connected successfully", backend.value)
+ # Try Weaviate
+ if self.settings.weaviate_endpoint:
+ config = StorageConfig(
+ backend=StorageBackend.WEAVIATE,
+ endpoint=self.settings.weaviate_endpoint,
+ api_key=SecretStr(self.settings.weaviate_api_key)
+ if self.settings.weaviate_api_key
+ else None,
+ collection_name="default",
+ )
+ tasks.append(
+ (
+ StorageBackend.WEAVIATE,
+ init_backend(StorageBackend.WEAVIATE, config, WeaviateStorage),
+ )
+ )
else:
- LOGGER.warning("%s connection failed", backend.value)
+ results[StorageBackend.WEAVIATE] = False
- available_backends = storage_manager.get_available_backends()
- if not available_backends:
- LOGGER.error("Could not connect to any storage backend")
- LOGGER.info("Please check your configuration and try again")
- LOGGER.info("Supported backends: Weaviate, OpenWebUI, R2R")
- return
+ # Try OpenWebUI
+ if self.settings.openwebui_endpoint and self.settings.openwebui_api_key:
+ config = StorageConfig(
+ backend=StorageBackend.OPEN_WEBUI,
+ endpoint=self.settings.openwebui_endpoint,
+ api_key=SecretStr(self.settings.openwebui_api_key)
+ if self.settings.openwebui_api_key
+ else None,
+ collection_name="default",
+ )
+ tasks.append(
+ (
+ StorageBackend.OPEN_WEBUI,
+ init_backend(StorageBackend.OPEN_WEBUI, config, OpenWebUIStorage),
+ )
+ )
+ else:
+ results[StorageBackend.OPEN_WEBUI] = False
- LOGGER.info(
- "Launching TUI with %d backend(s): %s",
- len(available_backends),
- ", ".join(backend.value for backend in available_backends),
- )
+ # Try R2R
+ if self.settings.r2r_endpoint:
+ config = StorageConfig(
+ backend=StorageBackend.R2R,
+ endpoint=self.settings.r2r_endpoint,
+ api_key=SecretStr(self.settings.r2r_api_key) if self.settings.r2r_api_key else None,
+ collection_name="default",
+ )
+ tasks.append((StorageBackend.R2R, init_backend(StorageBackend.R2R, config, R2RStorage)))
+ else:
+ results[StorageBackend.R2R] = False
- # Get individual storage instances for backward compatibility
- from ....storage.openwebui import OpenWebUIStorage
- from ....storage.weaviate import WeaviateStorage
+ # Execute initialization tasks concurrently
+ if tasks:
+ backend_types, task_coroutines = zip(*tasks, strict=False)
+ task_results: Sequence[bool | BaseException] = await asyncio.gather(
+ *task_coroutines, return_exceptions=True
+ )
- weaviate_backend = storage_manager.get_backend(StorageBackend.WEAVIATE)
- openwebui_backend = storage_manager.get_backend(StorageBackend.OPEN_WEBUI)
- r2r_backend = storage_manager.get_backend(StorageBackend.R2R)
+ for backend_type, task_result in zip(backend_types, task_results, strict=False):
+ results[backend_type] = task_result if isinstance(task_result, bool) else False
+ self._initialized = True
+ return results
- # Type-safe casting to specific storage types
- weaviate = weaviate_backend if isinstance(weaviate_backend, WeaviateStorage) else None
- openwebui = openwebui_backend if isinstance(openwebui_backend, OpenWebUIStorage) else None
+ def get_backend(self, backend_type: StorageBackend) -> BaseStorage | None:
+ """Get storage backend by type."""
+ return self.backends.get(backend_type)
- # Import here to avoid circular import
- from ..app import CollectionManagementApp
- app = CollectionManagementApp(
- storage_manager,
- weaviate,
- openwebui,
- r2r_backend,
- log_queue=logging_context.queue,
- log_formatter=logging_context.formatter,
- log_file=logging_context.log_file,
- )
- try:
- await app.run_async()
- finally:
- LOGGER.info("Shutting down storage connections")
- await storage_manager.close_all()
- LOGGER.info("All storage connections closed gracefully")
+ def build_multi_storage_adapter(
+ self, backends: Sequence[StorageBackend]
+ ) -> MultiStorageAdapter:
+ storages: list[BaseStorage] = []
+ seen: set[StorageBackend] = set()
+ for backend in backends:
+ backend_enum = (
+ backend if isinstance(backend, StorageBackend) else StorageBackend(backend)
+ )
+ if backend_enum in seen:
+ continue
+ seen.add(backend_enum)
+ storage = self.backends.get(backend_enum)
+ if storage is None:
+ raise ValueError(f"Storage backend {backend_enum.value} is not initialized")
+ storages.append(storage)
+ return MultiStorageAdapter(storages)
+ def get_available_backends(self) -> list[StorageBackend]:
+ """Get list of successfully initialized backends."""
+ return list(self.backends.keys())
-def dashboard() -> None:
- """Launch the modern collection dashboard."""
- asyncio.run(run_textual_tui())
+ def has_capability(self, backend: StorageBackend, capability: StorageCapabilities) -> bool:
+ """Check if backend has specific capability."""
+ backend_caps = self.capabilities.get(backend, StorageCapabilities.BASIC)
+ return capability.value <= backend_caps.value
+
+ async def get_all_collections(self) -> list[CollectionInfo]:
+ """Get collections from all available backends, merging collections with same name."""
+ collection_map: dict[str, CollectionInfo] = {}
+
+ for backend_type, storage in self.backends.items():
+ try:
+ backend_collections = await storage.list_collections()
+ for collection_name in backend_collections:
+ # Validate collection name
+ if not collection_name or not isinstance(collection_name, str):
+ continue
+
+ try:
+ count = await storage.count(collection_name=collection_name)
+ # Validate count is non-negative
+ count = max(count, 0)
+ except StorageError as e:
+ # Storage-specific errors - log and use 0 count
+ import logging
+
+ logging.warning(
+ f"Failed to get count for {collection_name} on {backend_type.value}: {e}"
+ )
+ count = 0
+ except Exception as e:
+ # Unexpected errors - log and skip this collection from this backend
+ import logging
+
+ logging.warning(
+ f"Unexpected error counting {collection_name} on {backend_type.value}: {e}"
+ )
+ continue
+
+ size_mb = count * 0.01 # Rough estimate: 10KB per document
+
+ if collection_name in collection_map:
+ # Merge with existing collection
+ existing = collection_map[collection_name]
+ existing_backends = existing["backend"]
+ backend_value = backend_type.value
+
+ if isinstance(existing_backends, str):
+ existing["backend"] = [existing_backends, backend_value]
+ elif isinstance(existing_backends, list):
+ # Prevent duplicates
+ if backend_value not in existing_backends:
+ existing_backends.append(backend_value)
+
+ # Aggregate counts and sizes
+ existing["count"] += count
+ existing["size_mb"] += size_mb
+ else:
+ # Create new collection entry
+ collection_info: CollectionInfo = {
+ "name": collection_name,
+ "type": self._get_collection_type(collection_name, backend_type),
+ "count": count,
+ "backend": backend_type.value,
+ "status": "active",
+ "last_updated": "2024-01-01T00:00:00Z",
+ "size_mb": size_mb,
+ }
+ collection_map[collection_name] = collection_info
+ except Exception:
+ continue
+
+ return list(collection_map.values())
+
+ def _get_collection_type(self, collection_name: str, backend: StorageBackend) -> str:
+ """Determine collection type based on name and backend."""
+ # Prioritize definitive backend type first
+ if backend == StorageBackend.R2R:
+ return "r2r"
+ elif backend == StorageBackend.WEAVIATE:
+ return "weaviate"
+ elif backend == StorageBackend.OPEN_WEBUI:
+ return "openwebui"
+
+ # Fallback to name-based guessing if backend is not specific
+ name_lower = collection_name.lower()
+ if "web" in name_lower or "doc" in name_lower:
+ return "documentation"
+ elif "repo" in name_lower or "code" in name_lower:
+ return "repository"
+ else:
+ return "general"
+
+ async def search_across_backends(
+ self,
+ query: str,
+ limit: int = 10,
+ backends: list[StorageBackend] | None = None,
+ ) -> dict[StorageBackend, list[Document]]:
+ """Search across multiple backends and return grouped results."""
+ if backends is None:
+ backends = self.get_available_backends()
+
+ results: dict[StorageBackend, list[Document]] = {}
+
+ async def search_backend(backend_type: StorageBackend) -> None:
+ storage = self.backends.get(backend_type)
+ if storage:
+ try:
+ documents: list[Document] = []
+ async for doc in storage.search(query, limit=limit):
+ documents.append(doc)
+ results[backend_type] = documents
+ except Exception:
+ results[backend_type] = []
+
+ # Run searches in parallel
+ tasks = [search_backend(backend) for backend in backends]
+ await asyncio.gather(*tasks, return_exceptions=True)
+
+ return results
+
+ def get_r2r_storage(self) -> R2RStorage | None:
+ """Get R2R storage instance if available."""
+ storage = self.backends.get(StorageBackend.R2R)
+ return storage if isinstance(storage, R2RStorage) else None
+
+ async def get_backend_status(
+ self,
+ ) -> dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]]:
+ """Get comprehensive status for all backends."""
+ status: dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]] = {}
+
+ for backend_type, storage in self.backends.items():
+ try:
+ collections = await storage.list_collections()
+ total_docs = 0
+ for collection in collections:
+ total_docs += await storage.count(collection_name=collection)
+
+ backend_status: dict[str, str | int | bool | StorageCapabilities] = {
+ "available": True,
+ "collections": len(collections),
+ "total_documents": total_docs,
+ "capabilities": self.capabilities.get(backend_type, StorageCapabilities.BASIC),
+ "endpoint": getattr(storage.config, "endpoint", "unknown"),
+ }
+ status[backend_type] = backend_status
+ except Exception as e:
+ status[backend_type] = {
+ "available": False,
+ "error": str(e),
+ "capabilities": StorageCapabilities.NONE,
+ }
+
+ return status
+
+ async def close_all(self) -> None:
+ """Close all storage connections."""
+ for storage in self.backends.values():
+ try:
+ await storage.close()
+ except Exception:
+ pass
+
+ self.backends.clear()
+ self.capabilities.clear()
+ self._initialized = False
+
+ @property
+ def is_initialized(self) -> bool:
+ """Check if storage manager is initialized."""
+ return self._initialized
+
+ def supports_advanced_features(self, backend: StorageBackend) -> bool:
+ """Check if backend supports advanced features like chunks and entities."""
+ return self.has_capability(backend, StorageCapabilities.FULL_FEATURED)
@@ -3542,9 +3332,7 @@ class ScrapeOptionsForm(Widget):
formats.append("screenshot")
options: dict[str, object] = {
"formats": formats,
- "only_main_content": self.query_one(
- "#only_main_content", Switch
- ).value,
+ "only_main_content": self.query_one("#only_main_content", Switch).value,
}
include_tags_input = self.query_one("#include_tags", Input).value
if include_tags_input.strip():
@@ -3579,11 +3367,15 @@ class ScrapeOptionsForm(Widget):
if include_tags := options.get("include_tags", []):
include_list = include_tags if isinstance(include_tags, list) else []
- self.query_one("#include_tags", Input).value = ", ".join(str(tag) for tag in include_list)
+ self.query_one("#include_tags", Input).value = ", ".join(
+ str(tag) for tag in include_list
+ )
if exclude_tags := options.get("exclude_tags", []):
exclude_list = exclude_tags if isinstance(exclude_tags, list) else []
- self.query_one("#exclude_tags", Input).value = ", ".join(str(tag) for tag in exclude_list)
+ self.query_one("#exclude_tags", Input).value = ", ".join(
+ str(tag) for tag in exclude_list
+ )
# Set performance
wait_for = options.get("wait_for")
@@ -3887,7 +3679,9 @@ class ExtractOptionsForm(Widget):
"content": "string",
"tags": ["string"]
}"""
- prompt_widget.text = "Extract article title, author, publication date, main content, and associated tags"
+ prompt_widget.text = (
+ "Extract article title, author, publication date, main content, and associated tags"
+ )
elif event.button.id == "preset_product":
schema_widget.text = """{
@@ -3897,7 +3691,9 @@ class ExtractOptionsForm(Widget):
"category": "string",
"availability": "string"
}"""
- prompt_widget.text = "Extract product name, price, description, category, and availability status"
+ prompt_widget.text = (
+ "Extract product name, price, description, category, and availability status"
+ )
elif event.button.id == "preset_contact":
schema_widget.text = """{
@@ -3907,7 +3703,9 @@ class ExtractOptionsForm(Widget):
"company": "string",
"position": "string"
}"""
- prompt_widget.text = "Extract contact information including name, email, phone, company, and position"
+ prompt_widget.text = (
+ "Extract contact information including name, email, phone, company, and position"
+ )
elif event.button.id == "preset_data":
schema_widget.text = """{
@@ -4160,10 +3958,17 @@ class ChunkViewer(Widget):
"id": str(chunk_data.get("id", "")),
"document_id": self.document_id,
"content": str(chunk_data.get("text", "")),
- "start_index": (lambda si: int(si) if isinstance(si, (int, str)) else 0)(chunk_data.get("start_index", 0)),
- "end_index": (lambda ei: int(ei) if isinstance(ei, (int, str)) else 0)(chunk_data.get("end_index", 0)),
+ "start_index": (lambda si: int(si) if isinstance(si, (int, str)) else 0)(
+ chunk_data.get("start_index", 0)
+ ),
+ "end_index": (lambda ei: int(ei) if isinstance(ei, (int, str)) else 0)(
+ chunk_data.get("end_index", 0)
+ ),
"metadata": (
- dict(metadata_val) if (metadata_val := chunk_data.get("metadata")) and isinstance(metadata_val, dict) else {}
+ dict(metadata_val)
+ if (metadata_val := chunk_data.get("metadata"))
+ and isinstance(metadata_val, dict)
+ else {}
),
}
self.chunks.append(chunk_info)
@@ -4336,10 +4141,10 @@ class EntityGraph(Widget):
"""Show detailed information about an entity."""
details_widget = self.query_one("#entity_details", Static)
- details_text = f"""**Entity:** {entity['name']}
-**Type:** {entity['type']}
-**Confidence:** {entity['confidence']:.2%}
-**ID:** {entity['id']}
+ details_text = f"""**Entity:** {entity["name"]}
+**Type:** {entity["type"]}
+**Confidence:** {entity["confidence"]:.2%}
+**ID:** {entity["id"]}
**Metadata:**
"""
@@ -4610,7 +4415,6 @@ else: # pragma: no cover - optional dependency fallback
R2RStorage = BaseStorage
-
class CollectionManagementApp(App[None]):
"""Enhanced modern Textual application with comprehensive keyboard navigation."""
@@ -4953,7 +4757,9 @@ class ResponsiveGrid(Container):
markup: bool = True,
) -> None:
"""Initialize responsive grid."""
- super().__init__(*children, name=name, id=id, classes=classes, disabled=disabled, markup=markup)
+ super().__init__(
+ *children, name=name, id=id, classes=classes, disabled=disabled, markup=markup
+ )
self._columns: int = columns
self._auto_fit: bool = auto_fit
self._compact: bool = compact
@@ -5223,7 +5029,9 @@ class CardLayout(ResponsiveGrid):
) -> None:
"""Initialize card layout with default settings for cards."""
# Default to auto-fit cards with minimum width
- super().__init__(auto_fit=True, name=name, id=id, classes=classes, disabled=disabled, markup=markup)
+ super().__init__(
+ auto_fit=True, name=name, id=id, classes=classes, disabled=disabled, markup=markup
+ )
class SplitPane(Container):
@@ -5292,7 +5100,9 @@ class SplitPane(Container):
if self._vertical:
cast(Widget, self).add_class("vertical")
- pane_classes = ("top-pane", "bottom-pane") if self._vertical else ("left-pane", "right-pane")
+ pane_classes = (
+ ("top-pane", "bottom-pane") if self._vertical else ("left-pane", "right-pane")
+ )
yield Container(self._left_content, classes=pane_classes[0])
yield Static("", classes="splitter")
@@ -5309,12 +5119,1457 @@ class SplitPane(Container):
widget.query_one(".right-pane").styles.width = f"{(1 - self._split_ratio) * 100}%"
+
+"""CLI interface for ingestion pipeline."""
+
+import asyncio
+from typing import Annotated
+
+import typer
+from pydantic import SecretStr
+from rich.console import Console
+from rich.panel import Panel
+from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
+from rich.table import Table
+
+from ..config import configure_prefect, get_settings
+from ..core.models import (
+ IngestionResult,
+ IngestionSource,
+ StorageBackend,
+ StorageConfig,
+)
+from ..flows.ingestion import create_ingestion_flow
+from ..flows.scheduler import create_scheduled_deployment, serve_deployments
+
+app = typer.Typer(
+ name="ingest",
+ help="🚀 Modern Document Ingestion Pipeline - Advanced web and repository processing",
+ rich_markup_mode="rich",
+ add_completion=False,
+)
+console = Console()
+
+
+@app.callback()
+def main(
+ version: Annotated[
+ bool, typer.Option("--version", "-v", help="Show version information")
+ ] = False,
+) -> None:
+ """
+ 🚀 Modern Document Ingestion Pipeline
+
+ [bold cyan]Advanced document processing and management platform[/bold cyan]
+
+ Features:
+ • 🌐 Web scraping and crawling with Firecrawl
+ • 📦 Repository ingestion with Repomix
+ • 🗄️ Multiple storage backends (Weaviate, OpenWebUI, R2R)
+ • 📊 Modern TUI for collection management
+ • ⚡ Async processing with Prefect orchestration
+ • 🎨 Rich CLI with enhanced visuals
+ """
+ settings = get_settings()
+ configure_prefect(settings)
+
+ if version:
+ console.print(
+ Panel(
+ (
+ "[bold magenta]Ingest Pipeline v0.1.0[/bold magenta]\n"
+ "[dim]Modern Document Ingestion & Management System[/dim]"
+ ),
+ title="🚀 Version Info",
+ border_style="magenta",
+ )
+ )
+ raise typer.Exit()
+
+
+@app.command()
+def ingest(
+ source_url: Annotated[str, typer.Argument(help="URL or path to ingest from")],
+ source_type: Annotated[
+ IngestionSource, typer.Option("--type", "-t", help="Type of source")
+ ] = IngestionSource.WEB,
+ storage: Annotated[
+ StorageBackend, typer.Option("--storage", "-s", help="Storage backend")
+ ] = StorageBackend.WEAVIATE,
+ collection: Annotated[
+ str | None,
+ typer.Option(
+ "--collection", "-c", help="Target collection name (auto-generated if not specified)"
+ ),
+ ] = None,
+ validate: Annotated[
+ bool, typer.Option("--validate/--no-validate", help="Validate source before ingesting")
+ ] = True,
+) -> None:
+ """
+ 🚀 Run a one-time ingestion job with enhanced progress tracking.
+
+ This command processes documents from various sources and stores them in
+ your chosen backend with full progress visualization.
+ """
+ # Enhanced startup message
+ console.print(
+ Panel(
+ (
+ f"[bold cyan]🚀 Starting Modern Ingestion[/bold cyan]\n\n"
+ f"[yellow]Source:[/yellow] {source_url}\n"
+ f"[yellow]Type:[/yellow] {source_type.value.title()}\n"
+ f"[yellow]Storage:[/yellow] {storage.value.replace('_', ' ').title()}\n"
+ f"[yellow]Collection:[/yellow] {collection or '[dim]Auto-generated[/dim]'}"
+ ),
+ title="🔥 Ingestion Configuration",
+ border_style="cyan",
+ )
+ )
+
+ async def run_with_progress() -> IngestionResult:
+ with Progress(
+ SpinnerColumn(),
+ TextColumn("[progress.description]{task.description}"),
+ BarColumn(),
+ TaskProgressColumn(),
+ console=console,
+ ) as progress:
+ task = progress.add_task("🔄 Processing documents...", total=100)
+
+ # Simulate progress updates during ingestion
+ progress.update(task, advance=20, description="🔗 Connecting to services...")
+ await asyncio.sleep(0.5)
+
+ progress.update(task, advance=30, description="📄 Fetching documents...")
+ result = await run_ingestion(
+ url=source_url,
+ source_type=source_type,
+ storage_backend=storage,
+ collection_name=collection,
+ validate_first=validate,
+ )
+
+ progress.update(task, advance=50, description="✅ Ingestion complete!")
+ return result
+
+ # Use asyncio.run() with proper event loop handling
+ try:
+ result = asyncio.run(run_with_progress())
+ except RuntimeError as e:
+ if "asyncio.run() cannot be called from a running event loop" in str(e):
+ # If we're already in an event loop (e.g., in Jupyter), use nest_asyncio
+ try:
+ import nest_asyncio
+
+ nest_asyncio.apply()
+ result = asyncio.run(run_with_progress())
+ except ImportError:
+ # Fallback: get the current loop and run the coroutine
+ loop = asyncio.get_event_loop()
+ result = loop.run_until_complete(run_with_progress())
+ else:
+ raise
+
+ # Enhanced results display
+ status_color = "green" if result.status.value == "completed" else "red"
+
+ # Create results table with enhanced styling
+ table = Table(
+ title="📊 Ingestion Results",
+ title_style="bold magenta",
+ border_style="cyan",
+ header_style="bold blue",
+ )
+ table.add_column("📋 Metric", style="cyan", no_wrap=True)
+ table.add_column("📈 Value", style=status_color, justify="right")
+
+ # Add enhanced status icon
+ status_icon = "✅" if result.status.value == "completed" else "❌"
+ table.add_row("Status", f"{status_icon} {result.status.value.title()}")
+
+ table.add_row("Documents Processed", f"📄 {result.documents_processed:,}")
+ table.add_row("Documents Failed", f"⚠️ {result.documents_failed:,}")
+ table.add_row("Duration", f"⏱️ {result.duration_seconds:.2f}s")
+
+ if result.error_messages:
+ error_text = "\n".join(f"❌ {error}" for error in result.error_messages[:3])
+ if len(result.error_messages) > 3:
+ error_text += f"\n... and {len(result.error_messages) - 3} more errors"
+ table.add_row("Errors", error_text)
+
+ console.print(table)
+
+ # Success celebration or error guidance
+ if result.status.value == "completed" and result.documents_processed > 0:
+ console.print(
+ Panel(
+ (
+ f"🎉 [bold green]Success![/bold green] {result.documents_processed} documents ingested\n\n"
+ f"💡 [dim]Try '[bold cyan]ingest modern[/bold cyan]' to explore your collections![/dim]"
+ ),
+ title="✨ Ingestion Complete",
+ border_style="green",
+ )
+ )
+ elif result.error_messages:
+ console.print(
+ Panel(
+ (
+ "❌ [bold red]Ingestion encountered errors[/bold red]\n\n"
+ "💡 [dim]Check your configuration and try again[/dim]"
+ ),
+ title="⚠️ Issues Detected",
+ border_style="red",
+ )
+ )
+
+
+@app.command()
+def schedule(
+ name: Annotated[str, typer.Argument(help="Deployment name")],
+ source_url: Annotated[str, typer.Argument(help="URL or path to ingest from")],
+ source_type: Annotated[
+ IngestionSource, typer.Option("--type", "-t", help="Type of source")
+ ] = IngestionSource.WEB,
+ storage: Annotated[
+ StorageBackend, typer.Option("--storage", "-s", help="Storage backend")
+ ] = StorageBackend.WEAVIATE,
+ cron: Annotated[
+ str | None, typer.Option("--cron", "-c", help="Cron expression for scheduling")
+ ] = None,
+ interval: Annotated[int, typer.Option("--interval", "-i", help="Interval in minutes")] = 60,
+ serve_now: Annotated[
+ bool, typer.Option("--serve/--no-serve", help="Start serving immediately")
+ ] = False,
+) -> None:
+ """
+ Create a scheduled deployment for recurring ingestion.
+ """
+ console.print(f"[bold blue]Creating deployment: {name}[/bold blue]")
+
+ deployment = create_scheduled_deployment(
+ name=name,
+ source_url=source_url,
+ source_type=source_type,
+ storage_backend=storage,
+ schedule_type="cron" if cron else "interval",
+ cron_expression=cron,
+ interval_minutes=interval,
+ )
+
+ console.print(f"[green]✓ Deployment '{name}' created[/green]")
+
+ if serve_now:
+ console.print("[yellow]Starting deployment server...[/yellow]")
+ serve_deployments([deployment])
+
+
+@app.command()
+def serve(
+ config_file: Annotated[
+ str | None, typer.Option("--config", "-c", help="Path to deployments config file")
+ ] = None,
+ ui: Annotated[
+ str | None, typer.Option("--ui", help="Launch user interface (options: tui, web)")
+ ] = None,
+) -> None:
+ """
+ 🚀 Serve configured deployments with optional UI interface.
+
+ Launch the deployment server to run scheduled ingestion jobs,
+ optionally with a modern Terminal User Interface (TUI) or web interface.
+ """
+ # Handle UI mode first
+ if ui == "tui":
+ console.print(
+ Panel(
+ (
+ "[bold cyan]🚀 Launching Enhanced TUI[/bold cyan]\n\n"
+ "[yellow]Features:[/yellow]\n"
+ "• 📊 Interactive collection management\n"
+ "• ⌨️ Enhanced keyboard navigation\n"
+ "• 🎨 Modern design with focus indicators\n"
+ "• 📄 Document browsing and search\n"
+ "• 🔄 Real-time status updates"
+ ),
+ title="🎉 TUI Mode",
+ border_style="cyan",
+ )
+ )
+ from .tui import dashboard
+
+ dashboard()
+ return
+ elif ui == "web":
+ console.print("[red]Web UI not yet implemented. Use --ui tui for Terminal UI.[/red]")
+ return
+ elif ui:
+ console.print(f"[red]Unknown UI option: {ui}[/red]")
+ console.print("[yellow]Available options: tui, web[/yellow]")
+ return
+
+ # Normal deployment server mode
+ if config_file:
+ # Load deployments from config
+ console.print(f"[yellow]Loading deployments from {config_file}[/yellow]")
+ # Implementation would load YAML/JSON config
+ else:
+ # Create example deployments
+ deployments = [
+ create_scheduled_deployment(
+ name="docs-daily",
+ source_url="https://docs.example.com",
+ source_type="documentation",
+ storage_backend="weaviate",
+ schedule_type="cron",
+ cron_expression="0 2 * * *", # Daily at 2 AM
+ ),
+ create_scheduled_deployment(
+ name="repo-hourly",
+ source_url="https://github.com/example/repo",
+ source_type="repository",
+ storage_backend="open_webui",
+ schedule_type="interval",
+ interval_minutes=60,
+ ),
+ ]
+
+ console.print(
+ "[bold green]Starting deployment server with example deployments[/bold green]"
+ )
+ serve_deployments(deployments)
+
+
+@app.command()
+def tui() -> None:
+ """
+ 🚀 Launch the enhanced Terminal User Interface.
+
+ Quick shortcut for 'serve --ui tui' with modern keyboard navigation,
+ interactive collection management, and real-time status updates.
+ """
+ console.print(
+ Panel(
+ (
+ "[bold cyan]🚀 Launching Enhanced TUI[/bold cyan]\n\n"
+ "[yellow]Features:[/yellow]\n"
+ "• 📊 Interactive collection management\n"
+ "• ⌨️ Enhanced keyboard navigation\n"
+ "• 🎨 Modern design with focus indicators\n"
+ "• 📄 Document browsing and search\n"
+ "• 🔄 Real-time status updates"
+ ),
+ title="🎉 TUI Mode",
+ border_style="cyan",
+ )
+ )
+ from .tui import dashboard
+
+ dashboard()
+
+
+@app.command()
+def config() -> None:
+ """
+ 📋 Display current configuration with enhanced formatting.
+
+ Shows all configured endpoints, models, and settings in a beautiful
+ table format with status indicators.
+ """
+ settings = get_settings()
+
+ console.print(
+ Panel(
+ (
+ "[bold cyan]⚙️ System Configuration[/bold cyan]\n"
+ "[dim]Current pipeline settings and endpoints[/dim]"
+ ),
+ title="🔧 Configuration",
+ border_style="cyan",
+ )
+ )
+
+ # Enhanced configuration table
+ table = Table(
+ title="📊 Configuration Details",
+ title_style="bold magenta",
+ border_style="blue",
+ header_style="bold cyan",
+ show_lines=True,
+ )
+ table.add_column("🏷️ Setting", style="cyan", no_wrap=True, width=25)
+ table.add_column("🎯 Value", style="yellow", overflow="fold")
+ table.add_column("📊 Status", style="green", width=12, justify="center")
+
+ # Add configuration rows with status indicators
+ def get_status_indicator(value: str | None) -> str:
+ return "✅ Set" if value else "❌ Missing"
+
+ table.add_row("🤖 LLM Endpoint", str(settings.llm_endpoint), "✅ Active")
+ table.add_row("🔥 Firecrawl Endpoint", str(settings.firecrawl_endpoint), "✅ Active")
+ table.add_row(
+ "🗄️ Weaviate Endpoint",
+ str(settings.weaviate_endpoint),
+ get_status_indicator(str(settings.weaviate_api_key) if settings.weaviate_api_key else None),
+ )
+ table.add_row(
+ "🌐 OpenWebUI Endpoint",
+ str(settings.openwebui_endpoint),
+ get_status_indicator(settings.openwebui_api_key),
+ )
+ table.add_row("🧠 Embedding Model", settings.embedding_model, "✅ Set")
+ table.add_row("💾 Default Storage", settings.default_storage_backend.title(), "✅ Set")
+ table.add_row("📦 Default Batch Size", f"{settings.default_batch_size:,}", "✅ Set")
+ table.add_row("⚡ Max Concurrent Tasks", f"{settings.max_concurrent_tasks}", "✅ Set")
+
+ console.print(table)
+
+ # Additional helpful information
+ console.print(
+ Panel(
+ (
+ "💡 [bold cyan]Quick Tips[/bold cyan]\n\n"
+ "• Use '[bold]ingest list-collections[/bold]' to view all collections\n"
+ "• Use '[bold]ingest search[/bold]' to search content\n"
+ "• Configure API keys in your [yellow].env[/yellow] file\n"
+ "• Default collection names are auto-generated from URLs"
+ ),
+ title="🚀 Usage Tips",
+ border_style="green",
+ )
+ )
+
+
+@app.command()
+def list_collections() -> None:
+ """
+ 📋 List all collections across storage backends.
+ """
+ console.print("[bold cyan]📚 Collection Overview[/bold cyan]")
+ asyncio.run(run_list_collections())
+
+
+@app.command()
+def search(
+ query: Annotated[str, typer.Argument(help="Search query")],
+ collection: Annotated[
+ str | None, typer.Option("--collection", "-c", help="Target collection")
+ ] = None,
+ backend: Annotated[
+ StorageBackend, typer.Option("--backend", "-b", help="Storage backend")
+ ] = StorageBackend.WEAVIATE,
+ limit: Annotated[int, typer.Option("--limit", "-l", help="Result limit")] = 10,
+) -> None:
+ """
+ 🔍 Search across collections.
+ """
+ console.print(f"[bold cyan]🔍 Searching for: {query}[/bold cyan]")
+ asyncio.run(run_search(query, collection, backend.value, limit))
+
+
+@app.command(name="blocks")
+def blocks_command() -> None:
+ """🧩 List and manage Prefect Blocks."""
+ console.print("[bold cyan]📦 Prefect Blocks Management[/bold cyan]")
+ console.print(
+ "Use 'prefect block register --module ingest_pipeline.core.models' to register custom blocks"
+ )
+ console.print("Use 'prefect block ls' to list available blocks")
+
+
+@app.command(name="variables")
+def variables_command() -> None:
+ """📊 Manage Prefect Variables."""
+ console.print("[bold cyan]📊 Prefect Variables Management[/bold cyan]")
+ console.print("Use 'prefect variable set VARIABLE_NAME value' to set variables")
+ console.print("Use 'prefect variable ls' to list variables")
+
+
+async def run_ingestion(
+ url: str,
+ source_type: IngestionSource,
+ storage_backend: StorageBackend,
+ collection_name: str | None = None,
+ validate_first: bool = True,
+) -> IngestionResult:
+ """
+ Run ingestion with support for targeted collections.
+ """
+ # Auto-generate collection name if not provided
+ if not collection_name:
+ from urllib.parse import urlparse
+
+ parsed = urlparse(url)
+ domain = parsed.netloc.replace(".", "_").replace("-", "_")
+ collection_name = f"{domain}_{source_type.value}"
+
+ return await create_ingestion_flow(
+ source_url=url,
+ source_type=source_type,
+ storage_backend=storage_backend,
+ collection_name=collection_name,
+ validate_first=validate_first,
+ )
+
+
+async def run_list_collections() -> None:
+ """
+ List collections across storage backends.
+ """
+ from ..config import get_settings
+ from ..core.models import StorageBackend
+ from ..storage.openwebui import OpenWebUIStorage
+ from ..storage.weaviate import WeaviateStorage
+
+ settings = get_settings()
+
+ console.print("🔍 [bold cyan]Scanning storage backends...[/bold cyan]")
+
+ # Try to connect to Weaviate
+ weaviate_collections: list[tuple[str, int]] = []
+ try:
+ weaviate_config = StorageConfig(
+ backend=StorageBackend.WEAVIATE,
+ endpoint=settings.weaviate_endpoint,
+ api_key=SecretStr(settings.weaviate_api_key)
+ if settings.weaviate_api_key is not None
+ else None,
+ collection_name="default",
+ )
+ weaviate = WeaviateStorage(weaviate_config)
+ await weaviate.initialize()
+
+ overview = await weaviate.describe_collections()
+ for item in overview:
+ name = str(item.get("name", "Unknown"))
+ count_val = item.get("count", 0)
+ count = int(count_val) if isinstance(count_val, (int, str)) else 0
+ weaviate_collections.append((name, count))
+ except Exception as e:
+ console.print(f"❌ [red]Weaviate connection failed: {e}[/red]")
+
+ # Try to connect to OpenWebUI
+ openwebui_collections: list[tuple[str, int]] = []
+ try:
+ openwebui_config = StorageConfig(
+ backend=StorageBackend.OPEN_WEBUI,
+ endpoint=settings.openwebui_endpoint,
+ api_key=SecretStr(settings.openwebui_api_key)
+ if settings.openwebui_api_key is not None
+ else None,
+ collection_name="default",
+ )
+ openwebui = OpenWebUIStorage(openwebui_config)
+ await openwebui.initialize()
+
+ overview = await openwebui.describe_collections()
+ for item in overview:
+ name = str(item.get("name", "Unknown"))
+ count_val = item.get("count", 0)
+ count = int(count_val) if isinstance(count_val, (int, str)) else 0
+ openwebui_collections.append((name, count))
+ except Exception as e:
+ console.print(f"❌ [red]OpenWebUI connection failed: {e}[/red]")
+
+ # Display results
+ if weaviate_collections or openwebui_collections:
+ # Create results table
+ from rich.table import Table
+
+ table = Table(
+ title="📚 Collection Overview",
+ title_style="bold magenta",
+ border_style="cyan",
+ header_style="bold blue",
+ )
+ table.add_column("🏷️ Collection", style="cyan", no_wrap=True)
+ table.add_column("📊 Backend", style="yellow")
+ table.add_column("📄 Documents", style="green", justify="right")
+
+ # Add Weaviate collections
+ for name, count in weaviate_collections:
+ table.add_row(name, "🗄️ Weaviate", f"{count:,}")
+
+ # Add OpenWebUI collections
+ for name, count in openwebui_collections:
+ table.add_row(name, "🌐 OpenWebUI", f"{count:,}")
+
+ console.print(table)
+ else:
+ console.print("❌ [yellow]No collections found in any backend[/yellow]")
+
+
+async def run_search(query: str, collection: str | None, backend: str, limit: int) -> None:
+ """
+ Search across collections.
+ """
+ from ..config import get_settings
+ from ..core.models import StorageBackend
+ from ..storage.weaviate import WeaviateStorage
+
+ settings = get_settings()
+
+ console.print(f"🔍 Searching for: '[bold cyan]{query}[/bold cyan]'")
+ if collection:
+ console.print(f"📚 Target collection: [yellow]{collection}[/yellow]")
+ console.print(f"💾 Backend: [blue]{backend}[/blue]")
+
+ results = []
+
+ try:
+ if backend == "weaviate":
+ weaviate_config = StorageConfig(
+ backend=StorageBackend.WEAVIATE,
+ endpoint=settings.weaviate_endpoint,
+ api_key=SecretStr(settings.weaviate_api_key)
+ if settings.weaviate_api_key is not None
+ else None,
+ collection_name=collection or "default",
+ )
+ weaviate = WeaviateStorage(weaviate_config)
+ await weaviate.initialize()
+
+ results_generator = weaviate.search(query, limit=limit)
+ async for doc in results_generator:
+ results.append(
+ {
+ "title": getattr(doc, "title", "Untitled"),
+ "content": getattr(doc, "content", ""),
+ "score": getattr(doc, "score", 0.0),
+ "backend": "🗄️ Weaviate",
+ }
+ )
+
+ elif backend == "open_webui":
+ console.print("❌ [red]OpenWebUI search not yet implemented[/red]")
+ return
+
+ except Exception as e:
+ console.print(f"❌ [red]Search failed: {e}[/red]")
+ return
+
+ # Display results
+ if results:
+ from rich.table import Table
+
+ table = Table(
+ title=f"🔍 Search Results for '{query}'",
+ title_style="bold magenta",
+ border_style="green",
+ header_style="bold blue",
+ )
+ table.add_column("📄 Title", style="cyan", max_width=40)
+ table.add_column("📝 Preview", style="white", max_width=60)
+ table.add_column("📊 Score", style="yellow", justify="right")
+
+ for result in results[:limit]:
+ title = str(result["title"])
+ title_display = title[:40] + "..." if len(title) > 40 else title
+
+ content = str(result["content"])
+ content_display = content[:60] + "..." if len(content) > 60 else content
+
+ score = f"{result['score']:.3f}"
+
+ table.add_row(title_display, content_display, score)
+
+ console.print(table)
+ console.print(f"\n✅ [green]Found {len(results)} results[/green]")
+ else:
+ console.print("❌ [yellow]No results found[/yellow]")
+
+
+if __name__ == "__main__":
+ app()
+
+
+
+"""Configuration management utilities."""
+
+from __future__ import annotations
+
+from contextlib import ExitStack
+
+from prefect.settings import Setting, temporary_settings
+
+
+# Import Prefect settings with version compatibility - avoid static analysis issues
+def _setup_prefect_settings() -> tuple[object, object, object]:
+ """Setup Prefect settings with proper fallbacks."""
+ try:
+ import prefect.settings as ps
+
+ # Try to get the settings directly
+ api_key = getattr(ps, "PREFECT_API_KEY", None)
+ api_url = getattr(ps, "PREFECT_API_URL", None)
+ work_pool = getattr(ps, "PREFECT_DEFAULT_WORK_POOL_NAME", None)
+
+ if api_key is not None:
+ return api_key, api_url, work_pool
+
+ # Fallback to registry-based approach
+ registry = getattr(ps, "PREFECT_SETTING_REGISTRY", None)
+ if registry is not None:
+ Setting = getattr(ps, "Setting", None)
+ if Setting is not None:
+ api_key = registry.get("PREFECT_API_KEY") or Setting(
+ "PREFECT_API_KEY", type_=str, default=None
+ )
+ api_url = registry.get("PREFECT_API_URL") or Setting(
+ "PREFECT_API_URL", type_=str, default=None
+ )
+ work_pool = registry.get("PREFECT_DEFAULT_WORK_POOL_NAME") or Setting(
+ "PREFECT_DEFAULT_WORK_POOL_NAME", type_=str, default=None
+ )
+ return api_key, api_url, work_pool
+
+ except ImportError:
+ pass
+
+ # Ultimate fallback
+ return None, None, None
+
+
+PREFECT_API_KEY, PREFECT_API_URL, PREFECT_DEFAULT_WORK_POOL_NAME = _setup_prefect_settings()
+
+# Import after Prefect settings setup to avoid circular dependencies
+from .settings import Settings, get_settings # noqa: E402
+
+__all__ = ["Settings", "get_settings", "configure_prefect"]
+
+_prefect_settings_stack: ExitStack | None = None
+
+
+def configure_prefect(settings: Settings) -> None:
+ """Apply Prefect settings from the application configuration."""
+ global _prefect_settings_stack
+
+ overrides: dict[Setting, str] = {}
+
+ if (
+ settings.prefect_api_url is not None
+ and PREFECT_API_URL is not None
+ and isinstance(PREFECT_API_URL, Setting)
+ ):
+ overrides[PREFECT_API_URL] = str(settings.prefect_api_url)
+ if (
+ settings.prefect_api_key
+ and PREFECT_API_KEY is not None
+ and isinstance(PREFECT_API_KEY, Setting)
+ ):
+ overrides[PREFECT_API_KEY] = settings.prefect_api_key
+ if (
+ settings.prefect_work_pool
+ and PREFECT_DEFAULT_WORK_POOL_NAME is not None
+ and isinstance(PREFECT_DEFAULT_WORK_POOL_NAME, Setting)
+ ):
+ overrides[PREFECT_DEFAULT_WORK_POOL_NAME] = settings.prefect_work_pool
+
+ if not overrides:
+ return
+
+ filtered_overrides = {
+ setting: value for setting, value in overrides.items() if setting.value() != value
+ }
+
+ if not filtered_overrides:
+ return
+
+ new_stack = ExitStack()
+ new_stack.enter_context(temporary_settings(updates=filtered_overrides))
+
+ if _prefect_settings_stack is not None:
+ _prefect_settings_stack.close()
+
+ _prefect_settings_stack = new_stack
+
+
+
+"""Base ingestor interface."""
+
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from collections.abc import AsyncGenerator
+from typing import TYPE_CHECKING
+
+from ..core.models import Document, IngestionJob
+
+if TYPE_CHECKING:
+ from ..storage.base import BaseStorage
+
+
+class BaseIngestor(ABC):
+ """Abstract base class for all ingestors."""
+
+ @abstractmethod
+ def ingest(self, job: IngestionJob) -> AsyncGenerator[Document, None]:
+ """
+ Ingest data from a source.
+
+ Args:
+ job: The ingestion job configuration
+
+ Yields:
+ Documents from the source
+ """
+ ... # pragma: no cover
+
+ @abstractmethod
+ async def validate_source(self, source_url: str) -> bool:
+ """
+ Validate if the source is accessible.
+
+ Args:
+ source_url: URL or path to the source
+
+ Returns:
+ True if source is valid and accessible
+ """
+ pass # pragma: no cover
+
+ @abstractmethod
+ async def estimate_size(self, source_url: str) -> int:
+ """
+ Estimate the number of documents in the source.
+
+ Args:
+ source_url: URL or path to the source
+
+ Returns:
+ Estimated number of documents
+ """
+ pass # pragma: no cover
+
+ async def ingest_with_dedup(
+ self,
+ job: IngestionJob,
+ storage_client: BaseStorage,
+ *,
+ collection_name: str | None = None,
+ stale_after_days: int = 30,
+ ) -> AsyncGenerator[Document, None]:
+ """
+ Ingest documents with duplicate detection (optional optimization).
+
+ Default implementation falls back to regular ingestion.
+ Subclasses can override to provide optimized deduplication.
+
+ Args:
+ job: The ingestion job configuration
+ storage_client: Storage client to check for existing documents
+ collection_name: Collection to check for duplicates
+ stale_after_days: Consider documents stale after this many days
+
+ Yields:
+ Documents from the source (with deduplication if implemented)
+ """
+ # Default implementation: fall back to regular ingestion
+ async for document in self.ingest(job):
+ yield document
+
+
+
+"""Document management screen with enhanced navigation."""
+
+from datetime import datetime
+
+from textual.app import ComposeResult
+from textual.binding import Binding
+from textual.containers import Container, Horizontal, ScrollableContainer
+from textual.screen import ModalScreen, Screen
+from textual.widgets import Button, Footer, Header, Label, LoadingIndicator, Markdown, Static
+from typing_extensions import override
+
+from ....storage.base import BaseStorage
+from ..models import CollectionInfo, DocumentInfo
+from ..widgets import EnhancedDataTable
+
+
+class DocumentManagementScreen(Screen[None]):
+ """Screen for managing documents within a collection with enhanced keyboard navigation."""
+
+ collection: CollectionInfo
+ storage: BaseStorage | None
+ documents: list[DocumentInfo]
+ selected_docs: set[str]
+ current_offset: int
+ page_size: int
+
+ BINDINGS = [
+ Binding("escape", "app.pop_screen", "Back"),
+ Binding("r", "refresh", "Refresh"),
+ Binding("v", "view_document", "View"),
+ Binding("delete", "delete_selected", "Delete Selected"),
+ Binding("a", "select_all", "Select All"),
+ Binding("ctrl+a", "select_all", "Select All"),
+ Binding("n", "select_none", "Clear Selection"),
+ Binding("ctrl+shift+a", "select_none", "Clear Selection"),
+ Binding("space", "toggle_selection", "Toggle Selection"),
+ Binding("ctrl+d", "delete_selected", "Delete Selected"),
+ Binding("pageup", "prev_page", "Previous Page"),
+ Binding("pagedown", "next_page", "Next Page"),
+ Binding("home", "first_page", "First Page"),
+ Binding("end", "last_page", "Last Page"),
+ ]
+
+ def __init__(self, collection: CollectionInfo, storage: BaseStorage | None):
+ super().__init__()
+ self.collection = collection
+ self.storage = storage
+ self.documents: list[DocumentInfo] = []
+ self.selected_docs: set[str] = set()
+ self.current_offset = 0
+ self.page_size = 50
+
+ @override
+ def compose(self) -> ComposeResult:
+ yield Header()
+ yield Container(
+ Static(f"📄 Document Management: {self.collection['name']}", classes="title"),
+ Static(
+ f"Total Documents: {self.collection['count']:,} | Use Space to select, Delete to remove",
+ classes="subtitle",
+ ),
+ Label(f"Page size: {self.page_size} documents"),
+ EnhancedDataTable(id="documents_table", classes="enhanced-table"),
+ Horizontal(
+ Button("🔄 Refresh", id="refresh_docs_btn", variant="primary"),
+ Button("🗑️ Delete Selected", id="delete_selected_btn", variant="error"),
+ Button("✅ Select All", id="select_all_btn", variant="default"),
+ Button("❌ Clear Selection", id="clear_selection_btn", variant="default"),
+ Button("⬅️ Previous Page", id="prev_page_btn", variant="default"),
+ Button("➡️ Next Page", id="next_page_btn", variant="default"),
+ classes="button_bar",
+ ),
+ Label("", id="selection_status"),
+ Static("", id="page_info", classes="status-text"),
+ LoadingIndicator(id="loading"),
+ classes="main_container",
+ )
+ yield Footer()
+
+ async def on_mount(self) -> None:
+ """Initialize the screen."""
+ self.query_one("#loading").display = False
+
+ # Setup documents table with enhanced columns
+ table = self.query_one("#documents_table", EnhancedDataTable)
+ table.add_columns(
+ "✓", "Title", "Source URL", "Description", "Type", "Words", "Timestamp", "ID"
+ )
+
+ # Set up message handling for table events
+ table.can_focus = True
+
+ await self.load_documents()
+
+ async def load_documents(self) -> None:
+ """Load documents from the collection."""
+ loading = self.query_one("#loading")
+ loading.display = True
+
+ try:
+ if self.storage:
+ # Try to load documents using the storage backend
+ try:
+ raw_docs = await self.storage.list_documents(
+ limit=self.page_size,
+ offset=self.current_offset,
+ collection_name=self.collection["name"],
+ )
+ # Cast to proper type with type checking
+ self.documents = [
+ DocumentInfo(
+ id=str(doc.get("id", f"doc_{i}")),
+ title=str(doc.get("title", "Untitled Document")),
+ source_url=str(doc.get("source_url", "")),
+ description=str(doc.get("description", "")),
+ content_type=str(doc.get("content_type", "text/plain")),
+ content_preview=str(doc.get("content_preview", "")),
+ word_count=(
+ lambda wc_val: int(wc_val)
+ if isinstance(wc_val, (int, str)) and str(wc_val).isdigit()
+ else 0
+ )(doc.get("word_count", 0)),
+ timestamp=str(doc.get("timestamp", "")),
+ )
+ for i, doc in enumerate(raw_docs)
+ ]
+ except NotImplementedError:
+ # For storage backends that don't support document listing, show a message
+ self.notify(
+ f"Document listing not supported for {self.storage.__class__.__name__}",
+ severity="information",
+ )
+ self.documents = []
+
+ await self.update_table()
+ self.update_selection_status()
+ self.update_page_info()
+
+ except Exception as e:
+ self.notify(f"Error loading documents: {e}", severity="error", markup=False)
+ finally:
+ loading.display = False
+
+ async def update_table(self) -> None:
+ """Update the documents table with enhanced metadata display."""
+ table = self.query_one("#documents_table", EnhancedDataTable)
+ table.clear(columns=True)
+
+ # Add enhanced columns with more metadata
+ table.add_columns(
+ "✓", "Title", "Source URL", "Description", "Type", "Words", "Timestamp", "ID"
+ )
+
+ # Add rows with enhanced metadata
+ for doc in self.documents:
+ selected = "✓" if doc["id"] in self.selected_docs else ""
+
+ # Get additional metadata from the raw docs
+ description = str(doc.get("description") or "").strip()[:40]
+ if not description:
+ description = "[dim]No description[/dim]"
+ elif len(str(doc.get("description") or "")) > 40:
+ description += "..."
+
+ # Format content type with appropriate icon
+ content_type = doc.get("content_type", "text/plain")
+ if "markdown" in content_type.lower():
+ type_display = "📝 md"
+ elif "html" in content_type.lower():
+ type_display = "🌐 html"
+ elif "text" in content_type.lower():
+ type_display = "📄 txt"
+ else:
+ type_display = f"📄 {content_type.split('/')[-1][:5]}"
+
+ # Format timestamp to be more readable
+ timestamp = doc.get("timestamp", "")
+ if timestamp:
+ try:
+ # Parse ISO format timestamp
+ dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
+ timestamp = dt.strftime("%m/%d %H:%M")
+ except Exception:
+ timestamp = str(timestamp)[:16] # Fallback
+ table.add_row(
+ selected,
+ doc.get("title", "Untitled")[:40],
+ doc.get("source_url", "")[:35],
+ description,
+ type_display,
+ str(doc.get("word_count", 0)),
+ timestamp,
+ doc["id"][:8] + "...", # Show truncated ID
+ )
+
+ def update_selection_status(self) -> None:
+ """Update the selection status label."""
+ status_label = self.query_one("#selection_status", Label)
+ total_selected = len(self.selected_docs)
+ status_label.update(f"Selected: {total_selected} documents")
+
+ def update_page_info(self) -> None:
+ """Update the page information."""
+ page_info = self.query_one("#page_info", Static)
+ total_docs = self.collection["count"]
+ start = self.current_offset + 1
+ end = min(self.current_offset + len(self.documents), total_docs)
+ page_num = (self.current_offset // self.page_size) + 1
+ total_pages = (total_docs + self.page_size - 1) // self.page_size
+
+ page_info.update(
+ f"Showing {start:,}-{end:,} of {total_docs:,} documents (Page {page_num} of {total_pages})"
+ )
+
+ def get_current_document(self) -> DocumentInfo | None:
+ """Get the currently selected document."""
+ table = self.query_one("#documents_table", EnhancedDataTable)
+ try:
+ if 0 <= table.cursor_coordinate.row < len(self.documents):
+ return self.documents[table.cursor_coordinate.row]
+ except (AttributeError, IndexError):
+ pass
+ return None
+
+ # Action methods
+ def action_refresh(self) -> None:
+ """Refresh the document list."""
+ self.run_worker(self.load_documents())
+
+ def action_toggle_selection(self) -> None:
+ """Toggle selection of current row."""
+ if doc := self.get_current_document():
+ doc_id = doc["id"]
+ if doc_id in self.selected_docs:
+ self.selected_docs.remove(doc_id)
+ else:
+ self.selected_docs.add(doc_id)
+
+ self.run_worker(self.update_table())
+ self.update_selection_status()
+
+ def action_select_all(self) -> None:
+ """Select all documents on current page."""
+ for doc in self.documents:
+ self.selected_docs.add(doc["id"])
+ self.run_worker(self.update_table())
+ self.update_selection_status()
+
+ def action_select_none(self) -> None:
+ """Clear all selections."""
+ self.selected_docs.clear()
+ self.run_worker(self.update_table())
+ self.update_selection_status()
+
+ def action_delete_selected(self) -> None:
+ """Delete selected documents."""
+ if self.selected_docs:
+ from .dialogs import ConfirmDocumentDeleteScreen
+
+ self.app.push_screen(
+ ConfirmDocumentDeleteScreen(list(self.selected_docs), self.collection, self)
+ )
+ else:
+ self.notify("No documents selected", severity="warning")
+
+ def action_next_page(self) -> None:
+ """Go to next page."""
+ if self.current_offset + self.page_size < self.collection["count"]:
+ self.current_offset += self.page_size
+ self.run_worker(self.load_documents())
+
+ def action_prev_page(self) -> None:
+ """Go to previous page."""
+ if self.current_offset >= self.page_size:
+ self.current_offset -= self.page_size
+ self.run_worker(self.load_documents())
+
+ def action_first_page(self) -> None:
+ """Go to first page."""
+ if self.current_offset > 0:
+ self.current_offset = 0
+ self.run_worker(self.load_documents())
+
+ def action_last_page(self) -> None:
+ """Go to last page."""
+ total_docs = self.collection["count"]
+ last_offset = ((total_docs - 1) // self.page_size) * self.page_size
+ if self.current_offset != last_offset:
+ self.current_offset = last_offset
+ self.run_worker(self.load_documents())
+
+ def on_button_pressed(self, event: Button.Pressed) -> None:
+ """Handle button presses."""
+ if event.button.id == "refresh_docs_btn":
+ self.action_refresh()
+ elif event.button.id == "delete_selected_btn":
+ self.action_delete_selected()
+ elif event.button.id == "select_all_btn":
+ self.action_select_all()
+ elif event.button.id == "clear_selection_btn":
+ self.action_select_none()
+ elif event.button.id == "next_page_btn":
+ self.action_next_page()
+ elif event.button.id == "prev_page_btn":
+ self.action_prev_page()
+
+ def on_enhanced_data_table_row_toggled(self, event: EnhancedDataTable.RowToggled) -> None:
+ """Handle row toggle from enhanced table."""
+ if 0 <= event.row_index < len(self.documents):
+ doc = self.documents[event.row_index]
+ doc_id = doc["id"]
+
+ if doc_id in self.selected_docs:
+ self.selected_docs.remove(doc_id)
+ else:
+ self.selected_docs.add(doc_id)
+
+ self.run_worker(self.update_table())
+ self.update_selection_status()
+
+ def on_enhanced_data_table_select_all(self, event: EnhancedDataTable.SelectAll) -> None:
+ """Handle select all from enhanced table."""
+ self.action_select_all()
+
+ def on_enhanced_data_table_clear_selection(
+ self, event: EnhancedDataTable.ClearSelection
+ ) -> None:
+ """Handle clear selection from enhanced table."""
+ self.action_select_none()
+
+ def action_view_document(self) -> None:
+ """View the content of the currently selected document."""
+ if doc := self.get_current_document():
+ if self.storage:
+ self.app.push_screen(
+ DocumentContentModal(doc, self.storage, self.collection["name"])
+ )
+ else:
+ self.notify("No storage backend available", severity="error")
+ else:
+ self.notify("No document selected", severity="warning")
+
+
+class DocumentContentModal(ModalScreen[None]):
+ """Modal screen for viewing document content."""
+
+ DEFAULT_CSS = """
+ DocumentContentModal {
+ align: center middle;
+ }
+
+ DocumentContentModal > Container {
+ width: 90%;
+ height: 85%;
+ background: $surface;
+ border: thick $primary;
+ }
+
+ DocumentContentModal .modal-header {
+ background: $primary;
+ color: $text;
+ padding: 1;
+ dock: top;
+ height: 3;
+ }
+
+ DocumentContentModal .modal-content {
+ padding: 1;
+ height: 1fr;
+ }
+ """
+
+ BINDINGS = [
+ Binding("escape", "app.pop_screen", "Close"),
+ Binding("q", "app.pop_screen", "Close"),
+ ]
+
+ def __init__(self, document: DocumentInfo, storage: BaseStorage, collection_name: str):
+ super().__init__()
+ self.document = document
+ self.storage = storage
+ self.collection_name = collection_name
+
+ def compose(self) -> ComposeResult:
+ yield Container(
+ Static(
+ f"📄 Document: {self.document['title'][:60]}{'...' if len(self.document['title']) > 60 else ''}",
+ classes="modal-header",
+ ),
+ ScrollableContainer(
+ Markdown("Loading document content...", id="document_content"),
+ LoadingIndicator(id="content_loading"),
+ classes="modal-content",
+ ),
+ )
+
+ async def on_mount(self) -> None:
+ """Load and display the document content."""
+ content_widget = self.query_one("#document_content", Markdown)
+ loading = self.query_one("#content_loading")
+
+ try:
+ # Get full document content
+ doc_content = await self.storage.retrieve(
+ self.document["id"], collection_name=self.collection_name
+ )
+
+ # Format content for display
+ if isinstance(doc_content, str):
+ formatted_content = f"""# {self.document["title"]}
+
+**Source:** {self.document.get("source_url", "N/A")}
+**Type:** {self.document.get("content_type", "text/plain")}
+**Words:** {self.document.get("word_count", 0):,}
+**Timestamp:** {self.document.get("timestamp", "N/A")}
+
+---
+
+{doc_content}
+"""
+ else:
+ formatted_content = f"""# {self.document["title"]}
+
+**Source:** {self.document.get("source_url", "N/A")}
+**Type:** {self.document.get("content_type", "text/plain")}
+**Words:** {self.document.get("word_count", 0):,}
+**Timestamp:** {self.document.get("timestamp", "N/A")}
+
+---
+
+*Content format not supported for display*
+"""
+
+ content_widget.update(formatted_content)
+
+ except Exception as e:
+ content_widget.update(
+ f"# Error Loading Document\n\nFailed to load document content: {e}"
+ )
+ finally:
+ loading.display = False
+
+
+
+"""TUI runner functions and initialization."""
+
+from __future__ import annotations
+
+import asyncio
+import logging
+from logging import Logger
+from logging.handlers import QueueHandler, RotatingFileHandler
+from pathlib import Path
+from queue import Queue
+from typing import NamedTuple
+
+import platformdirs
+
+from ....config import configure_prefect, get_settings
+from .storage_manager import StorageManager
+
+
+class _TuiLoggingContext(NamedTuple):
+ """Container describing configured logging outputs for the TUI."""
+
+ queue: Queue[logging.LogRecord]
+ formatter: logging.Formatter
+ log_file: Path | None
+
+
+_logging_context: _TuiLoggingContext | None = None
+
+
+def _configure_tui_logging(*, log_level: str) -> _TuiLoggingContext:
+ """Configure logging so that messages do not break the TUI output."""
+
+ global _logging_context
+ if _logging_context is not None:
+ return _logging_context
+
+ resolved_level = getattr(logging, log_level.upper(), logging.INFO)
+ log_queue: Queue[logging.LogRecord] = Queue()
+ formatter = logging.Formatter(
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
+
+ root_logger = logging.getLogger()
+ root_logger.setLevel(resolved_level)
+
+ # Remove existing stream handlers to prevent console flicker inside the TUI
+ for handler in list(root_logger.handlers):
+ root_logger.removeHandler(handler)
+
+ queue_handler = QueueHandler(log_queue)
+ queue_handler.setLevel(resolved_level)
+ root_logger.addHandler(queue_handler)
+
+ log_file: Path | None = None
+ try:
+ # Try current directory first for development
+ log_dir = Path.cwd() / "logs"
+ log_dir.mkdir(parents=True, exist_ok=True)
+ log_file = log_dir / "tui.log"
+ except OSError:
+ # Fall back to user log directory
+ try:
+ log_dir = Path(platformdirs.user_log_dir("ingest-pipeline", "ingest-pipeline"))
+ log_dir.mkdir(parents=True, exist_ok=True)
+ log_file = log_dir / "tui.log"
+ except OSError as exc:
+ fallback = logging.getLogger(__name__)
+ fallback.warning("Failed to create log directory, file logging disabled: %s", exc)
+ log_file = None
+
+ if log_file:
+ try:
+ file_handler = RotatingFileHandler(
+ log_file,
+ maxBytes=2_000_000,
+ backupCount=5,
+ encoding="utf-8",
+ )
+ file_handler.setLevel(resolved_level)
+ file_handler.setFormatter(formatter)
+ root_logger.addHandler(file_handler)
+ except OSError as exc:
+ fallback = logging.getLogger(__name__)
+ fallback.warning("Failed to configure file logging for TUI: %s", exc)
+ log_file = None
+
+ _logging_context = _TuiLoggingContext(log_queue, formatter, log_file)
+ return _logging_context
+
+
+LOGGER: Logger = logging.getLogger(__name__)
+
+
+async def run_textual_tui() -> None:
+ """Run the enhanced modern TUI with better error handling and initialization."""
+ settings = get_settings()
+ configure_prefect(settings)
+
+ logging_context = _configure_tui_logging(log_level=settings.log_level)
+
+ LOGGER.info("Initializing collection management TUI")
+ LOGGER.info("Scanning available storage backends")
+
+ # Create storage manager without initialization - let TUI handle it asynchronously
+ storage_manager = StorageManager(settings)
+
+ LOGGER.info("Launching TUI - storage backends will initialize in background")
+
+ # Import here to avoid circular import
+ from ..app import CollectionManagementApp
+
+ app = CollectionManagementApp(
+ storage_manager,
+ None, # weaviate - will be available after initialization
+ None, # openwebui - will be available after initialization
+ None, # r2r_backend - will be available after initialization
+ log_queue=logging_context.queue,
+ log_formatter=logging_context.formatter,
+ log_file=logging_context.log_file,
+ )
+ try:
+ await app.run_async()
+ finally:
+ LOGGER.info("Shutting down storage connections")
+ await storage_manager.close_all()
+ LOGGER.info("All storage connections closed gracefully")
+
+
+def dashboard() -> None:
+ """Launch the modern collection dashboard."""
+ asyncio.run(run_textual_tui())
+
+
"""Comprehensive theming system for TUI applications with WCAG AA accessibility compliance."""
from dataclasses import dataclass
from enum import Enum
-from typing import Any
+from typing import Protocol
+
+from textual.app import App
+
+# Type alias for Textual apps with unknown return type
+TextualApp = App[object]
+
+
+class AppProtocol(Protocol):
+ """Protocol for apps that support CSS and refresh."""
+
+ CSS: str
+
+ def refresh(self) -> None:
+ """Refresh the app."""
+ ...
class ThemeType(Enum):
@@ -5493,8 +6748,8 @@ class ThemeManager:
"""Manages theme selection and CSS generation."""
def __init__(self, default_theme: ThemeType = ThemeType.DARK):
- self.current_theme = default_theme
- self._themes = {
+ self.current_theme: ThemeType = default_theme
+ self._themes: dict[ThemeType, ColorPalette] = {
ThemeType.DARK: ThemeRegistry.get_enhanced_dark(),
ThemeType.LIGHT: ThemeRegistry.get_light(),
ThemeType.HIGH_CONTRAST: ThemeRegistry.get_high_contrast(),
@@ -6418,30 +7673,28 @@ def get_css_for_theme(theme_type: ThemeType) -> str:
return css
-def apply_theme_to_app(app: object, theme_type: ThemeType) -> None:
+def apply_theme_to_app(app: TextualApp | AppProtocol, theme_type: ThemeType) -> None:
"""Apply a theme to a Textual app instance."""
try:
- css = set_theme(theme_type)
- if hasattr(app, "stylesheet"):
- app.stylesheet.clear()
- app.stylesheet.parse(css)
- elif hasattr(app, "CSS"):
- setattr(app, "CSS", css)
- elif hasattr(app, "refresh"):
- # Fallback: try to refresh the app with new CSS
+ # Note: CSS class variable cannot be changed at runtime
+ # This function would need to be called during app initialization
+ # or implement a different approach for dynamic theming
+ _ = set_theme(theme_type) # Keep for future implementation
+ if hasattr(app, "refresh"):
app.refresh()
except Exception as e:
# Graceful fallback - log but don't crash the UI
import logging
+
logging.debug(f"Failed to apply theme to app: {e}")
class ThemeSwitcher:
"""Helper class for managing theme switching in TUI applications."""
- def __init__(self, app: object | None = None) -> None:
- self.app = app
- self.theme_history = [ThemeType.DARK]
+ def __init__(self, app: TextualApp | AppProtocol | None = None) -> None:
+ self.app: TextualApp | AppProtocol | None = app
+ self.theme_history: list[ThemeType] = [ThemeType.DARK]
def switch_theme(self, theme_type: ThemeType) -> str:
"""Switch to a new theme and apply it to the app if available."""
@@ -6469,7 +7722,7 @@ class ThemeSwitcher:
next_theme = themes[(current_index + 1) % len(themes)]
return self.switch_theme(next_theme)
- def get_theme_info(self) -> dict[str, Any]:
+ def get_theme_info(self) -> dict[str, str | list[str] | dict[str, str]]:
"""Get information about the current theme."""
palette = get_theme_palette()
return {
@@ -6486,11 +7739,11 @@ class ThemeSwitcher:
# Responsive breakpoints for dynamic layout adaptation
RESPONSIVE_BREAKPOINTS = {
- "xs": 40, # Extra small terminals
- "sm": 60, # Small terminals
- "md": 100, # Medium terminals
- "lg": 140, # Large terminals
- "xl": 180, # Extra large terminals
+ "xs": 40, # Extra small terminals
+ "sm": 60, # Small terminals
+ "md": 100, # Medium terminals
+ "lg": 140, # Large terminals
+ "xl": 180, # Extra large terminals
}
@@ -7027,744 +8280,772 @@ def apply_responsive_theme() -> str:
return f"{custom_properties}\n{base_css}\n{responsive_css}"
-
-"""CLI interface for ingestion pipeline."""
-
-import asyncio
-from typing import Annotated
-
-import typer
-from pydantic import SecretStr
-from rich.console import Console
-from rich.panel import Panel
-from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
-from rich.table import Table
-
-from ..config import configure_prefect, get_settings
-from ..core.models import (
- IngestionResult,
- IngestionSource,
- StorageBackend,
- StorageConfig,
-)
-from ..flows.ingestion import create_ingestion_flow
-from ..flows.scheduler import create_scheduled_deployment, serve_deployments
-
-app = typer.Typer(
- name="ingest",
- help="🚀 Modern Document Ingestion Pipeline - Advanced web and repository processing",
- rich_markup_mode="rich",
- add_completion=False,
-)
-console = Console()
-
-
-@app.callback()
-def main(
- version: Annotated[
- bool, typer.Option("--version", "-v", help="Show version information")
- ] = False,
-) -> None:
- """
- 🚀 Modern Document Ingestion Pipeline
-
- [bold cyan]Advanced document processing and management platform[/bold cyan]
-
- Features:
- • 🌐 Web scraping and crawling with Firecrawl
- • 📦 Repository ingestion with Repomix
- • 🗄️ Multiple storage backends (Weaviate, OpenWebUI, R2R)
- • 📊 Modern TUI for collection management
- • ⚡ Async processing with Prefect orchestration
- • 🎨 Rich CLI with enhanced visuals
- """
- settings = get_settings()
- configure_prefect(settings)
-
- if version:
- console.print(
- Panel(
- (
- "[bold magenta]Ingest Pipeline v0.1.0[/bold magenta]\n"
- "[dim]Modern Document Ingestion & Management System[/dim]"
- ),
- title="🚀 Version Info",
- border_style="magenta",
- )
- )
- raise typer.Exit()
-
-
-@app.command()
-def ingest(
- source_url: Annotated[str, typer.Argument(help="URL or path to ingest from")],
- source_type: Annotated[
- IngestionSource, typer.Option("--type", "-t", help="Type of source")
- ] = IngestionSource.WEB,
- storage: Annotated[
- StorageBackend, typer.Option("--storage", "-s", help="Storage backend")
- ] = StorageBackend.WEAVIATE,
- collection: Annotated[
- str | None,
- typer.Option(
- "--collection", "-c", help="Target collection name (auto-generated if not specified)"
- ),
- ] = None,
- validate: Annotated[
- bool, typer.Option("--validate/--no-validate", help="Validate source before ingesting")
- ] = True,
-) -> None:
- """
- 🚀 Run a one-time ingestion job with enhanced progress tracking.
-
- This command processes documents from various sources and stores them in
- your chosen backend with full progress visualization.
- """
- # Enhanced startup message
- console.print(
- Panel(
- (
- f"[bold cyan]🚀 Starting Modern Ingestion[/bold cyan]\n\n"
- f"[yellow]Source:[/yellow] {source_url}\n"
- f"[yellow]Type:[/yellow] {source_type.value.title()}\n"
- f"[yellow]Storage:[/yellow] {storage.value.replace('_', ' ').title()}\n"
- f"[yellow]Collection:[/yellow] {collection or '[dim]Auto-generated[/dim]'}"
- ),
- title="🔥 Ingestion Configuration",
- border_style="cyan",
- )
- )
-
- async def run_with_progress() -> IngestionResult:
- with Progress(
- SpinnerColumn(),
- TextColumn("[progress.description]{task.description}"),
- BarColumn(),
- TaskProgressColumn(),
- console=console,
- ) as progress:
- task = progress.add_task("🔄 Processing documents...", total=100)
-
- # Simulate progress updates during ingestion
- progress.update(task, advance=20, description="🔗 Connecting to services...")
- await asyncio.sleep(0.5)
-
- progress.update(task, advance=30, description="📄 Fetching documents...")
- result = await run_ingestion(
- url=source_url,
- source_type=source_type,
- storage_backend=storage,
- collection_name=collection,
- validate_first=validate,
- )
-
- progress.update(task, advance=50, description="✅ Ingestion complete!")
- return result
-
- # Use asyncio.run() with proper event loop handling
- try:
- result = asyncio.run(run_with_progress())
- except RuntimeError as e:
- if "asyncio.run() cannot be called from a running event loop" in str(e):
- # If we're already in an event loop (e.g., in Jupyter), use nest_asyncio
- try:
- import nest_asyncio
- nest_asyncio.apply()
- result = asyncio.run(run_with_progress())
- except ImportError:
- # Fallback: get the current loop and run the coroutine
- loop = asyncio.get_event_loop()
- result = loop.run_until_complete(run_with_progress())
- else:
- raise
-
- # Enhanced results display
- status_color = "green" if result.status.value == "completed" else "red"
-
- # Create results table with enhanced styling
- table = Table(
- title="📊 Ingestion Results",
- title_style="bold magenta",
- border_style="cyan",
- header_style="bold blue",
- )
- table.add_column("📋 Metric", style="cyan", no_wrap=True)
- table.add_column("📈 Value", style=status_color, justify="right")
-
- # Add enhanced status icon
- status_icon = "✅" if result.status.value == "completed" else "❌"
- table.add_row("Status", f"{status_icon} {result.status.value.title()}")
-
- table.add_row("Documents Processed", f"📄 {result.documents_processed:,}")
- table.add_row("Documents Failed", f"⚠️ {result.documents_failed:,}")
- table.add_row("Duration", f"⏱️ {result.duration_seconds:.2f}s")
-
- if result.error_messages:
- error_text = "\n".join(f"❌ {error}" for error in result.error_messages[:3])
- if len(result.error_messages) > 3:
- error_text += f"\n... and {len(result.error_messages) - 3} more errors"
- table.add_row("Errors", error_text)
-
- console.print(table)
-
- # Success celebration or error guidance
- if result.status.value == "completed" and result.documents_processed > 0:
- console.print(
- Panel(
- (
- f"🎉 [bold green]Success![/bold green] {result.documents_processed} documents ingested\n\n"
- f"💡 [dim]Try '[bold cyan]ingest modern[/bold cyan]' to explore your collections![/dim]"
- ),
- title="✨ Ingestion Complete",
- border_style="green",
- )
- )
- elif result.error_messages:
- console.print(
- Panel(
- (
- "❌ [bold red]Ingestion encountered errors[/bold red]\n\n"
- "💡 [dim]Check your configuration and try again[/dim]"
- ),
- title="⚠️ Issues Detected",
- border_style="red",
- )
- )
-
-
-@app.command()
-def schedule(
- name: Annotated[str, typer.Argument(help="Deployment name")],
- source_url: Annotated[str, typer.Argument(help="URL or path to ingest from")],
- source_type: Annotated[
- IngestionSource, typer.Option("--type", "-t", help="Type of source")
- ] = IngestionSource.WEB,
- storage: Annotated[
- StorageBackend, typer.Option("--storage", "-s", help="Storage backend")
- ] = StorageBackend.WEAVIATE,
- cron: Annotated[
- str | None, typer.Option("--cron", "-c", help="Cron expression for scheduling")
- ] = None,
- interval: Annotated[int, typer.Option("--interval", "-i", help="Interval in minutes")] = 60,
- serve_now: Annotated[
- bool, typer.Option("--serve/--no-serve", help="Start serving immediately")
- ] = False,
-) -> None:
- """
- Create a scheduled deployment for recurring ingestion.
- """
- console.print(f"[bold blue]Creating deployment: {name}[/bold blue]")
-
- deployment = create_scheduled_deployment(
- name=name,
- source_url=source_url,
- source_type=source_type,
- storage_backend=storage,
- schedule_type="cron" if cron else "interval",
- cron_expression=cron,
- interval_minutes=interval,
- )
-
- console.print(f"[green]✓ Deployment '{name}' created[/green]")
-
- if serve_now:
- console.print("[yellow]Starting deployment server...[/yellow]")
- serve_deployments([deployment])
-
-
-@app.command()
-def serve(
- config_file: Annotated[
- str | None, typer.Option("--config", "-c", help="Path to deployments config file")
- ] = None,
- ui: Annotated[
- str | None, typer.Option("--ui", help="Launch user interface (options: tui, web)")
- ] = None,
-) -> None:
- """
- 🚀 Serve configured deployments with optional UI interface.
-
- Launch the deployment server to run scheduled ingestion jobs,
- optionally with a modern Terminal User Interface (TUI) or web interface.
- """
- # Handle UI mode first
- if ui == "tui":
- console.print(
- Panel(
- (
- "[bold cyan]🚀 Launching Enhanced TUI[/bold cyan]\n\n"
- "[yellow]Features:[/yellow]\n"
- "• 📊 Interactive collection management\n"
- "• ⌨️ Enhanced keyboard navigation\n"
- "• 🎨 Modern design with focus indicators\n"
- "• 📄 Document browsing and search\n"
- "• 🔄 Real-time status updates"
- ),
- title="🎉 TUI Mode",
- border_style="cyan",
- )
- )
- from .tui import dashboard
-
- dashboard()
- return
- elif ui == "web":
- console.print("[red]Web UI not yet implemented. Use --ui tui for Terminal UI.[/red]")
- return
- elif ui:
- console.print(f"[red]Unknown UI option: {ui}[/red]")
- console.print("[yellow]Available options: tui, web[/yellow]")
- return
-
- # Normal deployment server mode
- if config_file:
- # Load deployments from config
- console.print(f"[yellow]Loading deployments from {config_file}[/yellow]")
- # Implementation would load YAML/JSON config
- else:
- # Create example deployments
- deployments = [
- create_scheduled_deployment(
- name="docs-daily",
- source_url="https://docs.example.com",
- source_type="documentation",
- storage_backend="weaviate",
- schedule_type="cron",
- cron_expression="0 2 * * *", # Daily at 2 AM
- ),
- create_scheduled_deployment(
- name="repo-hourly",
- source_url="https://github.com/example/repo",
- source_type="repository",
- storage_backend="open_webui",
- schedule_type="interval",
- interval_minutes=60,
- ),
- ]
-
- console.print(
- "[bold green]Starting deployment server with example deployments[/bold green]"
- )
- serve_deployments(deployments)
-
-
-@app.command()
-def tui() -> None:
- """
- 🚀 Launch the enhanced Terminal User Interface.
-
- Quick shortcut for 'serve --ui tui' with modern keyboard navigation,
- interactive collection management, and real-time status updates.
- """
- console.print(
- Panel(
- (
- "[bold cyan]🚀 Launching Enhanced TUI[/bold cyan]\n\n"
- "[yellow]Features:[/yellow]\n"
- "• 📊 Interactive collection management\n"
- "• ⌨️ Enhanced keyboard navigation\n"
- "• 🎨 Modern design with focus indicators\n"
- "• 📄 Document browsing and search\n"
- "• 🔄 Real-time status updates"
- ),
- title="🎉 TUI Mode",
- border_style="cyan",
- )
- )
- from .tui import dashboard
-
- dashboard()
-
-
-@app.command()
-def config() -> None:
- """
- 📋 Display current configuration with enhanced formatting.
-
- Shows all configured endpoints, models, and settings in a beautiful
- table format with status indicators.
- """
- settings = get_settings()
-
- console.print(
- Panel(
- (
- "[bold cyan]⚙️ System Configuration[/bold cyan]\n"
- "[dim]Current pipeline settings and endpoints[/dim]"
- ),
- title="🔧 Configuration",
- border_style="cyan",
- )
- )
-
- # Enhanced configuration table
- table = Table(
- title="📊 Configuration Details",
- title_style="bold magenta",
- border_style="blue",
- header_style="bold cyan",
- show_lines=True,
- )
- table.add_column("🏷️ Setting", style="cyan", no_wrap=True, width=25)
- table.add_column("🎯 Value", style="yellow", overflow="fold")
- table.add_column("📊 Status", style="green", width=12, justify="center")
-
- # Add configuration rows with status indicators
- def get_status_indicator(value: str | None) -> str:
- return "✅ Set" if value else "❌ Missing"
-
- table.add_row("🤖 LLM Endpoint", str(settings.llm_endpoint), "✅ Active")
- table.add_row("🔥 Firecrawl Endpoint", str(settings.firecrawl_endpoint), "✅ Active")
- table.add_row(
- "🗄️ Weaviate Endpoint",
- str(settings.weaviate_endpoint),
- get_status_indicator(str(settings.weaviate_api_key) if settings.weaviate_api_key else None),
- )
- table.add_row(
- "🌐 OpenWebUI Endpoint",
- str(settings.openwebui_endpoint),
- get_status_indicator(settings.openwebui_api_key),
- )
- table.add_row("🧠 Embedding Model", settings.embedding_model, "✅ Set")
- table.add_row("💾 Default Storage", settings.default_storage_backend.title(), "✅ Set")
- table.add_row("📦 Default Batch Size", f"{settings.default_batch_size:,}", "✅ Set")
- table.add_row("⚡ Max Concurrent Tasks", f"{settings.max_concurrent_tasks}", "✅ Set")
-
- console.print(table)
-
- # Additional helpful information
- console.print(
- Panel(
- (
- "💡 [bold cyan]Quick Tips[/bold cyan]\n\n"
- "• Use '[bold]ingest list-collections[/bold]' to view all collections\n"
- "• Use '[bold]ingest search[/bold]' to search content\n"
- "• Configure API keys in your [yellow].env[/yellow] file\n"
- "• Default collection names are auto-generated from URLs"
- ),
- title="🚀 Usage Tips",
- border_style="green",
- )
- )
-
-
-@app.command()
-def list_collections() -> None:
- """
- 📋 List all collections across storage backends.
- """
- console.print("[bold cyan]📚 Collection Overview[/bold cyan]")
- asyncio.run(run_list_collections())
-
-
-@app.command()
-def search(
- query: Annotated[str, typer.Argument(help="Search query")],
- collection: Annotated[
- str | None, typer.Option("--collection", "-c", help="Target collection")
- ] = None,
- backend: Annotated[
- StorageBackend, typer.Option("--backend", "-b", help="Storage backend")
- ] = StorageBackend.WEAVIATE,
- limit: Annotated[int, typer.Option("--limit", "-l", help="Result limit")] = 10,
-) -> None:
- """
- 🔍 Search across collections.
- """
- console.print(f"[bold cyan]🔍 Searching for: {query}[/bold cyan]")
- asyncio.run(run_search(query, collection, backend.value, limit))
-
-
-@app.command(name="blocks")
-def blocks_command() -> None:
- """🧩 List and manage Prefect Blocks."""
- console.print("[bold cyan]📦 Prefect Blocks Management[/bold cyan]")
- console.print("Use 'prefect block register --module ingest_pipeline.core.models' to register custom blocks")
- console.print("Use 'prefect block ls' to list available blocks")
-
-
-@app.command(name="variables")
-def variables_command() -> None:
- """📊 Manage Prefect Variables."""
- console.print("[bold cyan]📊 Prefect Variables Management[/bold cyan]")
- console.print("Use 'prefect variable set VARIABLE_NAME value' to set variables")
- console.print("Use 'prefect variable ls' to list variables")
-
-
-async def run_ingestion(
- url: str,
- source_type: IngestionSource,
- storage_backend: StorageBackend,
- collection_name: str | None = None,
- validate_first: bool = True,
-) -> IngestionResult:
- """
- Run ingestion with support for targeted collections.
- """
- # Auto-generate collection name if not provided
- if not collection_name:
- from urllib.parse import urlparse
-
- parsed = urlparse(url)
- domain = parsed.netloc.replace(".", "_").replace("-", "_")
- collection_name = f"{domain}_{source_type.value}"
-
- return await create_ingestion_flow(
- source_url=url,
- source_type=source_type,
- storage_backend=storage_backend,
- collection_name=collection_name,
- validate_first=validate_first,
- )
-
-
-async def run_list_collections() -> None:
- """
- List collections across storage backends.
- """
- from ..config import get_settings
- from ..core.models import StorageBackend
- from ..storage.openwebui import OpenWebUIStorage
- from ..storage.weaviate import WeaviateStorage
-
- settings = get_settings()
-
- console.print("🔍 [bold cyan]Scanning storage backends...[/bold cyan]")
-
- # Try to connect to Weaviate
- weaviate_collections: list[tuple[str, int]] = []
- try:
- weaviate_config = StorageConfig(
- backend=StorageBackend.WEAVIATE,
- endpoint=settings.weaviate_endpoint,
- api_key=SecretStr(settings.weaviate_api_key) if settings.weaviate_api_key is not None else None,
- collection_name="default",
- )
- weaviate = WeaviateStorage(weaviate_config)
- await weaviate.initialize()
-
- overview = await weaviate.describe_collections()
- for item in overview:
- name = str(item.get("name", "Unknown"))
- count_val = item.get("count", 0)
- count = int(count_val) if isinstance(count_val, (int, str)) else 0
- weaviate_collections.append((name, count))
- except Exception as e:
- console.print(f"❌ [red]Weaviate connection failed: {e}[/red]")
-
- # Try to connect to OpenWebUI
- openwebui_collections: list[tuple[str, int]] = []
- try:
- openwebui_config = StorageConfig(
- backend=StorageBackend.OPEN_WEBUI,
- endpoint=settings.openwebui_endpoint,
- api_key=SecretStr(settings.openwebui_api_key) if settings.openwebui_api_key is not None else None,
- collection_name="default",
- )
- openwebui = OpenWebUIStorage(openwebui_config)
- await openwebui.initialize()
-
- overview = await openwebui.describe_collections()
- for item in overview:
- name = str(item.get("name", "Unknown"))
- count_val = item.get("count", 0)
- count = int(count_val) if isinstance(count_val, (int, str)) else 0
- openwebui_collections.append((name, count))
- except Exception as e:
- console.print(f"❌ [red]OpenWebUI connection failed: {e}[/red]")
-
- # Display results
- if weaviate_collections or openwebui_collections:
- # Create results table
- from rich.table import Table
-
- table = Table(
- title="📚 Collection Overview",
- title_style="bold magenta",
- border_style="cyan",
- header_style="bold blue",
- )
- table.add_column("🏷️ Collection", style="cyan", no_wrap=True)
- table.add_column("📊 Backend", style="yellow")
- table.add_column("📄 Documents", style="green", justify="right")
-
- # Add Weaviate collections
- for name, count in weaviate_collections:
- table.add_row(name, "🗄️ Weaviate", f"{count:,}")
-
- # Add OpenWebUI collections
- for name, count in openwebui_collections:
- table.add_row(name, "🌐 OpenWebUI", f"{count:,}")
-
- console.print(table)
- else:
- console.print("❌ [yellow]No collections found in any backend[/yellow]")
-
-
-async def run_search(query: str, collection: str | None, backend: str, limit: int) -> None:
- """
- Search across collections.
- """
- from ..config import get_settings
- from ..core.models import StorageBackend
- from ..storage.weaviate import WeaviateStorage
-
- settings = get_settings()
-
- console.print(f"🔍 Searching for: '[bold cyan]{query}[/bold cyan]'")
- if collection:
- console.print(f"📚 Target collection: [yellow]{collection}[/yellow]")
- console.print(f"💾 Backend: [blue]{backend}[/blue]")
-
- results = []
-
- try:
- if backend == "weaviate":
- weaviate_config = StorageConfig(
- backend=StorageBackend.WEAVIATE,
- endpoint=settings.weaviate_endpoint,
- api_key=SecretStr(settings.weaviate_api_key) if settings.weaviate_api_key is not None else None,
- collection_name=collection or "default",
- )
- weaviate = WeaviateStorage(weaviate_config)
- await weaviate.initialize()
-
- results_generator = weaviate.search(query, limit=limit)
- async for doc in results_generator:
- results.append(
- {
- "title": getattr(doc, "title", "Untitled"),
- "content": getattr(doc, "content", ""),
- "score": getattr(doc, "score", 0.0),
- "backend": "🗄️ Weaviate",
- }
- )
-
- elif backend == "open_webui":
- console.print("❌ [red]OpenWebUI search not yet implemented[/red]")
- return
-
- except Exception as e:
- console.print(f"❌ [red]Search failed: {e}[/red]")
- return
-
- # Display results
- if results:
- from rich.table import Table
-
- table = Table(
- title=f"🔍 Search Results for '{query}'",
- title_style="bold magenta",
- border_style="green",
- header_style="bold blue",
- )
- table.add_column("📄 Title", style="cyan", max_width=40)
- table.add_column("📝 Preview", style="white", max_width=60)
- table.add_column("📊 Score", style="yellow", justify="right")
-
- for result in results[:limit]:
- title = str(result["title"])
- title_display = title[:40] + "..." if len(title) > 40 else title
-
- content = str(result["content"])
- content_display = content[:60] + "..." if len(content) > 60 else content
-
- score = f"{result['score']:.3f}"
-
- table.add_row(title_display, content_display, score)
-
- console.print(table)
- console.print(f"\n✅ [green]Found {len(results)} results[/green]")
- else:
- console.print("❌ [yellow]No results found[/yellow]")
-
-
-if __name__ == "__main__":
- app()
-
-
-
-"""Configuration management utilities."""
+
+"""Prefect flow for ingestion pipeline."""
from __future__ import annotations
-from contextlib import ExitStack
+from collections.abc import Callable
+from datetime import UTC, datetime
+from typing import TYPE_CHECKING, Literal, TypeAlias, assert_never, cast
-from prefect.settings import Setting, temporary_settings
+from prefect import flow, get_run_logger, task
+from prefect.blocks.core import Block
+from prefect.variables import Variable
+from pydantic import SecretStr
+
+from ..config.settings import Settings
+from ..core.exceptions import IngestionError
+from ..core.models import (
+ Document,
+ FirecrawlConfig,
+ IngestionJob,
+ IngestionResult,
+ IngestionSource,
+ IngestionStatus,
+ RepomixConfig,
+ StorageBackend,
+ StorageConfig,
+)
+from ..ingestors import BaseIngestor, FirecrawlIngestor, FirecrawlPage, RepomixIngestor
+from ..storage import OpenWebUIStorage, WeaviateStorage
+from ..storage import R2RStorage as RuntimeR2RStorage
+from ..storage.base import BaseStorage
+from ..utils.metadata_tagger import MetadataTagger
+
+SourceTypeLiteral = Literal["web", "repository", "documentation"]
+StorageBackendLiteral = Literal["weaviate", "open_webui", "r2r"]
+SourceTypeLike: TypeAlias = IngestionSource | SourceTypeLiteral
+StorageBackendLike: TypeAlias = StorageBackend | StorageBackendLiteral
-# Import Prefect settings with version compatibility - avoid static analysis issues
-def _setup_prefect_settings() -> tuple[object, object, object]:
- """Setup Prefect settings with proper fallbacks."""
+def _safe_cache_key(prefix: str, params: dict[str, object], key: str) -> str:
+ """Create a type-safe cache key from task parameters."""
+ value = params.get(key, "")
+ return f"{prefix}_{hash(str(value))}"
+
+
+if TYPE_CHECKING:
+ from ..storage.r2r.storage import R2RStorage as R2RStorageType
+else:
+ R2RStorageType = BaseStorage
+
+
+@task(name="validate_source", retries=2, retry_delay_seconds=10, tags=["validation"])
+async def validate_source_task(source_url: str, source_type: IngestionSource) -> bool:
+ """
+ Validate that a source is accessible.
+
+ Args:
+ source_url: URL or path to source
+ source_type: Type of source
+
+ Returns:
+ True if valid
+ """
+ if source_type == IngestionSource.WEB:
+ ingestor = FirecrawlIngestor()
+ elif source_type == IngestionSource.REPOSITORY:
+ ingestor = RepomixIngestor()
+ else:
+ raise ValueError(f"Unsupported source type: {source_type}")
+
+ result = await ingestor.validate_source(source_url)
+ return bool(result)
+
+
+@task(name="initialize_storage", retries=3, retry_delay_seconds=5, tags=["storage"])
+async def initialize_storage_task(config: StorageConfig | str) -> BaseStorage:
+ """
+ Initialize storage backend.
+
+ Args:
+ config: Storage configuration block or block name
+
+ Returns:
+ Initialized storage adapter
+ """
+ # Load block if string provided
+ if isinstance(config, str):
+ # Use Block.aload with type slug for better type inference
+ loaded_block = await Block.aload(f"storage-config/{config}")
+ config = cast(StorageConfig, loaded_block)
+
+ if config.backend == StorageBackend.WEAVIATE:
+ storage = WeaviateStorage(config)
+ elif config.backend == StorageBackend.OPEN_WEBUI:
+ storage = OpenWebUIStorage(config)
+ elif config.backend == StorageBackend.R2R:
+ if RuntimeR2RStorage is None:
+ raise ValueError("R2R storage not available. Check dependencies.")
+ storage = RuntimeR2RStorage(config)
+ else:
+ raise ValueError(f"Unsupported backend: {config.backend}")
+
+ await storage.initialize()
+ return storage
+
+
+@task(
+ name="map_firecrawl_site",
+ retries=2,
+ retry_delay_seconds=15,
+ tags=["firecrawl", "map"],
+ cache_key_fn=lambda ctx, p: _safe_cache_key("firecrawl_map", p, "source_url"),
+)
+async def map_firecrawl_site_task(source_url: str, config: FirecrawlConfig | str) -> list[str]:
+ """Map a site using Firecrawl and return discovered URLs."""
+ # Load block if string provided
+ if isinstance(config, str):
+ # Use Block.aload with type slug for better type inference
+ loaded_block = await Block.aload(f"firecrawl-config/{config}")
+ config = cast(FirecrawlConfig, loaded_block)
+
+ ingestor = FirecrawlIngestor(config)
+ mapped = await ingestor.map_site(source_url)
+ return mapped or [source_url]
+
+
+@task(
+ name="filter_existing_documents",
+ retries=1,
+ retry_delay_seconds=5,
+ tags=["dedup"],
+ cache_key_fn=lambda ctx, p: _safe_cache_key("filter_docs", p, "urls"),
+) # Cache based on URL list
+async def filter_existing_documents_task(
+ urls: list[str],
+ storage_client: BaseStorage,
+ stale_after_days: int = 30,
+ *,
+ collection_name: str | None = None,
+) -> list[str]:
+ """Filter URLs to only those that need scraping (missing or stale in storage)."""
+ import asyncio
+
+ logger = get_run_logger()
+
+ # Use semaphore to limit concurrent existence checks
+ semaphore = asyncio.Semaphore(20)
+
+ async def check_url_exists(url: str) -> tuple[str, bool]:
+ async with semaphore:
+ try:
+ document_id = str(FirecrawlIngestor.compute_document_id(url))
+ exists = await storage_client.check_exists(
+ document_id, collection_name=collection_name, stale_after_days=stale_after_days
+ )
+ return url, exists
+ except Exception as e:
+ logger.warning("Error checking existence for URL %s: %s", url, e)
+ # Assume doesn't exist on error to ensure we scrape it
+ return url, False
+
+ # Check all URLs in parallel - use return_exceptions=True for partial failure handling
+ results = await asyncio.gather(*[check_url_exists(url) for url in urls], return_exceptions=True)
+
+ # Collect URLs that need scraping, handling any exceptions
+ eligible = []
+ for result in results:
+ if isinstance(result, Exception):
+ logger.error("Unexpected error in parallel existence check: %s", result)
+ continue
+ # Type narrowing: result is now known to be tuple[str, bool]
+ if isinstance(result, tuple) and len(result) == 2:
+ url, exists = result
+ if not exists:
+ eligible.append(url)
+
+ skipped = len(urls) - len(eligible)
+ if skipped > 0:
+ logger.info("Skipping %s up-to-date documents in %s", skipped, storage_client.display_name)
+
+ return eligible
+
+
+@task(
+ name="scrape_firecrawl_batch", retries=2, retry_delay_seconds=20, tags=["firecrawl", "scrape"]
+)
+async def scrape_firecrawl_batch_task(
+ batch_urls: list[str], config: FirecrawlConfig
+) -> list[FirecrawlPage]:
+ """Scrape a batch of URLs via Firecrawl."""
+ ingestor = FirecrawlIngestor(config)
+ result: list[FirecrawlPage] = await ingestor.scrape_pages(batch_urls)
+ return result
+
+
+@task(name="annotate_firecrawl_metadata", retries=1, retry_delay_seconds=10, tags=["metadata"])
+async def annotate_firecrawl_metadata_task(
+ pages: list[FirecrawlPage], job: IngestionJob
+) -> list[Document]:
+ """Annotate scraped pages with standardized metadata."""
+ if not pages:
+ return []
+
+ ingestor = FirecrawlIngestor()
+ documents = [ingestor.create_document(page, job) for page in pages]
+
try:
- import prefect.settings as ps
+ from ..config import get_settings
- # Try to get the settings directly
- api_key = getattr(ps, "PREFECT_API_KEY", None)
- api_url = getattr(ps, "PREFECT_API_URL", None)
- work_pool = getattr(ps, "PREFECT_DEFAULT_WORK_POOL_NAME", None)
-
- if api_key is not None:
- return api_key, api_url, work_pool
-
- # Fallback to registry-based approach
- registry = getattr(ps, "PREFECT_SETTING_REGISTRY", None)
- if registry is not None:
- Setting = getattr(ps, "Setting", None)
- if Setting is not None:
- api_key = registry.get("PREFECT_API_KEY") or Setting("PREFECT_API_KEY", type_=str, default=None)
- api_url = registry.get("PREFECT_API_URL") or Setting("PREFECT_API_URL", type_=str, default=None)
- work_pool = registry.get("PREFECT_DEFAULT_WORK_POOL_NAME") or Setting("PREFECT_DEFAULT_WORK_POOL_NAME", type_=str, default=None)
- return api_key, api_url, work_pool
-
- except ImportError:
- pass
-
- # Ultimate fallback
- return None, None, None
-
-PREFECT_API_KEY, PREFECT_API_URL, PREFECT_DEFAULT_WORK_POOL_NAME = _setup_prefect_settings()
-
-# Import after Prefect settings setup to avoid circular dependencies
-from .settings import Settings, get_settings # noqa: E402
-
-__all__ = ["Settings", "get_settings", "configure_prefect"]
-
-_prefect_settings_stack: ExitStack | None = None
+ settings = get_settings()
+ async with MetadataTagger(llm_endpoint=str(settings.llm_endpoint)) as tagger:
+ tagged_documents: list[Document] = await tagger.tag_batch(documents)
+ return tagged_documents
+ except IngestionError as exc: # pragma: no cover - logging side effect
+ logger = get_run_logger()
+ logger.warning("Metadata tagging failed: %s", exc)
+ return documents
+ except Exception as exc: # pragma: no cover - defensive
+ logger = get_run_logger()
+ logger.warning("Metadata tagging unavailable, using base metadata: %s", exc)
+ return documents
-def configure_prefect(settings: Settings) -> None:
- """Apply Prefect settings from the application configuration."""
- global _prefect_settings_stack
+@task(name="upsert_r2r_documents", retries=2, retry_delay_seconds=20, tags=["storage", "r2r"])
+async def upsert_r2r_documents_task(
+ storage_client: R2RStorageType,
+ documents: list[Document],
+ collection_name: str | None,
+) -> tuple[int, int]:
+ """Upsert documents into R2R storage."""
+ if not documents:
+ return 0, 0
- overrides: dict[Setting, str] = {}
+ stored_ids: list[str] = await storage_client.store_batch(
+ documents, collection_name=collection_name
+ )
+ processed = len(stored_ids)
+ failed = len(documents) - processed
- if settings.prefect_api_url is not None and PREFECT_API_URL is not None and isinstance(PREFECT_API_URL, Setting):
- overrides[PREFECT_API_URL] = str(settings.prefect_api_url)
- if settings.prefect_api_key and PREFECT_API_KEY is not None and isinstance(PREFECT_API_KEY, Setting):
- overrides[PREFECT_API_KEY] = settings.prefect_api_key
- if settings.prefect_work_pool and PREFECT_DEFAULT_WORK_POOL_NAME is not None and isinstance(PREFECT_DEFAULT_WORK_POOL_NAME, Setting):
- overrides[PREFECT_DEFAULT_WORK_POOL_NAME] = settings.prefect_work_pool
+ if failed:
+ logger = get_run_logger()
+ logger.warning("Failed to upsert %s documents to R2R", failed)
- if not overrides:
- return
+ return processed, failed
- filtered_overrides = {
- setting: value
- for setting, value in overrides.items()
- if setting.value() != value
+
+@task(name="ingest_documents", retries=2, retry_delay_seconds=30, tags=["ingestion"])
+async def ingest_documents_task(
+ job: IngestionJob,
+ collection_name: str | None = None,
+ batch_size: int | None = None,
+ storage_client: BaseStorage | None = None,
+ storage_block_name: str | None = None,
+ ingestor_config_block_name: str | None = None,
+ progress_callback: Callable[[int, str], None] | None = None,
+) -> tuple[int, int]:
+ """
+ Ingest documents from source with optional pre-initialized storage client.
+
+ Args:
+ job: Ingestion job configuration
+ collection_name: Target collection name
+ batch_size: Number of documents per batch (uses Variable if None)
+ storage_client: Optional pre-initialized storage client
+ storage_block_name: Optional storage block name to load
+ ingestor_config_block_name: Optional ingestor config block name to load
+ progress_callback: Optional callback for progress updates
+
+ Returns:
+ Tuple of (processed_count, failed_count)
+ """
+ if progress_callback:
+ progress_callback(35, "Creating ingestor and storage clients...")
+
+ # Use Variable for batch size if not provided
+ if batch_size is None:
+ try:
+ batch_size_var = await Variable.aget("default_batch_size", default="50")
+ # Convert Variable result to int, handling various types
+ if isinstance(batch_size_var, int):
+ batch_size = batch_size_var
+ elif isinstance(batch_size_var, (str, float)):
+ batch_size = int(float(str(batch_size_var)))
+ else:
+ batch_size = 50
+ except Exception:
+ batch_size = 50
+
+ ingestor = await _create_ingestor(job, ingestor_config_block_name)
+ storage = storage_client or await _create_storage(job, collection_name, storage_block_name)
+
+ if progress_callback:
+ progress_callback(40, "Starting document processing...")
+
+ return await _process_documents(
+ ingestor, storage, job, batch_size, collection_name, progress_callback
+ )
+
+
+async def _create_ingestor(job: IngestionJob, config_block_name: str | None = None) -> BaseIngestor:
+ """Create appropriate ingestor based on job source type."""
+ if job.source_type == IngestionSource.WEB:
+ if config_block_name:
+ # Use Block.aload with type slug for better type inference
+ loaded_block = await Block.aload(f"firecrawl-config/{config_block_name}")
+ config = cast(FirecrawlConfig, loaded_block)
+ else:
+ # Fallback to default configuration
+ config = FirecrawlConfig()
+ return FirecrawlIngestor(config)
+ elif job.source_type == IngestionSource.REPOSITORY:
+ if config_block_name:
+ # Use Block.aload with type slug for better type inference
+ loaded_block = await Block.aload(f"repomix-config/{config_block_name}")
+ config = cast(RepomixConfig, loaded_block)
+ else:
+ # Fallback to default configuration
+ config = RepomixConfig()
+ return RepomixIngestor(config)
+ else:
+ raise ValueError(f"Unsupported source: {job.source_type}")
+
+
+async def _create_storage(
+ job: IngestionJob, collection_name: str | None, storage_block_name: str | None = None
+) -> BaseStorage:
+ """Create and initialize storage client."""
+ if collection_name is None:
+ # Use variable for default collection prefix
+ prefix = await Variable.aget("default_collection_prefix", default="docs")
+ collection_name = f"{prefix}_{job.source_type.value}"
+
+ if storage_block_name:
+ # Load storage config from block
+ loaded_block = await Block.aload(f"storage-config/{storage_block_name}")
+ storage_config = cast(StorageConfig, loaded_block)
+ # Override collection name if provided
+ storage_config.collection_name = collection_name
+ else:
+ # Fallback to building config from settings
+ from ..config import get_settings
+
+ settings = get_settings()
+ storage_config = _build_storage_config(job, settings, collection_name)
+
+ storage = _instantiate_storage(job.storage_backend, storage_config)
+ await storage.initialize()
+ return storage
+
+
+def _build_storage_config(
+ job: IngestionJob, settings: Settings, collection_name: str
+) -> StorageConfig:
+ """Build storage configuration from job and settings."""
+ storage_endpoints = {
+ StorageBackend.WEAVIATE: settings.weaviate_endpoint,
+ StorageBackend.OPEN_WEBUI: settings.openwebui_endpoint,
+ StorageBackend.R2R: settings.get_storage_endpoint("r2r"),
+ }
+ storage_api_keys: dict[StorageBackend, str | None] = {
+ StorageBackend.WEAVIATE: settings.get_api_key("weaviate"),
+ StorageBackend.OPEN_WEBUI: settings.get_api_key("openwebui"),
+ StorageBackend.R2R: None, # R2R is self-hosted, no API key needed
}
- if not filtered_overrides:
- return
+ api_key_raw: str | None = storage_api_keys[job.storage_backend]
+ api_key: SecretStr | None = SecretStr(api_key_raw) if api_key_raw is not None else None
- new_stack = ExitStack()
- new_stack.enter_context(temporary_settings(updates=filtered_overrides))
+ return StorageConfig(
+ backend=job.storage_backend,
+ endpoint=storage_endpoints[job.storage_backend],
+ api_key=api_key,
+ collection_name=collection_name,
+ )
- if _prefect_settings_stack is not None:
- _prefect_settings_stack.close()
- _prefect_settings_stack = new_stack
+def _instantiate_storage(backend: StorageBackend, config: StorageConfig) -> BaseStorage:
+ """Instantiate storage based on backend type."""
+ if backend == StorageBackend.WEAVIATE:
+ return WeaviateStorage(config)
+ elif backend == StorageBackend.OPEN_WEBUI:
+ return OpenWebUIStorage(config)
+ elif backend == StorageBackend.R2R:
+ if RuntimeR2RStorage is None:
+ raise ValueError("R2R storage not available. Check dependencies.")
+ return RuntimeR2RStorage(config)
+
+ assert_never(backend)
+
+
+def _chunk_urls(urls: list[str], chunk_size: int) -> list[list[str]]:
+ """Group URLs into fixed-size chunks for batch processing."""
+
+ if chunk_size <= 0:
+ raise ValueError("chunk_size must be greater than zero")
+
+ return [urls[i : i + chunk_size] for i in range(0, len(urls), chunk_size)]
+
+
+def _deduplicate_urls(urls: list[str]) -> list[str]:
+ """Return the URLs with order preserved and duplicates removed."""
+
+ seen: set[str] = set()
+ unique: list[str] = []
+ for url in urls:
+ if url not in seen:
+ seen.add(url)
+ unique.append(url)
+ return unique
+
+
+async def _process_documents(
+ ingestor: BaseIngestor,
+ storage: BaseStorage,
+ job: IngestionJob,
+ batch_size: int,
+ collection_name: str | None,
+ progress_callback: Callable[[int, str], None] | None = None,
+) -> tuple[int, int]:
+ """Process documents in batches."""
+ processed = 0
+ failed = 0
+ batch: list[Document] = []
+ total_documents = 0
+ batch_count = 0
+
+ if progress_callback:
+ progress_callback(45, "Ingesting documents from source...")
+
+ # Use smart ingestion with deduplication if storage supports it
+ if hasattr(storage, "check_exists"):
+ try:
+ # Try to use the smart ingestion method
+ document_generator = ingestor.ingest_with_dedup(
+ job, storage, collection_name=collection_name
+ )
+ except Exception:
+ # Fall back to regular ingestion if smart method fails
+ document_generator = ingestor.ingest(job)
+ else:
+ document_generator = ingestor.ingest(job)
+
+ async for document in document_generator:
+ batch.append(document)
+ total_documents += 1
+
+ if len(batch) >= batch_size:
+ batch_count += 1
+ if progress_callback:
+ progress_callback(
+ 45 + min(35, (batch_count * 10)),
+ f"Processing batch {batch_count} ({total_documents} documents so far)...",
+ )
+
+ batch_processed, batch_failed = await _store_batch(storage, batch, collection_name)
+ processed += batch_processed
+ failed += batch_failed
+ batch = []
+
+ # Process remaining batch
+ if batch:
+ batch_count += 1
+ if progress_callback:
+ progress_callback(80, f"Processing final batch ({total_documents} total documents)...")
+
+ batch_processed, batch_failed = await _store_batch(storage, batch, collection_name)
+ processed += batch_processed
+ failed += batch_failed
+
+ if progress_callback:
+ progress_callback(85, f"Completed processing {total_documents} documents")
+
+ return processed, failed
+
+
+async def _store_batch(
+ storage: BaseStorage,
+ batch: list[Document],
+ collection_name: str | None,
+) -> tuple[int, int]:
+ """Store a batch of documents and return processed/failed counts."""
+ try:
+ # Apply metadata tagging for backends that benefit from it
+ processed_batch = batch
+ if hasattr(storage, "config") and storage.config.backend in (
+ StorageBackend.R2R,
+ StorageBackend.WEAVIATE,
+ ):
+ try:
+ from ..config import get_settings
+
+ settings = get_settings()
+ async with MetadataTagger(llm_endpoint=str(settings.llm_endpoint)) as tagger:
+ processed_batch = await tagger.tag_batch(batch)
+ except Exception as exc:
+ print(f"Metadata tagging failed, using original documents: {exc}")
+ processed_batch = batch
+
+ stored_ids = await storage.store_batch(processed_batch, collection_name=collection_name)
+ processed_count = len(stored_ids)
+ failed_count = len(processed_batch) - processed_count
+
+ batch_type = (
+ "final" if len(processed_batch) < 50 else ""
+ ) # Assume standard batch size is 50
+ print(f"Successfully stored {processed_count} documents in {batch_type} batch".strip())
+
+ return processed_count, failed_count
+ except Exception as e:
+ batch_type = "Final" if len(batch) < 50 else "Batch"
+ print(f"{batch_type} storage failed: {e}")
+ return 0, len(batch)
+
+
+@flow(
+ name="firecrawl_to_r2r",
+ description="Ingest Firecrawl pages into R2R with metadata annotation",
+ persist_result=False,
+ log_prints=True,
+)
+async def firecrawl_to_r2r_flow(
+ job: IngestionJob,
+ collection_name: str | None = None,
+ progress_callback: Callable[[int, str], None] | None = None,
+) -> tuple[int, int]:
+ """Specialized flow for Firecrawl ingestion into R2R."""
+ logger = get_run_logger()
+ from ..config import get_settings
+
+ if progress_callback:
+ progress_callback(35, "Initializing Firecrawl and R2R storage...")
+
+ settings = get_settings()
+ firecrawl_config = FirecrawlConfig()
+ resolved_collection = collection_name or f"docs_{job.source_type.value}"
+
+ storage_config = _build_storage_config(job, settings, resolved_collection)
+ storage_client = await initialize_storage_task(storage_config)
+
+ if RuntimeR2RStorage is None or not isinstance(storage_client, RuntimeR2RStorage):
+ raise IngestionError("Firecrawl to R2R flow requires an R2R storage backend")
+
+ r2r_storage = cast("R2RStorageType", storage_client)
+
+ if progress_callback:
+ progress_callback(45, "Checking for existing content before mapping...")
+
+ # Smart mapping: try single URL first to avoid expensive map operation
+ base_url = str(job.source_url)
+ single_url_id = str(FirecrawlIngestor.compute_document_id(base_url))
+ base_exists = await r2r_storage.check_exists(
+ single_url_id, collection_name=resolved_collection, stale_after_days=30
+ )
+
+ if base_exists:
+ # Check if this is a recent single-page update
+ logger.info("Base URL %s exists and is fresh, skipping expensive mapping", base_url)
+ if progress_callback:
+ progress_callback(100, "Content is up to date, no processing needed")
+ return 0, 0
+
+ if progress_callback:
+ progress_callback(50, "Discovering pages with Firecrawl...")
+
+ discovered_urls = await map_firecrawl_site_task(base_url, firecrawl_config)
+ unique_urls = _deduplicate_urls(discovered_urls)
+ logger.info("Discovered %s unique URLs from Firecrawl map", len(unique_urls))
+
+ if progress_callback:
+ progress_callback(60, f"Found {len(unique_urls)} pages, filtering existing content...")
+
+ eligible_urls = await filter_existing_documents_task(
+ unique_urls, r2r_storage, collection_name=resolved_collection
+ )
+
+ if not eligible_urls:
+ logger.info("All Firecrawl pages are up to date for %s", job.source_url)
+ if progress_callback:
+ progress_callback(100, "All pages are up to date, no processing needed")
+ return 0, 0
+
+ if progress_callback:
+ progress_callback(70, f"Scraping {len(eligible_urls)} new/updated pages...")
+
+ batch_size = min(settings.default_batch_size, firecrawl_config.limit)
+ url_batches = _chunk_urls(eligible_urls, batch_size)
+ logger.info("Scraping %s batches of Firecrawl pages", len(url_batches))
+
+ # Use asyncio.gather for concurrent scraping
+ import asyncio
+
+ scrape_tasks = [scrape_firecrawl_batch_task(batch, firecrawl_config) for batch in url_batches]
+ batch_results = await asyncio.gather(*scrape_tasks)
+
+ scraped_pages: list[FirecrawlPage] = []
+ for batch_pages in batch_results:
+ scraped_pages.extend(batch_pages)
+
+ if progress_callback:
+ progress_callback(80, f"Processing {len(scraped_pages)} scraped pages...")
+
+ documents = await annotate_firecrawl_metadata_task(scraped_pages, job)
+
+ if not documents:
+ logger.warning("No documents produced after scraping for %s", job.source_url)
+ return 0, len(eligible_urls)
+
+ if progress_callback:
+ progress_callback(90, f"Storing {len(documents)} documents in R2R...")
+
+ processed, failed = await upsert_r2r_documents_task(r2r_storage, documents, resolved_collection)
+
+ logger.info("Upserted %s documents into R2R (%s failed)", processed, failed)
+
+ return processed, failed
+
+
+@task(name="update_job_status", tags=["tracking"])
+async def update_job_status_task(
+ job: IngestionJob,
+ status: IngestionStatus,
+ processed: int = 0,
+ _failed: int = 0,
+ error: str | None = None,
+) -> IngestionJob:
+ """
+ Update job status.
+
+ Args:
+ job: Ingestion job
+ status: New status
+ processed: Documents processed
+ _failed: Documents failed (currently unused)
+ error: Error message if any
+
+ Returns:
+ Updated job
+ """
+ job.status = status
+ job.updated_at = datetime.now(UTC)
+ job.document_count = processed
+
+ if status == IngestionStatus.COMPLETED:
+ job.completed_at = datetime.now(UTC)
+
+ if error:
+ job.error_message = error
+
+ return job
+
+
+@flow(
+ name="ingestion_pipeline",
+ description="Main ingestion pipeline for documents",
+ retries=1,
+ retry_delay_seconds=60,
+ persist_result=True,
+ log_prints=True,
+)
+async def create_ingestion_flow(
+ source_url: str,
+ source_type: SourceTypeLike,
+ storage_backend: StorageBackendLike = StorageBackend.WEAVIATE,
+ collection_name: str | None = None,
+ validate_first: bool = True,
+ progress_callback: Callable[[int, str], None] | None = None,
+) -> IngestionResult:
+ """
+ Main ingestion flow.
+
+ Args:
+ source_url: URL or path to source
+ source_type: Type of source
+ storage_backend: Storage backend to use
+ validate_first: Whether to validate source first
+ progress_callback: Optional callback for progress updates
+
+ Returns:
+ Ingestion result
+ """
+ print(f"Starting ingestion from {source_url}")
+
+ source_enum = IngestionSource(source_type)
+ backend_enum = StorageBackend(storage_backend)
+
+ # Create job
+ job = IngestionJob(
+ source_url=source_url,
+ source_type=source_enum,
+ storage_backend=backend_enum,
+ status=IngestionStatus.PENDING,
+ )
+
+ start_time = datetime.now(UTC)
+ error_messages: list[str] = []
+ processed = 0
+ failed = 0
+
+ try:
+ # Validate source if requested
+ if validate_first:
+ if progress_callback:
+ progress_callback(10, "Validating source...")
+ print("Validating source...")
+ is_valid = await validate_source_task(source_url, job.source_type)
+
+ if not is_valid:
+ raise IngestionError(f"Source validation failed: {source_url}")
+
+ # Update status to in progress
+ if progress_callback:
+ progress_callback(20, "Initializing storage...")
+ job = await update_job_status_task(job, IngestionStatus.IN_PROGRESS)
+
+ # Run ingestion
+ if progress_callback:
+ progress_callback(30, "Starting document ingestion...")
+ print("Ingesting documents...")
+ if job.source_type == IngestionSource.WEB and job.storage_backend == StorageBackend.R2R:
+ processed, failed = await firecrawl_to_r2r_flow(
+ job, collection_name, progress_callback=progress_callback
+ )
+ else:
+ processed, failed = await ingest_documents_task(
+ job, collection_name, progress_callback=progress_callback
+ )
+
+ if progress_callback:
+ progress_callback(90, "Finalizing ingestion...")
+
+ # Update final status
+ if failed > 0:
+ error_messages.append(f"{failed} documents failed to process")
+
+ # Set status based on results
+ if processed == 0 and failed > 0:
+ final_status = IngestionStatus.FAILED
+ elif failed > 0:
+ final_status = IngestionStatus.PARTIAL
+ else:
+ final_status = IngestionStatus.COMPLETED
+
+ job = await update_job_status_task(job, final_status, processed=processed, _failed=failed)
+
+ print(f"Ingestion completed: {processed} processed, {failed} failed")
+
+ except Exception as e:
+ print(f"Ingestion failed: {e}")
+ error_messages.append(str(e))
+
+ # Don't reset counts - keep whatever was processed before the error
+ job = await update_job_status_task(
+ job, IngestionStatus.FAILED, processed=processed, _failed=failed, error=str(e)
+ )
+
+ # Calculate duration
+ duration = (datetime.now(UTC) - start_time).total_seconds()
+
+ return IngestionResult(
+ job_id=job.id,
+ status=job.status,
+ documents_processed=processed,
+ documents_failed=failed,
+ duration_seconds=duration,
+ error_messages=error_messages,
+ )
@@ -7773,8 +9054,8 @@ def configure_prefect(settings: Settings) -> None:
from datetime import timedelta
from typing import Literal, Protocol, cast
-from prefect import serve
from prefect.deployments.runner import RunnerDeployment
+from prefect.flows import serve as prefect_serve
from prefect.schedules import Cron, Interval
from prefect.variables import Variable
@@ -7852,7 +9133,7 @@ def create_scheduled_deployment(
tags = [source_enum.value, backend_enum.value]
# Create deployment parameters with block support
- parameters = {
+ parameters: dict[str, str | bool] = {
"source_url": source_url,
"source_type": source_enum.value,
"storage_backend": backend_enum.value,
@@ -7867,8 +9148,8 @@ def create_scheduled_deployment(
# Create deployment
# The flow decorator adds the to_deployment method at runtime
- to_deployment = create_ingestion_flow.to_deployment
- deployment = to_deployment(
+ flow_with_deployment = cast(FlowWithDeployment, create_ingestion_flow)
+ return flow_with_deployment.to_deployment(
name=name,
schedule=schedule,
parameters=parameters,
@@ -7876,8 +9157,6 @@ def create_scheduled_deployment(
description=f"Scheduled ingestion from {source_url}",
)
- return cast("RunnerDeployment", deployment)
-
def serve_deployments(deployments: list[RunnerDeployment]) -> None:
"""
@@ -7886,92 +9165,7 @@ def serve_deployments(deployments: list[RunnerDeployment]) -> None:
Args:
deployments: List of deployment configurations
"""
- serve(*deployments, limit=10)
-
-
-
-"""Base ingestor interface."""
-
-from __future__ import annotations
-
-from abc import ABC, abstractmethod
-from collections.abc import AsyncGenerator
-from typing import TYPE_CHECKING
-
-from ..core.models import Document, IngestionJob
-
-if TYPE_CHECKING:
- from ..storage.base import BaseStorage
-
-
-class BaseIngestor(ABC):
- """Abstract base class for all ingestors."""
-
- @abstractmethod
- def ingest(self, job: IngestionJob) -> AsyncGenerator[Document, None]:
- """
- Ingest data from a source.
-
- Args:
- job: The ingestion job configuration
-
- Yields:
- Documents from the source
- """
- ... # pragma: no cover
-
- @abstractmethod
- async def validate_source(self, source_url: str) -> bool:
- """
- Validate if the source is accessible.
-
- Args:
- source_url: URL or path to the source
-
- Returns:
- True if source is valid and accessible
- """
- pass # pragma: no cover
-
- @abstractmethod
- async def estimate_size(self, source_url: str) -> int:
- """
- Estimate the number of documents in the source.
-
- Args:
- source_url: URL or path to the source
-
- Returns:
- Estimated number of documents
- """
- pass # pragma: no cover
-
- async def ingest_with_dedup(
- self,
- job: IngestionJob,
- storage_client: BaseStorage,
- *,
- collection_name: str | None = None,
- stale_after_days: int = 30,
- ) -> AsyncGenerator[Document, None]:
- """
- Ingest documents with duplicate detection (optional optimization).
-
- Default implementation falls back to regular ingestion.
- Subclasses can override to provide optimized deduplication.
-
- Args:
- job: The ingestion job configuration
- storage_client: Storage client to check for existing documents
- collection_name: Collection to check for duplicates
- stale_after_days: Consider documents stale after this many days
-
- Yields:
- Documents from the source (with deduplication if implemented)
- """
- # Default implementation: fall back to regular ingestion
- async for document in self.ingest(job):
- yield document
+ prefect_serve(*deployments, limit=10)
@@ -7983,7 +9177,7 @@ import re
from collections.abc import AsyncGenerator, Awaitable, Callable
from dataclasses import dataclass
from datetime import UTC, datetime
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Protocol, cast
from urllib.parse import urlparse
from uuid import NAMESPACE_URL, UUID, uuid5
@@ -8005,9 +9199,70 @@ if TYPE_CHECKING:
from ..storage.base import BaseStorage
+class FirecrawlMetadata(Protocol):
+ """Protocol for Firecrawl metadata objects."""
+
+ title: str | None
+ description: str | None
+ author: str | None
+ language: str | None
+ sitemap_last_modified: str | None
+ sourceURL: str | None
+ keywords: str | list[str] | None
+ robots: str | None
+ ogTitle: str | None
+ ogDescription: str | None
+ ogUrl: str | None
+ ogImage: str | None
+ twitterCard: str | None
+ twitterSite: str | None
+ twitterCreator: str | None
+ favicon: str | None
+ statusCode: int | None
+
+
+class FirecrawlResult(Protocol):
+ """Protocol for Firecrawl scrape result objects."""
+
+ metadata: FirecrawlMetadata | None
+ markdown: str | None
+
+
+class FirecrawlMapLink(Protocol):
+ """Protocol for Firecrawl map link objects."""
+
+ url: str
+
+
+class FirecrawlMapResult(Protocol):
+ """Protocol for Firecrawl map result objects."""
+
+ links: list[FirecrawlMapLink] | None
+
+
+class AsyncFirecrawlSession(Protocol):
+ """Protocol for AsyncFirecrawl session objects."""
+
+ async def close(self) -> None: ...
+
+
+class AsyncFirecrawlClient(Protocol):
+ """Protocol for AsyncFirecrawl client objects."""
+
+ _session: AsyncFirecrawlSession | None
+
+ async def close(self) -> None: ...
+
+ async def scrape(self, url: str, formats: list[str]) -> FirecrawlResult: ...
+
+ async def map(self, url: str, limit: int | None = None) -> "FirecrawlMapResult": ...
+
+
class FirecrawlError(IngestionError):
"""Base exception for Firecrawl-related errors."""
+ status_code: int | None
+
def __init__(self, message: str, status_code: int | None = None) -> None:
super().__init__(message)
self.status_code = status_code
@@ -8041,7 +9296,7 @@ async def retry_with_backoff(
except Exception as e:
if attempt == max_retries - 1:
raise e
- delay = 1.0 * (2**attempt)
+ delay: float = 1.0 * (2**attempt)
logging.warning(
f"Firecrawl operation failed (attempt {attempt + 1}/{max_retries}): {e}. Retrying in {delay:.1f}s..."
)
@@ -8081,7 +9336,7 @@ class FirecrawlIngestor(BaseIngestor):
"""Ingestor for web and documentation sites using Firecrawl."""
config: FirecrawlConfig
- client: AsyncFirecrawl
+ client: AsyncFirecrawlClient
def __init__(self, config: FirecrawlConfig | None = None):
"""
@@ -8107,15 +9362,16 @@ class FirecrawlIngestor(BaseIngestor):
"http://localhost"
):
# Self-hosted instance - try with api_url if supported
- self.client = AsyncFirecrawl(
- api_key=api_key, api_url=str(settings.firecrawl_endpoint)
+ self.client = cast(
+ AsyncFirecrawlClient,
+ AsyncFirecrawl(api_key=api_key, api_url=str(settings.firecrawl_endpoint)),
)
else:
# Cloud instance - use standard initialization
- self.client = AsyncFirecrawl(api_key=api_key)
+ self.client = cast(AsyncFirecrawlClient, AsyncFirecrawl(api_key=api_key))
except Exception:
# Fallback to standard initialization
- self.client = AsyncFirecrawl(api_key=api_key)
+ self.client = cast(AsyncFirecrawlClient, AsyncFirecrawl(api_key=api_key))
@override
async def ingest(self, job: IngestionJob) -> AsyncGenerator[Document, None]:
@@ -8172,9 +9428,7 @@ class FirecrawlIngestor(BaseIngestor):
for check_url in site_map:
document_id = str(self.compute_document_id(check_url))
exists = await storage_client.check_exists(
- document_id,
- collection_name=collection_name,
- stale_after_days=stale_after_days
+ document_id, collection_name=collection_name, stale_after_days=stale_after_days
)
if not exists:
eligible_urls.append(check_url)
@@ -8254,11 +9508,11 @@ class FirecrawlIngestor(BaseIngestor):
"""
try:
# Use SDK v2 map endpoint following official pattern
- result = await self.client.map(url=url, limit=self.config.limit)
+ result: FirecrawlMapResult = await self.client.map(url=url, limit=self.config.limit)
- if result and getattr(result, "links", None):
+ if result and result.links:
# Extract URLs from the result following official pattern
- return [getattr(link, "url", str(link)) for link in result.links]
+ return [link.url for link in result.links]
return []
except Exception as e:
# If map fails (might not be available in all versions), fall back to single URL
@@ -8301,43 +9555,57 @@ class FirecrawlIngestor(BaseIngestor):
try:
# Use SDK v2 scrape endpoint following official pattern with retry
async def scrape_operation() -> FirecrawlPage | None:
- result = await self.client.scrape(url, formats=self.config.formats)
+ result: FirecrawlResult = await self.client.scrape(url, formats=self.config.formats)
# Extract data from the result following official response handling
if result:
# The SDK returns a ScrapeData object with typed metadata
- metadata = getattr(result, "metadata", None)
+ metadata: FirecrawlMetadata | None = getattr(result, "metadata", None)
# Extract basic metadata
- title = getattr(metadata, "title", None) if metadata else None
- description = getattr(metadata, "description", None) if metadata else None
+ title: str | None = getattr(metadata, "title", None) if metadata else None
+ description: str | None = (
+ getattr(metadata, "description", None) if metadata else None
+ )
# Extract enhanced metadata if available
- author = getattr(metadata, "author", None) if metadata else None
- language = getattr(metadata, "language", None) if metadata else None
- sitemap_last_modified = (
+ author: str | None = getattr(metadata, "author", None) if metadata else None
+ language: str | None = getattr(metadata, "language", None) if metadata else None
+ sitemap_last_modified: str | None = (
getattr(metadata, "sitemap_last_modified", None) if metadata else None
)
- source_url = getattr(metadata, "sourceURL", None) if metadata else None
- keywords = getattr(metadata, "keywords", None) if metadata else None
- robots = getattr(metadata, "robots", None) if metadata else None
+ source_url: str | None = (
+ getattr(metadata, "sourceURL", None) if metadata else None
+ )
+ keywords: str | list[str] | None = (
+ getattr(metadata, "keywords", None) if metadata else None
+ )
+ robots: str | None = getattr(metadata, "robots", None) if metadata else None
# Open Graph metadata
- og_title = getattr(metadata, "ogTitle", None) if metadata else None
- og_description = getattr(metadata, "ogDescription", None) if metadata else None
- og_url = getattr(metadata, "ogUrl", None) if metadata else None
- og_image = getattr(metadata, "ogImage", None) if metadata else None
+ og_title: str | None = getattr(metadata, "ogTitle", None) if metadata else None
+ og_description: str | None = (
+ getattr(metadata, "ogDescription", None) if metadata else None
+ )
+ og_url: str | None = getattr(metadata, "ogUrl", None) if metadata else None
+ og_image: str | None = getattr(metadata, "ogImage", None) if metadata else None
# Twitter metadata
- twitter_card = getattr(metadata, "twitterCard", None) if metadata else None
- twitter_site = getattr(metadata, "twitterSite", None) if metadata else None
- twitter_creator = (
+ twitter_card: str | None = (
+ getattr(metadata, "twitterCard", None) if metadata else None
+ )
+ twitter_site: str | None = (
+ getattr(metadata, "twitterSite", None) if metadata else None
+ )
+ twitter_creator: str | None = (
getattr(metadata, "twitterCreator", None) if metadata else None
)
# Additional metadata
- favicon = getattr(metadata, "favicon", None) if metadata else None
- status_code = getattr(metadata, "statusCode", None) if metadata else None
+ favicon: str | None = getattr(metadata, "favicon", None) if metadata else None
+ status_code: int | None = (
+ getattr(metadata, "statusCode", None) if metadata else None
+ )
return FirecrawlPage(
url=url,
@@ -8350,7 +9618,7 @@ class FirecrawlIngestor(BaseIngestor):
source_url=source_url,
keywords=keywords.split(",")
if keywords and isinstance(keywords, str)
- else keywords,
+ else (keywords if isinstance(keywords, list) else None),
robots=robots,
og_title=og_title,
og_description=og_description,
@@ -8376,11 +9644,11 @@ class FirecrawlIngestor(BaseIngestor):
return uuid5(NAMESPACE_URL, source_url)
@staticmethod
- def _analyze_content_structure(content: str) -> dict[str, object]:
+ def _analyze_content_structure(content: str) -> dict[str, str | int | bool | list[str]]:
"""Analyze markdown content to extract structural information."""
# Extract heading hierarchy
heading_pattern = r"^(#{1,6})\s+(.+)$"
- headings = []
+ headings: list[str] = []
for match in re.finditer(heading_pattern, content, re.MULTILINE):
level = len(match.group(1))
text = match.group(2).strip()
@@ -8395,7 +9663,8 @@ class FirecrawlIngestor(BaseIngestor):
max_depth = 0
if headings:
for heading in headings:
- depth = (len(heading) - len(heading.lstrip())) // 2 + 1
+ heading_str: str = str(heading)
+ depth = (len(heading_str) - len(heading_str.lstrip())) // 2 + 1
max_depth = max(max_depth, depth)
return {
@@ -8509,10 +9778,16 @@ class FirecrawlIngestor(BaseIngestor):
"site_name": domain_info["site_name"],
# Document structure
"heading_hierarchy": (
- list(hierarchy_val) if (hierarchy_val := structure_info.get("heading_hierarchy")) and isinstance(hierarchy_val, (list, tuple)) else []
+ list(hierarchy_val)
+ if (hierarchy_val := structure_info.get("heading_hierarchy"))
+ and isinstance(hierarchy_val, (list, tuple))
+ else []
),
"section_depth": (
- int(depth_val) if (depth_val := structure_info.get("section_depth")) and isinstance(depth_val, (int, str)) else 0
+ int(depth_val)
+ if (depth_val := structure_info.get("section_depth"))
+ and isinstance(depth_val, (int, str))
+ else 0
),
"has_code_blocks": bool(structure_info.get("has_code_blocks", False)),
"has_images": bool(structure_info.get("has_images", False)),
@@ -8547,7 +9822,11 @@ class FirecrawlIngestor(BaseIngestor):
await self.client.close()
except Exception as e:
logging.debug(f"Error closing Firecrawl client: {e}")
- elif hasattr(self.client, "_session") and hasattr(self.client._session, "close"):
+ elif (
+ hasattr(self.client, "_session")
+ and self.client._session
+ and hasattr(self.client._session, "close")
+ ):
try:
await self.client._session.close()
except Exception as e:
@@ -8572,13 +9851,17 @@ class FirecrawlIngestor(BaseIngestor):
import json
from datetime import UTC, datetime
-from typing import Protocol, TypedDict, cast
+from typing import Final, Protocol, TypedDict, cast
import httpx
+from ..config import get_settings
from ..core.exceptions import IngestionError
from ..core.models import Document
+JSON_CONTENT_TYPE: Final[str] = "application/json"
+AUTHORIZATION_HEADER: Final[str] = "Authorization"
+
class HttpResponse(Protocol):
"""Protocol for HTTP response."""
@@ -8590,28 +9873,35 @@ class HttpResponse(Protocol):
class AsyncHttpClient(Protocol):
"""Protocol for async HTTP client."""
- async def post(
- self,
- url: str,
- *,
- json: dict[str, object] | None = None
- ) -> HttpResponse: ...
+ async def post(self, url: str, *, json: dict[str, object] | None = None) -> HttpResponse: ...
async def aclose(self) -> None: ...
+ async def __aenter__(self) -> "AsyncHttpClient": ...
+
+ async def __aexit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: object | None,
+ ) -> None: ...
+
class LlmResponse(TypedDict):
"""Type for LLM API response structure."""
+
choices: list[dict[str, object]]
class LlmChoice(TypedDict):
"""Type for individual choice in LLM response."""
+
message: dict[str, object]
class LlmMessage(TypedDict):
"""Type for message in LLM choice."""
+
content: str
@@ -8636,8 +9926,11 @@ class MetadataTagger:
def __init__(
self,
- llm_endpoint: str = "http://llm.lab",
- model: str = "fireworks/glm-4p5-air",
+ llm_endpoint: str | None = None,
+ model: str | None = None,
+ api_key: str | None = None,
+ *,
+ timeout: float | None = None,
):
"""
Initialize metadata tagger.
@@ -8645,30 +9938,26 @@ class MetadataTagger:
Args:
llm_endpoint: LLM API endpoint
model: Model to use for tagging
+ api_key: Explicit API key override
+ timeout: Optional request timeout override in seconds
"""
- self.endpoint = llm_endpoint.rstrip('/')
- self.model = model
+ settings = get_settings()
+ endpoint_value = llm_endpoint or str(settings.llm_endpoint)
+ self.endpoint = endpoint_value.rstrip("/")
+ self.model = model or settings.metadata_model
- # Get API key from environment
- import os
- from pathlib import Path
+ resolved_timeout = timeout if timeout is not None else float(settings.request_timeout)
+ resolved_api_key = api_key or settings.get_llm_api_key() or ""
- from dotenv import load_dotenv
-
- # Load .env from the project root
- env_path = Path(__file__).parent.parent.parent / ".env"
- _ = load_dotenv(env_path)
-
- api_key = os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
-
- headers = {"Content-Type": "application/json"}
- if api_key:
- headers["Authorization"] = f"Bearer {api_key}"
+ headers: dict[str, str] = {"Content-Type": JSON_CONTENT_TYPE}
+ if resolved_api_key:
+ headers[AUTHORIZATION_HEADER] = f"Bearer {resolved_api_key}"
# Create client with proper typing - httpx.AsyncClient implements AsyncHttpClient protocol
- AsyncClientClass = getattr(httpx, "AsyncClient")
- raw_client = AsyncClientClass(timeout=60.0, headers=headers)
- self.client = cast(AsyncHttpClient, raw_client)
+ self.client = cast(
+ AsyncHttpClient,
+ httpx.AsyncClient(timeout=resolved_timeout, headers=headers),
+ )
async def tag_document(
self, document: Document, custom_instructions: str | None = None
@@ -8719,12 +10008,16 @@ class MetadataTagger:
# Build a proper DocumentMetadata instance with only valid keys
new_metadata: CoreDocumentMetadata = {
"source_url": str(updated_metadata.get("source_url", "")),
- "timestamp": (
- lambda ts: ts if isinstance(ts, datetime) else datetime.now(UTC)
- )(updated_metadata.get("timestamp", datetime.now(UTC))),
+ "timestamp": (lambda ts: ts if isinstance(ts, datetime) else datetime.now(UTC))(
+ updated_metadata.get("timestamp", datetime.now(UTC))
+ ),
"content_type": str(updated_metadata.get("content_type", "text/plain")),
- "word_count": (lambda wc: int(wc) if isinstance(wc, (int, str)) else 0)(updated_metadata.get("word_count", 0)),
- "char_count": (lambda cc: int(cc) if isinstance(cc, (int, str)) else 0)(updated_metadata.get("char_count", 0)),
+ "word_count": (lambda wc: int(wc) if isinstance(wc, (int, str)) else 0)(
+ updated_metadata.get("word_count", 0)
+ ),
+ "char_count": (lambda cc: int(cc) if isinstance(cc, (int, str)) else 0)(
+ updated_metadata.get("char_count", 0)
+ ),
}
# Add optional fields if they exist
@@ -8926,16 +10219,79 @@ Return a JSON object with the following structure:
"""Vectorizer utility for generating embeddings."""
+import asyncio
from types import TracebackType
-from typing import Self, cast
+from typing import Final, NotRequired, Self, TypedDict
import httpx
-from typings import EmbeddingResponse
-
+from ..config import get_settings
from ..core.exceptions import VectorizationError
from ..core.models import StorageConfig, VectorConfig
+JSON_CONTENT_TYPE: Final[str] = "application/json"
+AUTHORIZATION_HEADER: Final[str] = "Authorization"
+
+
+class EmbeddingData(TypedDict):
+ """Structure for embedding data from providers."""
+
+ embedding: list[float]
+ index: NotRequired[int]
+ object: NotRequired[str]
+
+
+class EmbeddingResponse(TypedDict):
+ """Embedding response format for multiple providers."""
+
+ data: list[EmbeddingData]
+ model: NotRequired[str]
+ object: NotRequired[str]
+ usage: NotRequired[dict[str, int]]
+ # Alternative formats
+ embedding: NotRequired[list[float]]
+ vector: NotRequired[list[float]]
+ embeddings: NotRequired[list[list[float]]]
+
+
+def _extract_embedding_from_response(response_data: dict[str, object]) -> list[float]:
+ """Extract embedding vector from provider response."""
+ # OpenAI/Ollama format: {"data": [{"embedding": [...]}]}
+ if "data" in response_data:
+ data_list = response_data["data"]
+ if isinstance(data_list, list) and data_list:
+ first_item = data_list[0]
+ if isinstance(first_item, dict) and "embedding" in first_item:
+ embedding = first_item["embedding"]
+ if isinstance(embedding, list) and all(
+ isinstance(x, (int, float)) for x in embedding
+ ):
+ return [float(x) for x in embedding]
+
+ # Direct embedding format: {"embedding": [...]}
+ if "embedding" in response_data:
+ embedding = response_data["embedding"]
+ if isinstance(embedding, list) and all(isinstance(x, (int, float)) for x in embedding):
+ return [float(x) for x in embedding]
+
+ # Vector format: {"vector": [...]}
+ if "vector" in response_data:
+ vector = response_data["vector"]
+ if isinstance(vector, list) and all(isinstance(x, (int, float)) for x in vector):
+ return [float(x) for x in vector]
+
+ # Embeddings array format: {"embeddings": [[...]]}
+ if "embeddings" in response_data:
+ embeddings = response_data["embeddings"]
+ if isinstance(embeddings, list) and embeddings:
+ first_embedding = embeddings[0]
+ if isinstance(first_embedding, list) and all(
+ isinstance(x, (int, float)) for x in first_embedding
+ ):
+ return [float(x) for x in first_embedding]
+
+ raise VectorizationError("Unrecognized embedding response format")
+
class Vectorizer:
"""Handles text vectorization using LLM endpoints."""
@@ -8951,33 +10307,24 @@ class Vectorizer:
Args:
config: Configuration with embedding details
"""
+ settings = get_settings()
if isinstance(config, StorageConfig):
- # Extract vector config from storage config
- self.endpoint = "http://llm.lab"
- self.model = "ollama/bge-m3"
- self.dimension = 1024
+ # Extract vector config from global settings when storage config is provided
+ self.endpoint = str(settings.llm_endpoint).rstrip("/")
+ self.model = settings.embedding_model
+ self.dimension = settings.embedding_dimension
else:
- self.endpoint = str(config.embedding_endpoint)
+ self.endpoint = str(config.embedding_endpoint).rstrip("/")
self.model = config.model
self.dimension = config.dimension
- # Get API key from environment
- import os
- from pathlib import Path
+ resolved_api_key = settings.get_llm_api_key() or ""
+ headers: dict[str, str] = {"Content-Type": JSON_CONTENT_TYPE}
+ if resolved_api_key:
+ headers[AUTHORIZATION_HEADER] = f"Bearer {resolved_api_key}"
- from dotenv import load_dotenv
-
- # Load .env from the project root
- env_path = Path(__file__).parent.parent.parent / ".env"
- _ = load_dotenv(env_path)
-
- api_key = os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
-
- headers = {"Content-Type": "application/json"}
- if api_key:
- headers["Authorization"] = f"Bearer {api_key}"
-
- self.client: httpx.AsyncClient = httpx.AsyncClient(timeout=60.0, headers=headers)
+ timeout_seconds = float(settings.request_timeout)
+ self.client = httpx.AsyncClient(timeout=timeout_seconds, headers=headers)
async def vectorize(self, text: str) -> list[float]:
"""
@@ -9003,21 +10350,34 @@ class Vectorizer:
async def vectorize_batch(self, texts: list[str]) -> list[list[float]]:
"""
- Generate embeddings for multiple texts.
+ Generate embeddings for multiple texts in parallel.
Args:
texts: List of texts to vectorize
Returns:
List of embedding vectors
+
+ Raises:
+ VectorizationError: If any vectorization fails
"""
- vectors: list[list[float]] = []
- for text in texts:
- vector = await self.vectorize(text)
- vectors.append(vector)
+ if not texts:
+ return []
- return vectors
+ # Use semaphore to limit concurrent requests and prevent overwhelming the endpoint
+ semaphore = asyncio.Semaphore(20)
+
+ async def vectorize_with_semaphore(text: str) -> list[float]:
+ async with semaphore:
+ return await self.vectorize(text)
+
+ try:
+ # Execute all vectorization requests concurrently
+ vectors = await asyncio.gather(*[vectorize_with_semaphore(text) for text in texts])
+ return list(vectors)
+ except Exception as e:
+ raise VectorizationError(f"Batch vectorization failed: {e}") from e
async def _ollama_embed(self, text: str) -> list[float]:
"""
@@ -9043,23 +10403,12 @@ class Vectorizer:
_ = response.raise_for_status()
response_json = response.json()
- # Response is expected to be dict[str, object] from our type stub
+ if not isinstance(response_json, dict):
+ raise VectorizationError("Invalid JSON response format")
- response_data = cast(EmbeddingResponse, cast(object, response_json))
+ # Extract embedding using type-safe helper
+ embedding = _extract_embedding_from_response(response_json)
- # Parse OpenAI-compatible response format
- embeddings_list = response_data.get("data", [])
- if not embeddings_list:
- raise VectorizationError("No embeddings returned")
-
- first_embedding = embeddings_list[0]
- embedding_raw = first_embedding.get("embedding")
- if not embedding_raw:
- raise VectorizationError("Invalid embedding format")
-
- # Convert to float list and validate
- embedding: list[float] = []
- embedding.extend(float(item) for item in embedding_raw)
# Ensure correct dimension
if len(embedding) != self.dimension:
raise VectorizationError(
@@ -9088,22 +10437,12 @@ class Vectorizer:
_ = response.raise_for_status()
response_json = response.json()
- # Response is expected to be dict[str, object] from our type stub
+ if not isinstance(response_json, dict):
+ raise VectorizationError("Invalid JSON response format")
- response_data = cast(EmbeddingResponse, cast(object, response_json))
+ # Extract embedding using type-safe helper
+ embedding = _extract_embedding_from_response(response_json)
- embeddings_list = response_data.get("data", [])
- if not embeddings_list:
- raise VectorizationError("No embeddings returned")
-
- first_embedding = embeddings_list[0]
- embedding_raw = first_embedding.get("embedding")
- if not embedding_raw:
- raise VectorizationError("Invalid embedding format")
-
- # Convert to float list and validate
- embedding: list[float] = []
- embedding.extend(float(item) for item in embedding_raw)
# Ensure correct dimension
if len(embedding) != self.dimension:
raise VectorizationError(
@@ -9116,6 +10455,14 @@ class Vectorizer:
"""Async context manager entry."""
return self
+ async def close(self) -> None:
+ """Close the HTTP client connection."""
+ try:
+ await self.client.aclose()
+ except Exception:
+ # Already closed or connection lost
+ pass
+
async def __aexit__(
self,
exc_type: type[BaseException] | None,
@@ -9123,14 +10470,13 @@ class Vectorizer:
exc_tb: TracebackType | None,
) -> None:
"""Async context manager exit."""
- await self.client.aclose()
+ await self.close()
"""Main dashboard screen with collections overview."""
import logging
-from datetime import datetime
from typing import TYPE_CHECKING, Final
from textual import work
@@ -9335,7 +10681,11 @@ class CollectionOverviewScreen(Screen[None]):
"""Calculate basic metrics from collections."""
self.total_collections = len(self.collections)
self.total_documents = sum(col["count"] for col in self.collections)
- self.active_backends = sum([bool(self.weaviate), bool(self.openwebui), bool(self.r2r)])
+ # Calculate active backends from storage manager if individual storages are None
+ if self.weaviate is None and self.openwebui is None and self.r2r is None:
+ self.active_backends = len(self.storage_manager.get_available_backends())
+ else:
+ self.active_backends = sum([bool(self.weaviate), bool(self.openwebui), bool(self.r2r)])
def _update_metrics_cards(self) -> None:
"""Update the metrics cards display."""
@@ -9482,76 +10832,6 @@ class CollectionOverviewScreen(Screen[None]):
self.is_loading = False
loading_indicator.display = False
- async def list_weaviate_collections(self) -> list[CollectionInfo]:
- """List Weaviate collections with enhanced metadata."""
- if not self.weaviate:
- return []
-
- try:
- overview = await self.weaviate.describe_collections()
- collections: list[CollectionInfo] = []
-
- for item in overview:
- count_raw = item.get("count", 0)
- count_val = int(count_raw) if isinstance(count_raw, (int, str)) else 0
- size_mb_raw = item.get("size_mb", 0.0)
- size_mb_val = float(size_mb_raw) if isinstance(size_mb_raw, (int, float, str)) else 0.0
- collections.append(
- CollectionInfo(
- name=str(item.get("name", "Unknown")),
- type="weaviate",
- count=count_val,
- backend="🗄️ Weaviate",
- status="✓ Active",
- last_updated=datetime.now().strftime("%Y-%m-%d %H:%M"),
- size_mb=size_mb_val,
- )
- )
-
- return collections
- except Exception as e:
- self.notify(f"Error listing Weaviate collections: {e}", severity="error", markup=False)
- return []
-
- async def list_openwebui_collections(self) -> list[CollectionInfo]:
- """List OpenWebUI collections with enhanced metadata."""
- # Try to get OpenWebUI backend from storage manager if direct instance not available
- openwebui_backend = self.openwebui
- if not openwebui_backend:
- backend = self.storage_manager.get_backend(StorageBackend.OPEN_WEBUI)
- if not isinstance(backend, OpenWebUIStorage):
- return []
- openwebui_backend = backend
- if not openwebui_backend:
- return []
-
- try:
- overview = await openwebui_backend.describe_collections()
- collections: list[CollectionInfo] = []
-
- for item in overview:
- count_raw = item.get("count", 0)
- count_val = int(count_raw) if isinstance(count_raw, (int, str)) else 0
- size_mb_raw = item.get("size_mb", 0.0)
- size_mb_val = float(size_mb_raw) if isinstance(size_mb_raw, (int, float, str)) else 0.0
- collection_name = str(item.get("name", "Unknown"))
- collections.append(
- CollectionInfo(
- name=collection_name,
- type="openwebui",
- count=count_val,
- backend="🌐 OpenWebUI",
- status="✓ Active",
- last_updated=datetime.now().strftime("%Y-%m-%d %H:%M"),
- size_mb=size_mb_val,
- )
- )
-
- return collections
- except Exception as e:
- self.notify(f"Error listing OpenWebUI collections: {e}", severity="error", markup=False)
- return []
-
async def update_collections_table(self) -> None:
"""Update the collections table with enhanced formatting."""
table = self.query_one("#collections_table", EnhancedDataTable)
@@ -9825,7 +11105,7 @@ Enjoy the enhanced interface! 🎉
from __future__ import annotations
from pathlib import Path
-from typing import TYPE_CHECKING, ClassVar
+from typing import TYPE_CHECKING
from textual.app import ComposeResult
from textual.binding import Binding
@@ -9837,6 +11117,7 @@ from typing_extensions import override
from ..models import CollectionInfo
if TYPE_CHECKING:
+ from ..app import CollectionManagementApp
from .dashboard import CollectionOverviewScreen
from .documents import DocumentManagementScreen
@@ -9847,7 +11128,12 @@ class ConfirmDeleteScreen(Screen[None]):
collection: CollectionInfo
parent_screen: CollectionOverviewScreen
- BINDINGS: list[Binding] = [
+ @property
+ def app(self) -> CollectionManagementApp: # type: ignore[override]
+ """Return the typed app instance."""
+ return super().app # type: ignore[return-value]
+
+ BINDINGS = [
Binding("escape", "app.pop_screen", "Cancel"),
Binding("y", "confirm_delete", "Yes"),
Binding("n", "app.pop_screen", "No"),
@@ -9898,7 +11184,10 @@ class ConfirmDeleteScreen(Screen[None]):
try:
if self.collection["type"] == "weaviate" and self.parent_screen.weaviate:
# Delete Weaviate collection
- if self.parent_screen.weaviate.client and self.parent_screen.weaviate.client.collections:
+ if (
+ self.parent_screen.weaviate.client
+ and self.parent_screen.weaviate.client.collections
+ ):
self.parent_screen.weaviate.client.collections.delete(self.collection["name"])
self.notify(
f"Deleted Weaviate collection: {self.collection['name']}",
@@ -9916,7 +11205,7 @@ class ConfirmDeleteScreen(Screen[None]):
return
# Check if the storage backend supports collection deletion
- if not hasattr(storage_backend, 'delete_collection'):
+ if not hasattr(storage_backend, "delete_collection"):
self.notify(
f"❌ Collection deletion not supported for {self.collection['type']} backend",
severity="error",
@@ -9929,10 +11218,13 @@ class ConfirmDeleteScreen(Screen[None]):
collection_name = str(self.collection["name"])
collection_type = str(self.collection["type"])
- self.notify(f"Deleting {collection_type} collection: {collection_name}...", severity="information")
+ self.notify(
+ f"Deleting {collection_type} collection: {collection_name}...",
+ severity="information",
+ )
# Use the standard delete_collection method for all backends
- if hasattr(storage_backend, 'delete_collection'):
+ if hasattr(storage_backend, "delete_collection"):
success = await storage_backend.delete_collection(collection_name)
else:
self.notify("❌ Backend does not support collection deletion", severity="error")
@@ -9954,12 +11246,15 @@ class ConfirmDeleteScreen(Screen[None]):
return
# Refresh parent screen after a short delay to ensure deletion is processed
- self.call_later(lambda _: self.parent_screen.refresh_collections(), 0.5) # 500ms delay
+ self.call_later(self._refresh_parent_collections, 0.5) # 500ms delay
self.app.pop_screen()
except Exception as e:
self.notify(f"Failed to delete collection: {e}", severity="error", markup=False)
+ def _refresh_parent_collections(self) -> None:
+ """Helper method to refresh parent collections."""
+ self.parent_screen.refresh_collections()
class ConfirmDocumentDeleteScreen(Screen[None]):
@@ -9967,9 +11262,14 @@ class ConfirmDocumentDeleteScreen(Screen[None]):
doc_ids: list[str]
collection: CollectionInfo
- parent_screen: "DocumentManagementScreen"
+ parent_screen: DocumentManagementScreen
- BINDINGS: list[Binding] = [
+ @property
+ def app(self) -> CollectionManagementApp: # type: ignore[override]
+ """Return the typed app instance."""
+ return super().app # type: ignore[return-value]
+
+ BINDINGS = [
Binding("escape", "app.pop_screen", "Cancel"),
Binding("y", "confirm_delete", "Yes"),
Binding("n", "app.pop_screen", "No"),
@@ -9980,7 +11280,7 @@ class ConfirmDocumentDeleteScreen(Screen[None]):
self,
doc_ids: list[str],
collection: CollectionInfo,
- parent_screen: "DocumentManagementScreen",
+ parent_screen: DocumentManagementScreen,
):
super().__init__()
self.doc_ids = doc_ids
@@ -10030,11 +11330,11 @@ class ConfirmDocumentDeleteScreen(Screen[None]):
try:
results: dict[str, bool] = {}
- if hasattr(self.parent_screen, 'storage') and self.parent_screen.storage:
+ if hasattr(self.parent_screen, "storage") and self.parent_screen.storage:
# Delete documents via storage
# The storage should have delete_documents method for weaviate
storage = self.parent_screen.storage
- if hasattr(storage, 'delete_documents'):
+ if hasattr(storage, "delete_documents"):
results = await storage.delete_documents(
self.doc_ids,
collection_name=self.collection["name"],
@@ -10066,7 +11366,12 @@ class LogViewerScreen(ModalScreen[None]):
_log_widget: RichLog | None
_log_file: Path | None
- BINDINGS: list[Binding] = [
+ @property
+ def app(self) -> CollectionManagementApp: # type: ignore[override]
+ """Return the typed app instance."""
+ return super().app # type: ignore[return-value]
+
+ BINDINGS = [
Binding("escape", "close", "Close"),
Binding("ctrl+l", "close", "Close"),
Binding("s", "show_path", "Log File"),
@@ -10082,7 +11387,9 @@ class LogViewerScreen(ModalScreen[None]):
yield Header(show_clock=True)
yield Container(
Static("📜 Live Application Logs", classes="title"),
- Static("Logs update in real time. Press S to reveal the log file path.", classes="subtitle"),
+ Static(
+ "Logs update in real time. Press S to reveal the log file path.", classes="subtitle"
+ ),
RichLog(id="log_stream", classes="log-stream", wrap=True, highlight=False),
Static("", id="log_file_path", classes="subtitle"),
classes="main_container log-viewer-container",
@@ -10093,14 +11400,14 @@ class LogViewerScreen(ModalScreen[None]):
"""Attach this viewer to the parent application once mounted."""
self._log_widget = self.query_one(RichLog)
- if hasattr(self.app, 'attach_log_viewer'):
- self.app.attach_log_viewer(self)
+ if hasattr(self.app, "attach_log_viewer"):
+ self.app.attach_log_viewer(self) # type: ignore[arg-type]
def on_unmount(self) -> None:
"""Detach from the parent application when closed."""
- if hasattr(self.app, 'detach_log_viewer'):
- self.app.detach_log_viewer(self)
+ if hasattr(self.app, "detach_log_viewer"):
+ self.app.detach_log_viewer(self) # type: ignore[arg-type]
def _get_log_widget(self) -> RichLog:
if self._log_widget is None:
@@ -10142,14 +11449,16 @@ class LogViewerScreen(ModalScreen[None]):
if self._log_file is None:
self.notify("File logging is disabled for this session.", severity="warning")
else:
- self.notify(f"Log file available at: {self._log_file}", severity="information", markup=False)
+ self.notify(
+ f"Log file available at: {self._log_file}", severity="information", markup=False
+ )
"""Application settings and configuration."""
from functools import lru_cache
-from typing import Annotated, ClassVar, Literal
+from typing import Annotated, ClassVar, Final, Literal
from prefect.variables import Variable
from pydantic import Field, HttpUrl, model_validator
@@ -10168,6 +11477,8 @@ class Settings(BaseSettings):
# API Keys
firecrawl_api_key: str | None = None
+ llm_api_key: str | None = None
+ openai_api_key: str | None = None
openwebui_api_key: str | None = None
weaviate_api_key: str | None = None
r2r_api_key: str | None = None
@@ -10181,6 +11492,7 @@ class Settings(BaseSettings):
# Model Configuration
embedding_model: str = "ollama/bge-m3:latest"
+ metadata_model: str = "fireworks/glm-4p5-air"
embedding_dimension: int = 1024
# Ingestion Settings
@@ -10248,14 +11560,20 @@ class Settings(BaseSettings):
Returns:
API key or None
"""
- service_map = {
+ service_map: Final[dict[str, str | None]] = {
"firecrawl": self.firecrawl_api_key,
"openwebui": self.openwebui_api_key,
"weaviate": self.weaviate_api_key,
"r2r": self.r2r_api_key,
+ "llm": self.get_llm_api_key(),
+ "openai": self.openai_api_key,
}
return service_map.get(service)
+ def get_llm_api_key(self) -> str | None:
+ """Get API key for LLM services with OpenAI fallback."""
+ return self.llm_api_key or (self.openai_api_key or None)
+
@model_validator(mode="after")
def validate_backend_configuration(self) -> "Settings":
"""Validate that required configuration is present for the default backend."""
@@ -10278,10 +11596,11 @@ class Settings(BaseSettings):
key_name, key_value = required_keys[backend]
if not key_value:
import warnings
+
warnings.warn(
f"{key_name} not set - authentication may fail for {backend} backend",
UserWarning,
- stacklevel=2
+ stacklevel=2,
)
return self
@@ -10304,16 +11623,24 @@ class PrefectVariableConfig:
def __init__(self) -> None:
self._settings: Settings = get_settings()
self._variable_names: list[str] = [
- "default_batch_size", "max_file_size", "max_crawl_depth", "max_crawl_pages",
- "default_storage_backend", "default_collection_prefix", "max_concurrent_tasks",
- "request_timeout", "default_schedule_interval"
+ "default_batch_size",
+ "max_file_size",
+ "max_crawl_depth",
+ "max_crawl_pages",
+ "default_storage_backend",
+ "default_collection_prefix",
+ "max_concurrent_tasks",
+ "request_timeout",
+ "default_schedule_interval",
]
def _get_fallback_value(self, name: str, default_value: object = None) -> object:
"""Get fallback value from settings or default."""
return default_value or getattr(self._settings, name, default_value)
- def get_with_fallback(self, name: str, default_value: str | int | float | None = None) -> str | int | float | None:
+ def get_with_fallback(
+ self, name: str, default_value: str | int | float | None = None
+ ) -> str | int | float | None:
"""Get variable value with fallback synchronously."""
fallback = self._get_fallback_value(name, default_value)
# Ensure fallback is a type that Variable expects
@@ -10334,7 +11661,9 @@ class PrefectVariableConfig:
return fallback
return str(fallback) if fallback is not None else None
- async def get_with_fallback_async(self, name: str, default_value: str | int | float | None = None) -> str | int | float | None:
+ async def get_with_fallback_async(
+ self, name: str, default_value: str | int | float | None = None
+ ) -> str | int | float | None:
"""Get variable value with fallback asynchronously."""
fallback = self._get_fallback_value(name, default_value)
variable_fallback = str(fallback) if fallback is not None else None
@@ -10383,6 +11712,41 @@ from uuid import UUID, uuid4
from prefect.blocks.core import Block
from pydantic import BaseModel, Field, HttpUrl, SecretStr
+from ..config import get_settings
+
+
+def _default_embedding_model() -> str:
+ return str(get_settings().embedding_model)
+
+
+def _default_embedding_endpoint() -> HttpUrl:
+ endpoint = get_settings().llm_endpoint
+ return endpoint if isinstance(endpoint, HttpUrl) else HttpUrl(str(endpoint))
+
+
+def _default_embedding_dimension() -> int:
+ return int(get_settings().embedding_dimension)
+
+
+def _default_batch_size() -> int:
+ return int(get_settings().default_batch_size)
+
+
+def _default_collection_name() -> str:
+ return str(get_settings().default_collection_prefix)
+
+
+def _default_max_crawl_depth() -> int:
+ return int(get_settings().max_crawl_depth)
+
+
+def _default_max_crawl_pages() -> int:
+ return int(get_settings().max_crawl_pages)
+
+
+def _default_max_file_size() -> int:
+ return int(get_settings().max_file_size)
+
class IngestionStatus(str, Enum):
"""Status of an ingestion job."""
@@ -10414,36 +11778,39 @@ class IngestionSource(str, Enum):
class VectorConfig(BaseModel):
"""Configuration for vectorization."""
- model: str = Field(default="ollama/bge-m3:latest")
- embedding_endpoint: HttpUrl = Field(default=HttpUrl("http://llm.lab"))
- dimension: int = Field(default=1024)
- batch_size: Annotated[int, Field(gt=0, le=1000)] = 100
+ model: str = Field(default_factory=_default_embedding_model)
+ embedding_endpoint: HttpUrl = Field(default_factory=_default_embedding_endpoint)
+ dimension: int = Field(default_factory=_default_embedding_dimension)
+ batch_size: Annotated[int, Field(gt=0, le=1000)] = Field(default_factory=_default_batch_size)
class StorageConfig(Block):
"""Configuration for storage backend."""
- _block_type_name: ClassVar[str] = "Storage Configuration"
- _block_type_slug: ClassVar[str] = "storage-config"
- _description: ClassVar[str] = "Configures storage backend connections and settings for document ingestion"
+ _block_type_name: ClassVar[str | None] = "Storage Configuration"
+ _block_type_slug: ClassVar[str | None] = "storage-config"
+ _description: ClassVar[str | None] = (
+ "Configures storage backend connections and settings for document ingestion"
+ )
backend: StorageBackend
endpoint: HttpUrl
api_key: SecretStr | None = Field(default=None)
- collection_name: str = Field(default="documents")
- batch_size: Annotated[int, Field(gt=0, le=1000)] = 100
+ collection_name: str = Field(default_factory=_default_collection_name)
+ batch_size: Annotated[int, Field(gt=0, le=1000)] = Field(default_factory=_default_batch_size)
+ grpc_port: int | None = Field(default=None, description="gRPC port for Weaviate connections")
class FirecrawlConfig(Block):
"""Configuration for Firecrawl ingestion (operational parameters only)."""
- _block_type_name: ClassVar[str] = "Firecrawl Configuration"
- _block_type_slug: ClassVar[str] = "firecrawl-config"
- _description: ClassVar[str] = "Configures Firecrawl web scraping and crawling parameters"
+ _block_type_name: ClassVar[str | None] = "Firecrawl Configuration"
+ _block_type_slug: ClassVar[str | None] = "firecrawl-config"
+ _description: ClassVar[str | None] = "Configures Firecrawl web scraping and crawling parameters"
formats: list[str] = Field(default_factory=lambda: ["markdown", "html"])
- max_depth: Annotated[int, Field(ge=1, le=20)] = 5
- limit: Annotated[int, Field(ge=1, le=1000)] = 100
+ max_depth: Annotated[int, Field(ge=1, le=20)] = Field(default_factory=_default_max_crawl_depth)
+ limit: Annotated[int, Field(ge=1, le=1000)] = Field(default_factory=_default_max_crawl_pages)
only_main_content: bool = Field(default=True)
include_subdomains: bool = Field(default=False)
@@ -10451,9 +11818,11 @@ class FirecrawlConfig(Block):
class RepomixConfig(Block):
"""Configuration for Repomix ingestion."""
- _block_type_name: ClassVar[str] = "Repomix Configuration"
- _block_type_slug: ClassVar[str] = "repomix-config"
- _description: ClassVar[str] = "Configures repository ingestion patterns and file processing settings"
+ _block_type_name: ClassVar[str | None] = "Repomix Configuration"
+ _block_type_slug: ClassVar[str | None] = "repomix-config"
+ _description: ClassVar[str | None] = (
+ "Configures repository ingestion patterns and file processing settings"
+ )
include_patterns: list[str] = Field(
default_factory=lambda: ["*.py", "*.js", "*.ts", "*.md", "*.yaml", "*.json"]
@@ -10461,16 +11830,18 @@ class RepomixConfig(Block):
exclude_patterns: list[str] = Field(
default_factory=lambda: ["**/node_modules/**", "**/__pycache__/**", "**/.git/**"]
)
- max_file_size: int = Field(default=1_000_000) # 1MB
+ max_file_size: int = Field(default_factory=_default_max_file_size) # 1MB
respect_gitignore: bool = Field(default=True)
class R2RConfig(Block):
"""Configuration for R2R ingestion."""
- _block_type_name: ClassVar[str] = "R2R Configuration"
- _block_type_slug: ClassVar[str] = "r2r-config"
- _description: ClassVar[str] = "Configures R2R-specific ingestion settings including chunking and graph enrichment"
+ _block_type_name: ClassVar[str | None] = "R2R Configuration"
+ _block_type_slug: ClassVar[str | None] = "r2r-config"
+ _description: ClassVar[str | None] = (
+ "Configures R2R-specific ingestion settings including chunking and graph enrichment"
+ )
chunk_size: Annotated[int, Field(ge=100, le=8192)] = 1000
chunk_overlap: Annotated[int, Field(ge=0, le=1000)] = 200
@@ -10480,6 +11851,7 @@ class R2RConfig(Block):
class DocumentMetadataRequired(TypedDict):
"""Required metadata fields for a document."""
+
source_url: str
timestamp: datetime
content_type: str
@@ -10543,7 +11915,7 @@ class Document(BaseModel):
vector: list[float] | None = Field(default=None)
score: float | None = Field(default=None)
source: IngestionSource
- collection: str = Field(default="documents")
+ collection: str = Field(default_factory=_default_collection_name)
class IngestionJob(BaseModel):
@@ -10572,734 +11944,6 @@ class IngestionResult(BaseModel):
error_messages: list[str] = Field(default_factory=list)
-
-"""Prefect flow for ingestion pipeline."""
-
-from __future__ import annotations
-
-from collections.abc import Callable
-from datetime import UTC, datetime
-from typing import TYPE_CHECKING, Literal, TypeAlias, assert_never, cast
-
-from prefect import flow, get_run_logger, task
-from prefect.blocks.core import Block
-from prefect.variables import Variable
-from pydantic import SecretStr
-
-from ..config.settings import Settings
-from ..core.exceptions import IngestionError
-from ..core.models import (
- Document,
- FirecrawlConfig,
- IngestionJob,
- IngestionResult,
- IngestionSource,
- IngestionStatus,
- RepomixConfig,
- StorageBackend,
- StorageConfig,
-)
-from ..ingestors import BaseIngestor, FirecrawlIngestor, FirecrawlPage, RepomixIngestor
-from ..storage import OpenWebUIStorage, WeaviateStorage
-from ..storage import R2RStorage as RuntimeR2RStorage
-from ..storage.base import BaseStorage
-from ..utils.metadata_tagger import MetadataTagger
-
-SourceTypeLiteral = Literal["web", "repository", "documentation"]
-StorageBackendLiteral = Literal["weaviate", "open_webui", "r2r"]
-SourceTypeLike: TypeAlias = IngestionSource | SourceTypeLiteral
-StorageBackendLike: TypeAlias = StorageBackend | StorageBackendLiteral
-
-
-def _safe_cache_key(prefix: str, params: dict[str, object], key: str) -> str:
- """Create a type-safe cache key from task parameters."""
- value = params.get(key, "")
- return f"{prefix}_{hash(str(value))}"
-
-
-if TYPE_CHECKING:
- from ..storage.r2r.storage import R2RStorage as R2RStorageType
-else:
- R2RStorageType = BaseStorage
-
-
-@task(name="validate_source", retries=2, retry_delay_seconds=10, tags=["validation"])
-async def validate_source_task(source_url: str, source_type: IngestionSource) -> bool:
- """
- Validate that a source is accessible.
-
- Args:
- source_url: URL or path to source
- source_type: Type of source
-
- Returns:
- True if valid
- """
- if source_type == IngestionSource.WEB:
- ingestor = FirecrawlIngestor()
- elif source_type == IngestionSource.REPOSITORY:
- ingestor = RepomixIngestor()
- else:
- raise ValueError(f"Unsupported source type: {source_type}")
-
- result = await ingestor.validate_source(source_url)
- return bool(result)
-
-
-@task(name="initialize_storage", retries=3, retry_delay_seconds=5, tags=["storage"])
-async def initialize_storage_task(config: StorageConfig | str) -> BaseStorage:
- """
- Initialize storage backend.
-
- Args:
- config: Storage configuration block or block name
-
- Returns:
- Initialized storage adapter
- """
- # Load block if string provided
- if isinstance(config, str):
- # Use Block.aload with type slug for better type inference
- loaded_block = await Block.aload(f"storage-config/{config}")
- config = cast(StorageConfig, loaded_block)
-
- if config.backend == StorageBackend.WEAVIATE:
- storage = WeaviateStorage(config)
- elif config.backend == StorageBackend.OPEN_WEBUI:
- storage = OpenWebUIStorage(config)
- elif config.backend == StorageBackend.R2R:
- if RuntimeR2RStorage is None:
- raise ValueError("R2R storage not available. Check dependencies.")
- storage = RuntimeR2RStorage(config)
- else:
- raise ValueError(f"Unsupported backend: {config.backend}")
-
- await storage.initialize()
- return storage
-
-
-@task(name="map_firecrawl_site", retries=2, retry_delay_seconds=15, tags=["firecrawl", "map"],
- cache_key_fn=lambda ctx, p: _safe_cache_key("firecrawl_map", p, "source_url"))
-async def map_firecrawl_site_task(source_url: str, config: FirecrawlConfig | str) -> list[str]:
- """Map a site using Firecrawl and return discovered URLs."""
- # Load block if string provided
- if isinstance(config, str):
- # Use Block.aload with type slug for better type inference
- loaded_block = await Block.aload(f"firecrawl-config/{config}")
- config = cast(FirecrawlConfig, loaded_block)
-
- ingestor = FirecrawlIngestor(config)
- mapped = await ingestor.map_site(source_url)
- return mapped or [source_url]
-
-
-@task(name="filter_existing_documents", retries=1, retry_delay_seconds=5, tags=["dedup"],
- cache_key_fn=lambda ctx, p: _safe_cache_key("filter_docs", p, "urls")) # Cache based on URL list
-async def filter_existing_documents_task(
- urls: list[str],
- storage_client: BaseStorage,
- stale_after_days: int = 30,
- *,
- collection_name: str | None = None,
-) -> list[str]:
- """Filter URLs to only those that need scraping (missing or stale in storage)."""
- logger = get_run_logger()
- eligible: list[str] = []
-
- for url in urls:
- document_id = str(FirecrawlIngestor.compute_document_id(url))
- exists = await storage_client.check_exists(
- document_id,
- collection_name=collection_name,
- stale_after_days=stale_after_days
- )
-
- if not exists:
- eligible.append(url)
-
- skipped = len(urls) - len(eligible)
- if skipped > 0:
- logger.info("Skipping %s up-to-date documents in %s", skipped, storage_client.display_name)
-
- return eligible
-
-
-@task(
- name="scrape_firecrawl_batch", retries=2, retry_delay_seconds=20, tags=["firecrawl", "scrape"]
-)
-async def scrape_firecrawl_batch_task(
- batch_urls: list[str], config: FirecrawlConfig
-) -> list[FirecrawlPage]:
- """Scrape a batch of URLs via Firecrawl."""
- ingestor = FirecrawlIngestor(config)
- result: list[FirecrawlPage] = await ingestor.scrape_pages(batch_urls)
- return result
-
-
-@task(name="annotate_firecrawl_metadata", retries=1, retry_delay_seconds=10, tags=["metadata"])
-async def annotate_firecrawl_metadata_task(
- pages: list[FirecrawlPage], job: IngestionJob
-) -> list[Document]:
- """Annotate scraped pages with standardized metadata."""
- if not pages:
- return []
-
- ingestor = FirecrawlIngestor()
- documents = [ingestor.create_document(page, job) for page in pages]
-
- try:
- from ..config import get_settings
-
- settings = get_settings()
- async with MetadataTagger(llm_endpoint=str(settings.llm_endpoint)) as tagger:
- tagged_documents: list[Document] = await tagger.tag_batch(documents)
- return tagged_documents
- except IngestionError as exc: # pragma: no cover - logging side effect
- logger = get_run_logger()
- logger.warning("Metadata tagging failed: %s", exc)
- return documents
- except Exception as exc: # pragma: no cover - defensive
- logger = get_run_logger()
- logger.warning("Metadata tagging unavailable, using base metadata: %s", exc)
- return documents
-
-
-@task(name="upsert_r2r_documents", retries=2, retry_delay_seconds=20, tags=["storage", "r2r"])
-async def upsert_r2r_documents_task(
- storage_client: R2RStorageType,
- documents: list[Document],
- collection_name: str | None,
-) -> tuple[int, int]:
- """Upsert documents into R2R storage."""
- if not documents:
- return 0, 0
-
- stored_ids: list[str] = await storage_client.store_batch(
- documents, collection_name=collection_name
- )
- processed = len(stored_ids)
- failed = len(documents) - processed
-
- if failed:
- logger = get_run_logger()
- logger.warning("Failed to upsert %s documents to R2R", failed)
-
- return processed, failed
-
-
-@task(name="ingest_documents", retries=2, retry_delay_seconds=30, tags=["ingestion"])
-async def ingest_documents_task(
- job: IngestionJob,
- collection_name: str | None = None,
- batch_size: int | None = None,
- storage_client: BaseStorage | None = None,
- storage_block_name: str | None = None,
- ingestor_config_block_name: str | None = None,
- progress_callback: Callable[[int, str], None] | None = None,
-) -> tuple[int, int]:
- """
- Ingest documents from source with optional pre-initialized storage client.
-
- Args:
- job: Ingestion job configuration
- collection_name: Target collection name
- batch_size: Number of documents per batch (uses Variable if None)
- storage_client: Optional pre-initialized storage client
- storage_block_name: Optional storage block name to load
- ingestor_config_block_name: Optional ingestor config block name to load
- progress_callback: Optional callback for progress updates
-
- Returns:
- Tuple of (processed_count, failed_count)
- """
- if progress_callback:
- progress_callback(35, "Creating ingestor and storage clients...")
-
- # Use Variable for batch size if not provided
- if batch_size is None:
- try:
- batch_size_var = await Variable.aget("default_batch_size", default="50")
- # Convert Variable result to int, handling various types
- if isinstance(batch_size_var, int):
- batch_size = batch_size_var
- elif isinstance(batch_size_var, (str, float)):
- batch_size = int(float(str(batch_size_var)))
- else:
- batch_size = 50
- except Exception:
- batch_size = 50
-
- ingestor = await _create_ingestor(job, ingestor_config_block_name)
- storage = storage_client or await _create_storage(job, collection_name, storage_block_name)
-
- if progress_callback:
- progress_callback(40, "Starting document processing...")
-
- return await _process_documents(ingestor, storage, job, batch_size, collection_name, progress_callback)
-
-
-async def _create_ingestor(job: IngestionJob, config_block_name: str | None = None) -> BaseIngestor:
- """Create appropriate ingestor based on job source type."""
- if job.source_type == IngestionSource.WEB:
- if config_block_name:
- # Use Block.aload with type slug for better type inference
- loaded_block = await Block.aload(f"firecrawl-config/{config_block_name}")
- config = cast(FirecrawlConfig, loaded_block)
- else:
- # Fallback to default configuration
- config = FirecrawlConfig()
- return FirecrawlIngestor(config)
- elif job.source_type == IngestionSource.REPOSITORY:
- if config_block_name:
- # Use Block.aload with type slug for better type inference
- loaded_block = await Block.aload(f"repomix-config/{config_block_name}")
- config = cast(RepomixConfig, loaded_block)
- else:
- # Fallback to default configuration
- config = RepomixConfig()
- return RepomixIngestor(config)
- else:
- raise ValueError(f"Unsupported source: {job.source_type}")
-
-
-async def _create_storage(job: IngestionJob, collection_name: str | None, storage_block_name: str | None = None) -> BaseStorage:
- """Create and initialize storage client."""
- if collection_name is None:
- # Use variable for default collection prefix
- prefix = await Variable.aget("default_collection_prefix", default="docs")
- collection_name = f"{prefix}_{job.source_type.value}"
-
- if storage_block_name:
- # Load storage config from block
- loaded_block = await Block.aload(f"storage-config/{storage_block_name}")
- storage_config = cast(StorageConfig, loaded_block)
- # Override collection name if provided
- storage_config.collection_name = collection_name
- else:
- # Fallback to building config from settings
- from ..config import get_settings
- settings = get_settings()
- storage_config = _build_storage_config(job, settings, collection_name)
-
- storage = _instantiate_storage(job.storage_backend, storage_config)
- await storage.initialize()
- return storage
-
-
-def _build_storage_config(
- job: IngestionJob, settings: Settings, collection_name: str
-) -> StorageConfig:
- """Build storage configuration from job and settings."""
- storage_endpoints = {
- StorageBackend.WEAVIATE: settings.weaviate_endpoint,
- StorageBackend.OPEN_WEBUI: settings.openwebui_endpoint,
- StorageBackend.R2R: settings.get_storage_endpoint("r2r"),
- }
- storage_api_keys: dict[StorageBackend, str | None] = {
- StorageBackend.WEAVIATE: settings.get_api_key("weaviate"),
- StorageBackend.OPEN_WEBUI: settings.get_api_key("openwebui"),
- StorageBackend.R2R: None, # R2R is self-hosted, no API key needed
- }
-
- api_key_raw: str | None = storage_api_keys[job.storage_backend]
- api_key: SecretStr | None = SecretStr(api_key_raw) if api_key_raw is not None else None
-
- return StorageConfig(
- backend=job.storage_backend,
- endpoint=storage_endpoints[job.storage_backend],
- api_key=api_key,
- collection_name=collection_name,
- )
-
-
-def _instantiate_storage(backend: StorageBackend, config: StorageConfig) -> BaseStorage:
- """Instantiate storage based on backend type."""
- if backend == StorageBackend.WEAVIATE:
- return WeaviateStorage(config)
- elif backend == StorageBackend.OPEN_WEBUI:
- return OpenWebUIStorage(config)
- elif backend == StorageBackend.R2R:
- if RuntimeR2RStorage is None:
- raise ValueError("R2R storage not available. Check dependencies.")
- return RuntimeR2RStorage(config)
-
- assert_never(backend)
-
-
-def _chunk_urls(urls: list[str], chunk_size: int) -> list[list[str]]:
- """Group URLs into fixed-size chunks for batch processing."""
-
- if chunk_size <= 0:
- raise ValueError("chunk_size must be greater than zero")
-
- return [urls[i : i + chunk_size] for i in range(0, len(urls), chunk_size)]
-
-
-def _deduplicate_urls(urls: list[str]) -> list[str]:
- """Return the URLs with order preserved and duplicates removed."""
-
- seen: set[str] = set()
- unique: list[str] = []
- for url in urls:
- if url not in seen:
- seen.add(url)
- unique.append(url)
- return unique
-
-
-async def _process_documents(
- ingestor: BaseIngestor,
- storage: BaseStorage,
- job: IngestionJob,
- batch_size: int,
- collection_name: str | None,
- progress_callback: Callable[[int, str], None] | None = None,
-) -> tuple[int, int]:
- """Process documents in batches."""
- processed = 0
- failed = 0
- batch: list[Document] = []
- total_documents = 0
- batch_count = 0
-
- if progress_callback:
- progress_callback(45, "Ingesting documents from source...")
-
- # Use smart ingestion with deduplication if storage supports it
- if hasattr(storage, 'check_exists'):
- try:
- # Try to use the smart ingestion method
- document_generator = ingestor.ingest_with_dedup(
- job, storage, collection_name=collection_name
- )
- except Exception:
- # Fall back to regular ingestion if smart method fails
- document_generator = ingestor.ingest(job)
- else:
- document_generator = ingestor.ingest(job)
-
- async for document in document_generator:
- batch.append(document)
- total_documents += 1
-
- if len(batch) >= batch_size:
- batch_count += 1
- if progress_callback:
- progress_callback(
- 45 + min(35, (batch_count * 10)),
- f"Processing batch {batch_count} ({total_documents} documents so far)..."
- )
-
- batch_processed, batch_failed = await _store_batch(storage, batch, collection_name)
- processed += batch_processed
- failed += batch_failed
- batch = []
-
- # Process remaining batch
- if batch:
- batch_count += 1
- if progress_callback:
- progress_callback(80, f"Processing final batch ({total_documents} total documents)...")
-
- batch_processed, batch_failed = await _store_batch(storage, batch, collection_name)
- processed += batch_processed
- failed += batch_failed
-
- if progress_callback:
- progress_callback(85, f"Completed processing {total_documents} documents")
-
- return processed, failed
-
-
-async def _store_batch(
- storage: BaseStorage,
- batch: list[Document],
- collection_name: str | None,
-) -> tuple[int, int]:
- """Store a batch of documents and return processed/failed counts."""
- try:
- # Apply metadata tagging for backends that benefit from it
- processed_batch = batch
- if hasattr(storage, "config") and storage.config.backend in (
- StorageBackend.R2R,
- StorageBackend.WEAVIATE,
- ):
- try:
- from ..config import get_settings
-
- settings = get_settings()
- async with MetadataTagger(llm_endpoint=str(settings.llm_endpoint)) as tagger:
- processed_batch = await tagger.tag_batch(batch)
- except Exception as exc:
- print(f"Metadata tagging failed, using original documents: {exc}")
- processed_batch = batch
-
- stored_ids = await storage.store_batch(processed_batch, collection_name=collection_name)
- processed_count = len(stored_ids)
- failed_count = len(processed_batch) - processed_count
-
- batch_type = (
- "final" if len(processed_batch) < 50 else ""
- ) # Assume standard batch size is 50
- print(f"Successfully stored {processed_count} documents in {batch_type} batch".strip())
-
- return processed_count, failed_count
- except Exception as e:
- batch_type = "Final" if len(batch) < 50 else "Batch"
- print(f"{batch_type} storage failed: {e}")
- return 0, len(batch)
-
-
-@flow(
- name="firecrawl_to_r2r",
- description="Ingest Firecrawl pages into R2R with metadata annotation",
- persist_result=False,
- log_prints=True,
-)
-async def firecrawl_to_r2r_flow(
- job: IngestionJob, collection_name: str | None = None, progress_callback: Callable[[int, str], None] | None = None
-) -> tuple[int, int]:
- """Specialized flow for Firecrawl ingestion into R2R."""
- logger = get_run_logger()
- from ..config import get_settings
-
- if progress_callback:
- progress_callback(35, "Initializing Firecrawl and R2R storage...")
-
- settings = get_settings()
- firecrawl_config = FirecrawlConfig()
- resolved_collection = collection_name or f"docs_{job.source_type.value}"
-
- storage_config = _build_storage_config(job, settings, resolved_collection)
- storage_client = await initialize_storage_task(storage_config)
-
- if RuntimeR2RStorage is None or not isinstance(storage_client, RuntimeR2RStorage):
- raise IngestionError("Firecrawl to R2R flow requires an R2R storage backend")
-
- r2r_storage = cast("R2RStorageType", storage_client)
-
- if progress_callback:
- progress_callback(45, "Checking for existing content before mapping...")
-
- # Smart mapping: try single URL first to avoid expensive map operation
- base_url = str(job.source_url)
- single_url_id = str(FirecrawlIngestor.compute_document_id(base_url))
- base_exists = await r2r_storage.check_exists(
- single_url_id, collection_name=resolved_collection, stale_after_days=30
- )
-
- if base_exists:
- # Check if this is a recent single-page update
- logger.info("Base URL %s exists and is fresh, skipping expensive mapping", base_url)
- if progress_callback:
- progress_callback(100, "Content is up to date, no processing needed")
- return 0, 0
-
- if progress_callback:
- progress_callback(50, "Discovering pages with Firecrawl...")
-
- discovered_urls = await map_firecrawl_site_task(base_url, firecrawl_config)
- unique_urls = _deduplicate_urls(discovered_urls)
- logger.info("Discovered %s unique URLs from Firecrawl map", len(unique_urls))
-
- if progress_callback:
- progress_callback(60, f"Found {len(unique_urls)} pages, filtering existing content...")
-
- eligible_urls = await filter_existing_documents_task(
- unique_urls, r2r_storage, collection_name=resolved_collection
- )
-
- if not eligible_urls:
- logger.info("All Firecrawl pages are up to date for %s", job.source_url)
- if progress_callback:
- progress_callback(100, "All pages are up to date, no processing needed")
- return 0, 0
-
- if progress_callback:
- progress_callback(70, f"Scraping {len(eligible_urls)} new/updated pages...")
-
- batch_size = min(settings.default_batch_size, firecrawl_config.limit)
- url_batches = _chunk_urls(eligible_urls, batch_size)
- logger.info("Scraping %s batches of Firecrawl pages", len(url_batches))
-
- # Use asyncio.gather for concurrent scraping
- import asyncio
- scrape_tasks = [
- scrape_firecrawl_batch_task(batch, firecrawl_config)
- for batch in url_batches
- ]
- batch_results = await asyncio.gather(*scrape_tasks)
-
- scraped_pages: list[FirecrawlPage] = []
- for batch_pages in batch_results:
- scraped_pages.extend(batch_pages)
-
- if progress_callback:
- progress_callback(80, f"Processing {len(scraped_pages)} scraped pages...")
-
- documents = await annotate_firecrawl_metadata_task(scraped_pages, job)
-
- if not documents:
- logger.warning("No documents produced after scraping for %s", job.source_url)
- return 0, len(eligible_urls)
-
- if progress_callback:
- progress_callback(90, f"Storing {len(documents)} documents in R2R...")
-
- processed, failed = await upsert_r2r_documents_task(r2r_storage, documents, resolved_collection)
-
- logger.info("Upserted %s documents into R2R (%s failed)", processed, failed)
-
- return processed, failed
-
-
-@task(name="update_job_status", tags=["tracking"])
-async def update_job_status_task(
- job: IngestionJob,
- status: IngestionStatus,
- processed: int = 0,
- _failed: int = 0,
- error: str | None = None,
-) -> IngestionJob:
- """
- Update job status.
-
- Args:
- job: Ingestion job
- status: New status
- processed: Documents processed
- _failed: Documents failed (currently unused)
- error: Error message if any
-
- Returns:
- Updated job
- """
- job.status = status
- job.updated_at = datetime.now(UTC)
- job.document_count = processed
-
- if status == IngestionStatus.COMPLETED:
- job.completed_at = datetime.now(UTC)
-
- if error:
- job.error_message = error
-
- return job
-
-
-@flow(
- name="ingestion_pipeline",
- description="Main ingestion pipeline for documents",
- retries=1,
- retry_delay_seconds=60,
- persist_result=True,
- log_prints=True,
-)
-async def create_ingestion_flow(
- source_url: str,
- source_type: SourceTypeLike,
- storage_backend: StorageBackendLike = StorageBackend.WEAVIATE,
- collection_name: str | None = None,
- validate_first: bool = True,
- progress_callback: Callable[[int, str], None] | None = None,
-) -> IngestionResult:
- """
- Main ingestion flow.
-
- Args:
- source_url: URL or path to source
- source_type: Type of source
- storage_backend: Storage backend to use
- validate_first: Whether to validate source first
- progress_callback: Optional callback for progress updates
-
- Returns:
- Ingestion result
- """
- print(f"Starting ingestion from {source_url}")
-
- source_enum = IngestionSource(source_type)
- backend_enum = StorageBackend(storage_backend)
-
- # Create job
- job = IngestionJob(
- source_url=source_url,
- source_type=source_enum,
- storage_backend=backend_enum,
- status=IngestionStatus.PENDING,
- )
-
- start_time = datetime.now(UTC)
- error_messages: list[str] = []
- processed = 0
- failed = 0
-
- try:
- # Validate source if requested
- if validate_first:
- if progress_callback:
- progress_callback(10, "Validating source...")
- print("Validating source...")
- is_valid = await validate_source_task(source_url, job.source_type)
-
- if not is_valid:
- raise IngestionError(f"Source validation failed: {source_url}")
-
- # Update status to in progress
- if progress_callback:
- progress_callback(20, "Initializing storage...")
- job = await update_job_status_task(job, IngestionStatus.IN_PROGRESS)
-
- # Run ingestion
- if progress_callback:
- progress_callback(30, "Starting document ingestion...")
- print("Ingesting documents...")
- if job.source_type == IngestionSource.WEB and job.storage_backend == StorageBackend.R2R:
- processed, failed = await firecrawl_to_r2r_flow(job, collection_name, progress_callback=progress_callback)
- else:
- processed, failed = await ingest_documents_task(job, collection_name, progress_callback=progress_callback)
-
- if progress_callback:
- progress_callback(90, "Finalizing ingestion...")
-
- # Update final status
- if failed > 0:
- error_messages.append(f"{failed} documents failed to process")
-
- # Set status based on results
- if processed == 0 and failed > 0:
- final_status = IngestionStatus.FAILED
- elif failed > 0:
- final_status = IngestionStatus.PARTIAL
- else:
- final_status = IngestionStatus.COMPLETED
-
- job = await update_job_status_task(job, final_status, processed=processed, _failed=failed)
-
- print(f"Ingestion completed: {processed} processed, {failed} failed")
-
- except Exception as e:
- print(f"Ingestion failed: {e}")
- error_messages.append(str(e))
-
- # Don't reset counts - keep whatever was processed before the error
- job = await update_job_status_task(
- job, IngestionStatus.FAILED, processed=processed, _failed=failed, error=str(e)
- )
-
- # Calculate duration
- duration = (datetime.now(UTC) - start_time).total_seconds()
-
- return IngestionResult(
- job_id=job.id,
- status=job.status,
- documents_processed=processed,
- documents_failed=failed,
- duration_seconds=duration,
- error_messages=error_messages,
- )
-
-
"""R2R storage implementation using the official R2R SDK."""
@@ -11307,21 +11951,23 @@ from __future__ import annotations
import asyncio
import contextlib
+import logging
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
from datetime import UTC, datetime
-from typing import Self, TypeVar, cast
+from typing import Final, Self, TypeVar, cast
from uuid import UUID, uuid4
# Direct imports for runtime and type checking
-# Note: Some type checkers (basedpyright/Pyrefly) may report import issues
-# but these work correctly at runtime and with mypy
-from httpx import AsyncClient, HTTPStatusError
-from r2r import R2RAsyncClient, R2RException
+from httpx import AsyncClient, HTTPStatusError # type: ignore
+from r2r import R2RAsyncClient, R2RException # type: ignore
from typing_extensions import override
from ...core.exceptions import StorageError
from ...core.models import Document, DocumentMetadata, IngestionSource, StorageConfig
from ..base import BaseStorage
+from ..types import DocumentInfo
+
+LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
T = TypeVar("T")
@@ -11383,6 +12029,24 @@ class R2RStorage(BaseStorage):
self.client: R2RAsyncClient = R2RAsyncClient(self.endpoint)
self.default_collection_id: str | None = None
+ def _get_http_client_headers(self) -> dict[str, str]:
+ """Get consistent HTTP headers for direct API calls."""
+ headers = {"Content-Type": "application/json"}
+
+ # Add authentication headers if available
+ # Note: R2R SDK may handle auth internally, so we extract it if possible
+ if hasattr(self.client, "_get_headers"):
+ with contextlib.suppress(Exception):
+ sdk_headers = self.client._get_headers() # type: ignore[attr-defined]
+ if isinstance(sdk_headers, dict):
+ headers |= sdk_headers
+ return headers
+
+ def _create_http_client(self) -> AsyncClient:
+ """Create a properly configured HTTP client for direct API calls."""
+ headers = self._get_http_client_headers()
+ return AsyncClient(headers=headers, timeout=30.0)
+
@override
async def initialize(self) -> None:
"""Initialize R2R connection and ensure default collection exists."""
@@ -11399,7 +12063,7 @@ class R2RStorage(BaseStorage):
# Test connection using direct HTTP call to v3 API
endpoint = self.endpoint
- client = AsyncClient()
+ client = self._create_http_client()
try:
response = await client.get(f"{endpoint}/v3/collections")
response.raise_for_status()
@@ -11412,7 +12076,7 @@ class R2RStorage(BaseStorage):
async def _ensure_collection(self, collection_name: str) -> str:
"""Get or create collection by name."""
endpoint = self.endpoint
- client = AsyncClient()
+ client = self._create_http_client()
try:
# List collections and find by name
response = await client.get(f"{endpoint}/v3/collections")
@@ -11455,6 +12119,9 @@ class R2RStorage(BaseStorage):
finally:
await client.aclose()
+ # This should never be reached, but satisfies static analyzer
+ raise StorageError(f"Unexpected code path in _ensure_collection for '{collection_name}'")
+
@override
async def store(self, document: Document, *, collection_name: str | None = None) -> str:
"""Store a single document."""
@@ -11464,20 +12131,46 @@ class R2RStorage(BaseStorage):
async def store_batch(
self, documents: list[Document], *, collection_name: str | None = None
) -> list[str]:
- """Store multiple documents."""
+ """Store multiple documents efficiently with connection reuse."""
collection_id = await self._resolve_collection_id(collection_name)
- print(
- f"Using collection ID: {collection_id} for collection: {collection_name or self.config.collection_name}"
+ LOGGER.info(
+ "Using collection ID: %s for collection: %s",
+ collection_id,
+ collection_name or self.config.collection_name,
)
- stored_ids: list[str] = []
- for document in documents:
- if not self._is_document_valid(document):
- continue
+ # Filter valid documents upfront
+ valid_documents = [doc for doc in documents if self._is_document_valid(doc)]
+ if not valid_documents:
+ return []
- stored_id = await self._store_single_document(document, collection_id)
- if stored_id:
- stored_ids.append(stored_id)
+ stored_ids: list[str] = []
+
+ # Use a single HTTP client for all requests
+ http_client = AsyncClient()
+ async with http_client: # type: ignore
+ # Process documents with controlled concurrency
+ import asyncio
+
+ semaphore = asyncio.Semaphore(5) # Limit concurrent uploads
+
+ async def store_single_with_client(document: Document) -> str | None:
+ async with semaphore:
+ return await self._store_single_document_with_client(
+ document, collection_id, http_client
+ )
+
+ # Execute all uploads concurrently
+ results = await asyncio.gather(
+ *[store_single_with_client(doc) for doc in valid_documents], return_exceptions=True
+ )
+
+ # Collect successful IDs
+ for result in results:
+ if isinstance(result, str):
+ stored_ids.append(result)
+ elif isinstance(result, Exception):
+ LOGGER.error("Document upload failed: %s", result)
return stored_ids
@@ -11498,12 +12191,14 @@ class R2RStorage(BaseStorage):
requested_id = str(document.id)
if not document.content or not document.content.strip():
- print(f"Skipping document {requested_id}: empty content")
+ LOGGER.warning("Skipping document %s: empty content", requested_id)
return False
if len(document.content) > 1_000_000: # 1MB limit
- print(
- f"Skipping document {requested_id}: content too large ({len(document.content)} chars)"
+ LOGGER.warning(
+ "Skipping document %s: content too large (%d chars)",
+ requested_id,
+ len(document.content),
)
return False
@@ -11511,23 +12206,41 @@ class R2RStorage(BaseStorage):
async def _store_single_document(self, document: Document, collection_id: str) -> str | None:
"""Store a single document with retry logic."""
+ http_client = AsyncClient()
+ async with http_client: # type: ignore
+ return await self._store_single_document_with_client(
+ document, collection_id, http_client
+ )
+
+ async def _store_single_document_with_client(
+ self, document: Document, collection_id: str, http_client: AsyncClient
+ ) -> str | None:
+ """Store a single document with retry logic using provided HTTP client."""
requested_id = str(document.id)
- print(f"Creating document with ID: {requested_id}")
+ LOGGER.debug("Creating document with ID: %s", requested_id)
max_retries = 3
retry_delay = 1.0
for attempt in range(max_retries):
try:
- doc_response = await self._attempt_document_creation(document, collection_id)
+ doc_response = await self._attempt_document_creation_with_client(
+ document, collection_id, http_client
+ )
if doc_response:
- return self._process_document_response(doc_response, requested_id, collection_id)
+ return self._process_document_response(
+ doc_response, requested_id, collection_id
+ )
except (TimeoutError, OSError) as e:
- if not await self._should_retry_timeout(e, attempt, max_retries, requested_id, retry_delay):
+ if not await self._should_retry_timeout(
+ e, attempt, max_retries, requested_id, retry_delay
+ ):
break
retry_delay *= 2
except HTTPStatusError as e:
- if not await self._should_retry_http_error(e, attempt, max_retries, requested_id, retry_delay):
+ if not await self._should_retry_http_error(
+ e, attempt, max_retries, requested_id, retry_delay
+ ):
break
retry_delay *= 2
except Exception as exc:
@@ -11536,13 +12249,25 @@ class R2RStorage(BaseStorage):
return None
- async def _attempt_document_creation(self, document: Document, collection_id: str) -> dict[str, object] | None:
+ async def _attempt_document_creation(
+ self, document: Document, collection_id: str
+ ) -> dict[str, object] | None:
"""Attempt to create a document via HTTP API."""
+ http_client = AsyncClient()
+ async with http_client: # type: ignore
+ return await self._attempt_document_creation_with_client(
+ document, collection_id, http_client
+ )
+
+ async def _attempt_document_creation_with_client(
+ self, document: Document, collection_id: str, http_client: AsyncClient
+ ) -> dict[str, object] | None:
+ """Attempt to create a document via HTTP API using provided client."""
import json
requested_id = str(document.id)
metadata = self._build_metadata(document)
- print(f"Built metadata for document {requested_id}: {metadata}")
+ LOGGER.debug("Built metadata for document %s: %s", requested_id, metadata)
files = {
"raw_text": (None, document.content),
@@ -11553,86 +12278,121 @@ class R2RStorage(BaseStorage):
if collection_id:
files["collection_ids"] = (None, json.dumps([collection_id]))
- print(f"Creating document {requested_id} with collection_ids: [{collection_id}]")
+ LOGGER.debug(
+ "Creating document %s with collection_ids: [%s]", requested_id, collection_id
+ )
- print(f"Sending to R2R - files keys: {list(files.keys())}")
- print(f"Metadata JSON: {files['metadata'][1]}")
+ LOGGER.debug("Sending to R2R - files keys: %s", list(files.keys()))
+ LOGGER.debug("Metadata JSON: %s", files["metadata"][1])
- async with AsyncClient() as http_client:
- response = await http_client.post(f"{self.endpoint}/v3/documents", files=files)
+ response = await http_client.post(f"{self.endpoint}/v3/documents", files=files) # type: ignore[call-arg]
- if response.status_code == 422:
- self._handle_validation_error(response, requested_id, metadata)
- return None
+ if response.status_code == 422:
+ self._handle_validation_error(response, requested_id, metadata)
+ return None
- response.raise_for_status()
- return response.json()
+ response.raise_for_status()
+ return response.json()
- def _handle_validation_error(self, response: object, requested_id: str, metadata: dict[str, object]) -> None:
+ def _handle_validation_error(
+ self, response: object, requested_id: str, metadata: dict[str, object]
+ ) -> None:
"""Handle validation errors from R2R API."""
try:
- error_detail = getattr(response, 'json', lambda: {})() if hasattr(response, 'json') else {}
- print(f"R2R validation error for document {requested_id}: {error_detail}")
- print(f"Document metadata sent: {metadata}")
- print(f"Response status: {getattr(response, 'status_code', 'unknown')}")
- print(f"Response headers: {dict(getattr(response, 'headers', {}))}")
+ error_detail = (
+ getattr(response, "json", lambda: {})() if hasattr(response, "json") else {}
+ )
+ LOGGER.error("R2R validation error for document %s: %s", requested_id, error_detail)
+ LOGGER.error("Document metadata sent: %s", metadata)
+ LOGGER.error("Response status: %s", getattr(response, "status_code", "unknown"))
+ LOGGER.error("Response headers: %s", dict(getattr(response, "headers", {})))
except Exception:
- print(f"R2R validation error for document {requested_id}: {getattr(response, 'text', 'unknown error')}")
- print(f"Document metadata sent: {metadata}")
+ LOGGER.error(
+ "R2R validation error for document %s: %s",
+ requested_id,
+ getattr(response, "text", "unknown error"),
+ )
+ LOGGER.error("Document metadata sent: %s", metadata)
- def _process_document_response(self, doc_response: dict[str, object], requested_id: str, collection_id: str) -> str:
+ def _process_document_response(
+ self, doc_response: dict[str, object], requested_id: str, collection_id: str
+ ) -> str:
"""Process successful document creation response."""
response_payload = doc_response.get("results", doc_response)
doc_id = _extract_id(response_payload, requested_id)
- print(f"R2R returned document ID: {doc_id}")
+ LOGGER.info("R2R returned document ID: %s", doc_id)
if doc_id != requested_id:
- print(f"Warning: Requested ID {requested_id} but got {doc_id}")
+ LOGGER.warning("Requested ID %s but got %s", requested_id, doc_id)
if collection_id:
- print(f"Document {doc_id} should be assigned to collection {collection_id} via creation API")
+ LOGGER.info(
+ "Document %s should be assigned to collection %s via creation API",
+ doc_id,
+ collection_id,
+ )
return doc_id
- async def _should_retry_timeout(self, error: Exception, attempt: int, max_retries: int, requested_id: str, retry_delay: float) -> bool:
+ async def _should_retry_timeout(
+ self,
+ error: Exception,
+ attempt: int,
+ max_retries: int,
+ requested_id: str,
+ retry_delay: float,
+ ) -> bool:
"""Determine if timeout error should be retried."""
if attempt >= max_retries - 1:
return False
- print(f"Timeout for document {requested_id}, retrying in {retry_delay}s...")
+ LOGGER.warning("Timeout for document %s, retrying in %ss...", requested_id, retry_delay)
await asyncio.sleep(retry_delay)
return True
- async def _should_retry_http_error(self, error: HTTPStatusError, attempt: int, max_retries: int, requested_id: str, retry_delay: float) -> bool:
+ async def _should_retry_http_error(
+ self,
+ error: HTTPStatusError,
+ attempt: int,
+ max_retries: int,
+ requested_id: str,
+ retry_delay: float,
+ ) -> bool:
"""Determine if HTTP error should be retried."""
- if error.response.status_code < 500 or attempt >= max_retries - 1:
+ status_code = error.response.status_code
+ if status_code < 500 or attempt >= max_retries - 1:
return False
- print(f"Server error {error.response.status_code} for document {requested_id}, retrying in {retry_delay}s...")
+ LOGGER.warning(
+ "Server error %s for document %s, retrying in %ss...",
+ status_code,
+ requested_id,
+ retry_delay,
+ )
await asyncio.sleep(retry_delay)
return True
def _log_document_error(self, document_id: object, exc: Exception) -> None:
"""Log document storage errors with specific categorization."""
- print(f"Failed to store document {document_id}: {exc}")
+ LOGGER.error("Failed to store document %s: %s", document_id, exc)
exc_str = str(exc)
if "422" in exc_str:
- print(" → Data validation issue - check document content and metadata format")
+ LOGGER.error(" → Data validation issue - check document content and metadata format")
elif "timeout" in exc_str.lower():
- print(" → Network timeout - R2R may be overloaded")
+ LOGGER.error(" → Network timeout - R2R may be overloaded")
elif "500" in exc_str:
- print(" → Server error - R2R internal issue")
+ LOGGER.error(" → Server error - R2R internal issue")
else:
import traceback
+
traceback.print_exc()
def _build_metadata(self, document: Document) -> dict[str, object]:
"""Convert document metadata to enriched R2R format."""
metadata = document.metadata
-
# Core required fields
result: dict[str, object] = {
"source_url": metadata["source_url"],
@@ -11768,7 +12528,9 @@ class R2RStorage(BaseStorage):
except ValueError:
return uuid4()
- def _build_core_metadata(self, metadata_map: dict[str, object], timestamp: datetime) -> DocumentMetadata:
+ def _build_core_metadata(
+ self, metadata_map: dict[str, object], timestamp: datetime
+ ) -> DocumentMetadata:
"""Build core required metadata fields."""
return {
"source_url": str(metadata_map.get("source_url", "")),
@@ -11778,7 +12540,12 @@ class R2RStorage(BaseStorage):
"char_count": _as_int(metadata_map.get("char_count")),
}
- def _add_optional_metadata_fields(self, metadata: DocumentMetadata, doc_map: dict[str, object], metadata_map: dict[str, object]) -> None:
+ def _add_optional_metadata_fields(
+ self,
+ metadata: DocumentMetadata,
+ doc_map: dict[str, object],
+ metadata_map: dict[str, object],
+ ) -> None:
"""Add optional metadata fields if present."""
self._add_title_and_description(metadata, doc_map, metadata_map)
self._add_content_categorization(metadata, metadata_map)
@@ -11787,7 +12554,12 @@ class R2RStorage(BaseStorage):
self._add_processing_fields(metadata, metadata_map)
self._add_quality_scores(metadata, metadata_map)
- def _add_title_and_description(self, metadata: DocumentMetadata, doc_map: dict[str, object], metadata_map: dict[str, object]) -> None:
+ def _add_title_and_description(
+ self,
+ metadata: DocumentMetadata,
+ doc_map: dict[str, object],
+ metadata_map: dict[str, object],
+ ) -> None:
"""Add title and description fields."""
if title := (doc_map.get("title") or metadata_map.get("title")):
metadata["title"] = cast(str | None, title)
@@ -11797,7 +12569,9 @@ class R2RStorage(BaseStorage):
elif description := metadata_map.get("description"):
metadata["description"] = cast(str | None, description)
- def _add_content_categorization(self, metadata: DocumentMetadata, metadata_map: dict[str, object]) -> None:
+ def _add_content_categorization(
+ self, metadata: DocumentMetadata, metadata_map: dict[str, object]
+ ) -> None:
"""Add content categorization fields."""
if tags := metadata_map.get("tags"):
metadata["tags"] = [str(tag) for tag in tags] if isinstance(tags, list) else []
@@ -11808,7 +12582,9 @@ class R2RStorage(BaseStorage):
if language := metadata_map.get("language"):
metadata["language"] = str(language)
- def _add_authorship_fields(self, metadata: DocumentMetadata, metadata_map: dict[str, object]) -> None:
+ def _add_authorship_fields(
+ self, metadata: DocumentMetadata, metadata_map: dict[str, object]
+ ) -> None:
"""Add authorship and source information fields."""
if author := metadata_map.get("author"):
metadata["author"] = str(author)
@@ -11817,7 +12593,9 @@ class R2RStorage(BaseStorage):
if site_name := metadata_map.get("site_name"):
metadata["site_name"] = str(site_name)
- def _add_structure_fields(self, metadata: DocumentMetadata, metadata_map: dict[str, object]) -> None:
+ def _add_structure_fields(
+ self, metadata: DocumentMetadata, metadata_map: dict[str, object]
+ ) -> None:
"""Add document structure fields."""
if heading_hierarchy := metadata_map.get("heading_hierarchy"):
metadata["heading_hierarchy"] = (
@@ -11832,7 +12610,9 @@ class R2RStorage(BaseStorage):
if has_links := metadata_map.get("has_links"):
metadata["has_links"] = bool(has_links)
- def _add_processing_fields(self, metadata: DocumentMetadata, metadata_map: dict[str, object]) -> None:
+ def _add_processing_fields(
+ self, metadata: DocumentMetadata, metadata_map: dict[str, object]
+ ) -> None:
"""Add processing-related metadata fields."""
if extraction_method := metadata_map.get("extraction_method"):
metadata["extraction_method"] = str(extraction_method)
@@ -11841,7 +12621,9 @@ class R2RStorage(BaseStorage):
if last_modified := metadata_map.get("last_modified"):
metadata["last_modified"] = _as_datetime(last_modified)
- def _add_quality_scores(self, metadata: DocumentMetadata, metadata_map: dict[str, object]) -> None:
+ def _add_quality_scores(
+ self, metadata: DocumentMetadata, metadata_map: dict[str, object]
+ ) -> None:
"""Add quality score fields with safe float conversion."""
if readability_score := metadata_map.get("readability_score"):
try:
@@ -11944,7 +12726,7 @@ class R2RStorage(BaseStorage):
async def count(self, *, collection_name: str | None = None) -> int:
"""Get document count in collection."""
endpoint = self.endpoint
- client = AsyncClient()
+ client = self._create_http_client()
try:
# Get collections and find the count for the specific collection
response = await client.get(f"{endpoint}/v3/collections")
@@ -11965,6 +12747,9 @@ class R2RStorage(BaseStorage):
finally:
await client.aclose()
+ # This should never be reached, but satisfies static analyzer
+ return 0
+
@override
async def close(self) -> None:
"""Close R2R client."""
@@ -12012,7 +12797,7 @@ class R2RStorage(BaseStorage):
async def list_collections(self) -> list[str]:
"""List all available collections."""
endpoint = self.endpoint
- client = AsyncClient()
+ client = self._create_http_client()
try:
response = await client.get(f"{endpoint}/v3/collections")
response.raise_for_status()
@@ -12029,6 +12814,9 @@ class R2RStorage(BaseStorage):
finally:
await client.aclose()
+ # This should never be reached, but satisfies static analyzer
+ return []
+
async def list_collections_detailed(self) -> list[dict[str, object]]:
"""List all available collections with detailed information."""
try:
@@ -12092,7 +12880,7 @@ class R2RStorage(BaseStorage):
offset: int = 0,
*,
collection_name: str | None = None,
- ) -> list[dict[str, object]]:
+ ) -> list[DocumentInfo]:
"""
List documents in R2R with pagination.
@@ -12105,14 +12893,14 @@ class R2RStorage(BaseStorage):
List of document dictionaries with metadata
"""
try:
- documents: list[dict[str, object]] = []
+ documents: list[DocumentInfo] = []
if collection_name:
# Get collection ID first
collection_id = await self._ensure_collection(collection_name)
# Use the collections API to list documents in a specific collection
endpoint = self.endpoint
- client = AsyncClient()
+ client = self._create_http_client()
try:
params = {"offset": offset, "limit": limit}
response = await client.get(
@@ -12145,20 +12933,19 @@ class R2RStorage(BaseStorage):
title = str(doc_map.get("title", "Untitled"))
metadata = _as_mapping(doc_map.get("metadata", {}))
- documents.append(
- {
- "id": doc_id,
- "title": title,
- "source_url": str(metadata.get("source_url", "")),
- "description": str(metadata.get("description", "")),
- "content_type": str(metadata.get("content_type", "text/plain")),
- "content_preview": str(doc_map.get("content", ""))[:200] + "..."
- if doc_map.get("content")
- else "",
- "word_count": _as_int(metadata.get("word_count", 0)),
- "timestamp": str(doc_map.get("created_at", "")),
- }
- )
+ document_info: DocumentInfo = {
+ "id": doc_id,
+ "title": title,
+ "source_url": str(metadata.get("source_url", "")),
+ "description": str(metadata.get("description", "")),
+ "content_type": str(metadata.get("content_type", "text/plain")),
+ "content_preview": str(doc_map.get("content", ""))[:200] + "..."
+ if doc_map.get("content")
+ else "",
+ "word_count": _as_int(metadata.get("word_count", 0)),
+ "timestamp": str(doc_map.get("created_at", "")),
+ }
+ documents.append(document_info)
return documents
@@ -12169,10 +12956,191 @@ class R2RStorage(BaseStorage):
"""Base storage interface."""
+import asyncio
+import logging
+import random
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
+from types import TracebackType
+from typing import Final
+import httpx
+from pydantic import SecretStr
+
+from ..core.exceptions import StorageError
from ..core.models import Document, StorageConfig
+from .types import CollectionSummary, DocumentInfo
+
+LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
+
+
+class TypedHttpClient:
+ """
+ A properly typed HTTP client wrapper for HTTPX.
+
+ Provides consistent exception handling and type annotations
+ for storage adapters that use HTTP APIs.
+
+ Note: Some type checkers (Pylance) may report warnings about HTTPX types
+ due to library compatibility issues. The code functions correctly at runtime.
+ """
+
+ client: httpx.AsyncClient
+ _base_url: str
+
+ def __init__(
+ self,
+ base_url: str,
+ *,
+ api_key: SecretStr | None = None,
+ timeout: float = 30.0,
+ headers: dict[str, str] | None = None,
+ max_connections: int = 100,
+ max_keepalive_connections: int = 20,
+ ):
+ """
+ Initialize the typed HTTP client.
+
+ Args:
+ base_url: Base URL for all requests
+ api_key: Optional API key for authentication
+ timeout: Request timeout in seconds
+ headers: Additional headers to include with requests
+ max_connections: Maximum total connections in pool
+ max_keepalive_connections: Maximum keepalive connections
+ """
+ self._base_url = base_url
+
+ # Build headers with optional authentication
+ client_headers: dict[str, str] = headers or {}
+ if api_key:
+ client_headers["Authorization"] = f"Bearer {api_key.get_secret_value()}"
+
+ # Create typed client configuration with connection pooling
+ limits = httpx.Limits(
+ max_connections=max_connections, max_keepalive_connections=max_keepalive_connections
+ )
+ timeout_config = httpx.Timeout(connect=5.0, read=timeout, write=30.0, pool=10.0)
+ self.client = httpx.AsyncClient(
+ base_url=base_url, headers=client_headers, timeout=timeout_config, limits=limits
+ )
+
+ async def request(
+ self,
+ method: str,
+ path: str,
+ *,
+ allow_404: bool = False,
+ json: dict[str, object] | None = None,
+ data: dict[str, object] | None = None,
+ files: dict[str, tuple[str, bytes, str]] | None = None,
+ params: dict[str, str | bool] | None = None,
+ max_retries: int = 3,
+ retry_delay: float = 1.0,
+ ) -> httpx.Response | None:
+ """
+ Perform an HTTP request with consistent error handling and retries.
+
+ Args:
+ method: HTTP method (GET, POST, DELETE, etc.)
+ path: URL path relative to base_url
+ allow_404: If True, return None for 404 responses instead of raising
+ json: JSON data to send
+ data: Form data to send
+ files: Files to upload
+ params: Query parameters
+ max_retries: Maximum number of retry attempts
+ retry_delay: Base delay between retries in seconds
+
+ Returns:
+ HTTP response object, or None if allow_404=True and status is 404
+
+ Raises:
+ StorageError: If request fails after retries
+ """
+ last_exception: Exception | None = None
+
+ for attempt in range(max_retries + 1):
+ try:
+ response = await self.client.request(
+ method, path, json=json, data=data, files=files, params=params
+ )
+ response.raise_for_status()
+ return response
+ except httpx.HTTPStatusError as e:
+ # Handle 404 as special case if requested
+ if allow_404 and e.response.status_code == 404:
+ LOGGER.debug("Resource not found (404): %s %s", method, path)
+ return None
+
+ # Don't retry client errors (4xx except for specific cases)
+ if 400 <= e.response.status_code < 500 and e.response.status_code not in [429, 408]:
+ raise StorageError(
+ f"HTTP {e.response.status_code} error from {self._base_url}: {e}"
+ ) from e
+
+ last_exception = e
+ if attempt < max_retries:
+ # Exponential backoff with jitter for retryable errors
+ delay = retry_delay * (2**attempt) + random.uniform(0, 1)
+ LOGGER.warning(
+ "HTTP %d error on attempt %d/%d, retrying in %.2fs: %s",
+ e.response.status_code,
+ attempt + 1,
+ max_retries + 1,
+ delay,
+ e,
+ )
+ await asyncio.sleep(delay)
+
+ except httpx.HTTPError as e:
+ last_exception = e
+ if attempt < max_retries:
+ # Retry transport errors with backoff
+ delay = retry_delay * (2**attempt) + random.uniform(0, 1)
+ LOGGER.warning(
+ "HTTP transport error on attempt %d/%d, retrying in %.2fs: %s",
+ attempt + 1,
+ max_retries + 1,
+ delay,
+ e,
+ )
+ await asyncio.sleep(delay)
+
+ # All retries exhausted - last_exception should always be set if we reach here
+ if last_exception is None:
+ raise StorageError(
+ f"Request to {self._base_url} failed after {max_retries + 1} attempts with unknown error"
+ )
+
+ if isinstance(last_exception, httpx.HTTPStatusError):
+ raise StorageError(
+ f"HTTP {last_exception.response.status_code} error from {self._base_url} after {max_retries + 1} attempts: {last_exception}"
+ ) from last_exception
+ else:
+ raise StorageError(
+ f"HTTP transport error to {self._base_url} after {max_retries + 1} attempts: {last_exception}"
+ ) from last_exception
+
+ async def close(self) -> None:
+ """Close the HTTP client and cleanup resources."""
+ try:
+ await self.client.aclose()
+ except Exception as e:
+ LOGGER.warning("Error closing HTTP client: %s", e)
+
+ async def __aenter__(self) -> "TypedHttpClient":
+ """Async context manager entry."""
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None,
+ ) -> None:
+ """Async context manager exit."""
+ await self.close()
class BaseStorage(ABC):
@@ -12266,11 +13234,30 @@ class BaseStorage(ABC):
# Check staleness if timestamp is available
if "timestamp" in document.metadata:
from datetime import UTC, datetime, timedelta
+
timestamp_obj = document.metadata["timestamp"]
+
+ # Handle both datetime objects and ISO strings
if isinstance(timestamp_obj, datetime):
timestamp = timestamp_obj
- cutoff = datetime.now(UTC) - timedelta(days=stale_after_days)
- return timestamp >= cutoff
+ # Ensure timezone awareness
+ if timestamp.tzinfo is None:
+ timestamp = timestamp.replace(tzinfo=UTC)
+ elif isinstance(timestamp_obj, str):
+ try:
+ timestamp = datetime.fromisoformat(timestamp_obj)
+ # Ensure timezone awareness
+ if timestamp.tzinfo is None:
+ timestamp = timestamp.replace(tzinfo=UTC)
+ except ValueError:
+ # If parsing fails, assume document is stale
+ return False
+ else:
+ # Unknown timestamp format, assume stale
+ return False
+
+ cutoff = datetime.now(UTC) - timedelta(days=stale_after_days)
+ return timestamp >= cutoff
# If no timestamp, assume it exists and is valid
return True
@@ -12333,12 +13320,12 @@ class BaseStorage(ABC):
"""
return []
- async def describe_collections(self) -> list[dict[str, object]]:
+ async def describe_collections(self) -> list[CollectionSummary]:
"""
Describe available collections with metadata (if supported by backend).
Returns:
- List of collection metadata dictionaries, empty list if not supported
+ List of collection metadata, empty list if not supported
"""
return []
@@ -12375,7 +13362,7 @@ class BaseStorage(ABC):
offset: int = 0,
*,
collection_name: str | None = None,
- ) -> list[dict[str, object]]:
+ ) -> list[DocumentInfo]:
"""
List documents in the storage backend (if supported).
@@ -12385,7 +13372,7 @@ class BaseStorage(ABC):
collection_name: Collection to list documents from
Returns:
- List of document dictionaries with metadata
+ List of document information with metadata
Raises:
NotImplementedError: If backend doesn't support document listing
@@ -12406,34 +13393,58 @@ class BaseStorage(ABC):
"""Open WebUI storage adapter."""
import asyncio
+import contextlib
import logging
-from typing import TYPE_CHECKING, Final, TypedDict, cast
+import time
+from typing import Final, NamedTuple, TypedDict
-import httpx
from typing_extensions import override
-if TYPE_CHECKING:
- # Type checking imports - these will be ignored at runtime
- from httpx import AsyncClient, ConnectError, HTTPStatusError, RequestError
-else:
- # Runtime imports that work properly
- AsyncClient = httpx.AsyncClient
- ConnectError = httpx.ConnectError
- HTTPStatusError = httpx.HTTPStatusError
- RequestError = httpx.RequestError
-
from ..core.exceptions import StorageError
from ..core.models import Document, StorageConfig
-from .base import BaseStorage
+from .base import BaseStorage, TypedHttpClient
+from .types import CollectionSummary, DocumentInfo
LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
+class OpenWebUIFileResponse(TypedDict, total=False):
+ """OpenWebUI API file response structure."""
+
+ id: str
+ filename: str
+ name: str
+ content_type: str
+ size: int
+ created_at: str
+ meta: dict[str, str | int]
+
+
+class OpenWebUIKnowledgeBase(TypedDict, total=False):
+ """OpenWebUI knowledge base response structure."""
+
+ id: str
+ name: str
+ description: str
+ files: list[OpenWebUIFileResponse]
+ data: dict[str, str]
+ created_at: str
+ updated_at: str
+
+
+class CacheEntry(NamedTuple):
+ """Cache entry with value and expiration time."""
+
+ value: str
+ expires_at: float
+
+
class OpenWebUIStorage(BaseStorage):
"""Storage adapter for Open WebUI knowledge endpoints."""
- client: AsyncClient
- _knowledge_cache: dict[str, str]
+ http_client: TypedHttpClient
+ _knowledge_cache: dict[str, CacheEntry]
+ _cache_ttl: float
def __init__(self, config: StorageConfig):
"""
@@ -12444,16 +13455,13 @@ class OpenWebUIStorage(BaseStorage):
"""
super().__init__(config)
- headers: dict[str, str] = {}
- if config.api_key:
- headers["Authorization"] = f"Bearer {config.api_key}"
-
- self.client = AsyncClient(
+ self.http_client = TypedHttpClient(
base_url=str(config.endpoint),
- headers=headers,
+ api_key=config.api_key,
timeout=30.0,
)
self._knowledge_cache = {}
+ self._cache_ttl = 300.0 # 5 minutes TTL
@override
async def initialize(self) -> None:
@@ -12464,60 +13472,61 @@ class OpenWebUIStorage(BaseStorage):
self.config.collection_name,
create=True,
)
-
- except ConnectError as e:
- raise StorageError(f"Connection to OpenWebUI failed: {e}") from e
- except HTTPStatusError as e:
- raise StorageError(f"OpenWebUI returned error {e.response.status_code}: {e}") from e
- except RequestError as e:
- raise StorageError(f"Request to OpenWebUI failed: {e}") from e
except Exception as e:
raise StorageError(f"Failed to initialize Open WebUI: {e}") from e
async def _create_collection(self, name: str) -> str:
"""Create knowledge base in Open WebUI."""
- try:
- response = await self.client.post(
- "/api/v1/knowledge/create",
- json={
- "name": name,
- "description": "Documents ingested from various sources",
- "data": {},
- "access_control": None,
- },
- )
- response.raise_for_status()
- result = response.json()
- knowledge_id = result.get("id")
+ response = await self.http_client.request(
+ "POST",
+ "/api/v1/knowledge/create",
+ json={
+ "name": name,
+ "description": "Documents ingested from various sources",
+ "data": {},
+ "access_control": None,
+ },
+ )
+ if response is None:
+ raise StorageError("Unexpected None response from knowledge base creation")
+ result = response.json()
+ knowledge_id = result.get("id")
- if not knowledge_id or not isinstance(knowledge_id, str):
- raise StorageError("Knowledge base creation failed: no ID returned")
+ if not knowledge_id or not isinstance(knowledge_id, str):
+ raise StorageError("Knowledge base creation failed: no ID returned")
- return str(knowledge_id)
+ return str(knowledge_id)
- except ConnectError as e:
- raise StorageError(f"Connection to OpenWebUI failed during creation: {e}") from e
- except HTTPStatusError as e:
- raise StorageError(
- f"OpenWebUI returned error {e.response.status_code} during creation: {e}"
- ) from e
- except RequestError as e:
- raise StorageError(f"Request to OpenWebUI failed during creation: {e}") from e
- except Exception as e:
- raise StorageError(f"Failed to create knowledge base: {e}") from e
-
- async def _fetch_knowledge_bases(self) -> list[dict[str, object]]:
+ async def _fetch_knowledge_bases(self) -> list[OpenWebUIKnowledgeBase]:
"""Return the list of knowledge bases from the API."""
- response = await self.client.get("/api/v1/knowledge/list")
- response.raise_for_status()
+ response = await self.http_client.request("GET", "/api/v1/knowledge/list")
+ if response is None:
+ return []
data = response.json()
if not isinstance(data, list):
return []
- normalized: list[dict[str, object]] = []
+ normalized: list[OpenWebUIKnowledgeBase] = []
for item in data:
- if isinstance(item, dict):
- item_dict: dict[str, object] = item
- normalized.append({str(k): v for k, v in item_dict.items()})
+ if (
+ isinstance(item, dict)
+ and "id" in item
+ and "name" in item
+ and isinstance(item["id"], str)
+ and isinstance(item["name"], str)
+ ):
+ # Create a new dict with known structure
+ kb_item: OpenWebUIKnowledgeBase = {
+ "id": item["id"],
+ "name": item["name"],
+ "description": item.get("description", ""),
+ "created_at": item.get("created_at", ""),
+ "updated_at": item.get("updated_at", ""),
+ }
+ if "files" in item and isinstance(item["files"], list):
+ kb_item["files"] = item["files"]
+ if "data" in item and isinstance(item["data"], dict):
+ kb_item["data"] = item["data"]
+ normalized.append(kb_item)
return normalized
async def _get_knowledge_id(
@@ -12532,22 +13541,29 @@ class OpenWebUIStorage(BaseStorage):
if not target:
raise StorageError("Knowledge base name is required")
- if cached := self._knowledge_cache.get(target):
- return cached
+ # Check cache with TTL
+ if cached_entry := self._knowledge_cache.get(target):
+ if time.time() < cached_entry.expires_at:
+ return cached_entry.value
+ else:
+ # Entry expired, remove it
+ del self._knowledge_cache[target]
knowledge_bases = await self._fetch_knowledge_bases()
for kb in knowledge_bases:
if kb.get("name") == target:
kb_id = kb.get("id")
if isinstance(kb_id, str):
- self._knowledge_cache[target] = kb_id
+ expires_at = time.time() + self._cache_ttl
+ self._knowledge_cache[target] = CacheEntry(kb_id, expires_at)
return kb_id
if not create:
return None
knowledge_id = await self._create_collection(target)
- self._knowledge_cache[target] = knowledge_id
+ expires_at = time.time() + self._cache_ttl
+ self._knowledge_cache[target] = CacheEntry(knowledge_id, expires_at)
return knowledge_id
@override
@@ -12573,15 +13589,17 @@ class OpenWebUIStorage(BaseStorage):
# Use document title from metadata if available, otherwise fall back to ID
filename = document.metadata.get("title") or f"doc_{document.id}"
# Ensure filename has proper extension
- if not filename.endswith(('.txt', '.md', '.pdf', '.doc', '.docx')):
+ if not filename.endswith((".txt", ".md", ".pdf", ".doc", ".docx")):
filename = f"{filename}.txt"
files = {"file": (filename, document.content.encode(), "text/plain")}
- response = await self.client.post(
+ response = await self.http_client.request(
+ "POST",
"/api/v1/files/",
files=files,
params={"process": True, "process_in_background": False},
)
- response.raise_for_status()
+ if response is None:
+ raise StorageError("Unexpected None response from file upload")
file_data = response.json()
file_id = file_data.get("id")
@@ -12590,19 +13608,12 @@ class OpenWebUIStorage(BaseStorage):
raise StorageError("File upload failed: no file ID returned")
# Step 2: Add file to knowledge base
- response = await self.client.post(
- f"/api/v1/knowledge/{knowledge_id}/file/add", json={"file_id": file_id}
+ response = await self.http_client.request(
+ "POST", f"/api/v1/knowledge/{knowledge_id}/file/add", json={"file_id": file_id}
)
- response.raise_for_status()
return str(file_id)
- except ConnectError as e:
- raise StorageError(f"Connection to OpenWebUI failed: {e}") from e
- except HTTPStatusError as e:
- raise StorageError(f"OpenWebUI returned error {e.response.status_code}: {e}") from e
- except RequestError as e:
- raise StorageError(f"Request to OpenWebUI failed: {e}") from e
except Exception as e:
raise StorageError(f"Failed to store document: {e}") from e
@@ -12631,15 +13642,19 @@ class OpenWebUIStorage(BaseStorage):
# Use document title from metadata if available, otherwise fall back to ID
filename = doc.metadata.get("title") or f"doc_{doc.id}"
# Ensure filename has proper extension
- if not filename.endswith(('.txt', '.md', '.pdf', '.doc', '.docx')):
+ if not filename.endswith((".txt", ".md", ".pdf", ".doc", ".docx")):
filename = f"{filename}.txt"
files = {"file": (filename, doc.content.encode(), "text/plain")}
- upload_response = await self.client.post(
+ upload_response = await self.http_client.request(
+ "POST",
"/api/v1/files/",
files=files,
params={"process": True, "process_in_background": False},
)
- upload_response.raise_for_status()
+ if upload_response is None:
+ raise StorageError(
+ f"Unexpected None response from file upload for document {doc.id}"
+ )
file_data = upload_response.json()
file_id = file_data.get("id")
@@ -12649,10 +13664,9 @@ class OpenWebUIStorage(BaseStorage):
f"File upload failed for document {doc.id}: no file ID returned"
)
- attach_response = await self.client.post(
- f"/api/v1/knowledge/{knowledge_id}/file/add", json={"file_id": file_id}
+ await self.http_client.request(
+ "POST", f"/api/v1/knowledge/{knowledge_id}/file/add", json={"file_id": file_id}
)
- attach_response.raise_for_status()
return str(file_id)
@@ -12667,7 +13681,8 @@ class OpenWebUIStorage(BaseStorage):
if isinstance(result, Exception):
failures.append(f"{doc.id}: {result}")
else:
- file_ids.append(cast(str, result))
+ if isinstance(result, str):
+ file_ids.append(result)
if failures:
LOGGER.warning(
@@ -12678,14 +13693,6 @@ class OpenWebUIStorage(BaseStorage):
return file_ids
- except ConnectError as e:
- raise StorageError(f"Connection to OpenWebUI failed during batch: {e}") from e
- except HTTPStatusError as e:
- raise StorageError(
- f"OpenWebUI returned error {e.response.status_code} during batch: {e}"
- ) from e
- except RequestError as e:
- raise StorageError(f"Request to OpenWebUI failed during batch: {e}") from e
except Exception as e:
raise StorageError(f"Failed to store batch: {e}") from e
@@ -12703,11 +13710,88 @@ class OpenWebUIStorage(BaseStorage):
Returns:
Always None - retrieval not supported
"""
+ _ = document_id, collection_name # Mark as used
# OpenWebUI uses file-based storage without direct document retrieval
- # This will cause the base check_exists method to return False,
- # which means documents will always be re-scraped for OpenWebUI
raise NotImplementedError("OpenWebUI doesn't support document retrieval by ID")
+ @override
+ async def check_exists(
+ self, document_id: str, *, collection_name: str | None = None, stale_after_days: int = 30
+ ) -> bool:
+ """
+ Check if a document exists in OpenWebUI knowledge base by searching files.
+
+ Args:
+ document_id: Document ID to check (usually based on source URL)
+ collection_name: Knowledge base name
+ stale_after_days: Consider document stale after this many days
+
+ Returns:
+ True if document exists and is not stale, False otherwise
+ """
+ try:
+ from datetime import UTC, datetime, timedelta
+
+ # Get knowledge base
+ knowledge_id = await self._get_knowledge_id(collection_name, create=False)
+ if not knowledge_id:
+ return False
+
+ # Get detailed knowledge base info to access files
+ response = await self.http_client.request("GET", f"/api/v1/knowledge/{knowledge_id}")
+ if response is None:
+ return False
+
+ kb_data = response.json()
+ files = kb_data.get("files", [])
+
+ # Look for file with matching document ID or source URL in metadata
+ cutoff = datetime.now(UTC) - timedelta(days=stale_after_days)
+
+ def _parse_openwebui_timestamp(timestamp_str: str) -> datetime | None:
+ """Parse OpenWebUI timestamp with proper timezone handling."""
+ try:
+ # Handle both 'Z' suffix and explicit timezone
+ normalized = timestamp_str.replace("Z", "+00:00")
+ parsed = datetime.fromisoformat(normalized)
+ # Ensure timezone awareness
+ if parsed.tzinfo is None:
+ parsed = parsed.replace(tzinfo=UTC)
+ return parsed
+ except (ValueError, AttributeError):
+ return None
+
+ def _check_file_freshness(file_info: dict[str, object]) -> bool:
+ """Check if file is fresh enough based on creation date."""
+ created_at = file_info.get("created_at")
+ if not isinstance(created_at, str):
+ # No date info available, consider stale to be safe
+ return False
+
+ file_date = _parse_openwebui_timestamp(created_at)
+ return file_date is not None and file_date >= cutoff
+
+ for file_info in files:
+ if not isinstance(file_info, dict):
+ continue
+
+ file_id = file_info.get("id")
+ if str(file_id) == document_id:
+ return _check_file_freshness(file_info)
+
+ # Also check meta.source_url if available for URL-based document IDs
+ meta = file_info.get("meta", {})
+ if isinstance(meta, dict):
+ source_url = meta.get("source_url")
+ if source_url and document_id in str(source_url):
+ return _check_file_freshness(file_info)
+
+ return False
+
+ except Exception as e:
+ LOGGER.debug("Error checking document existence in OpenWebUI: %s", e)
+ return False
+
@override
async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool:
"""
@@ -12728,35 +13812,16 @@ class OpenWebUIStorage(BaseStorage):
return False
# Remove file from knowledge base
- response = await self.client.post(
- f"/api/v1/knowledge/{knowledge_id}/file/remove", json={"file_id": document_id}
+ await self.http_client.request(
+ "POST",
+ f"/api/v1/knowledge/{knowledge_id}/file/remove",
+ json={"file_id": document_id},
)
- response.raise_for_status()
- delete_response = await self.client.delete(f"/api/v1/files/{document_id}")
- if delete_response.status_code == 404:
- return True
- delete_response.raise_for_status()
+ await self.http_client.request("DELETE", f"/api/v1/files/{document_id}", allow_404=True)
return True
-
- except ConnectError as exc:
- LOGGER.error(
- "Failed to reach OpenWebUI when deleting file %s", document_id, exc_info=exc
- )
- return False
- except HTTPStatusError as exc:
- LOGGER.error(
- "OpenWebUI returned status error %s when deleting file %s",
- exc.response.status_code if exc.response else "unknown",
- document_id,
- exc_info=exc,
- )
- return False
- except RequestError as exc:
- LOGGER.error("Request error deleting file %s from OpenWebUI", document_id, exc_info=exc)
- return False
except Exception as exc:
- LOGGER.error("Unexpected error deleting file %s", document_id, exc_info=exc)
+ LOGGER.error("Error deleting file %s from OpenWebUI", document_id, exc_info=exc)
return False
async def list_collections(self) -> list[str]:
@@ -12775,12 +13840,6 @@ class OpenWebUIStorage(BaseStorage):
for kb in knowledge_bases
]
- except ConnectError as e:
- raise StorageError(f"Connection to OpenWebUI failed: {e}") from e
- except HTTPStatusError as e:
- raise StorageError(f"OpenWebUI returned error {e.response.status_code}: {e}") from e
- except RequestError as e:
- raise StorageError(f"Request to OpenWebUI failed: {e}") from e
except Exception as e:
raise StorageError(f"Failed to list knowledge bases: {e}") from e
@@ -12801,8 +13860,9 @@ class OpenWebUIStorage(BaseStorage):
return True
# Delete the knowledge base using the OpenWebUI API
- response = await self.client.delete(f"/api/v1/knowledge/{knowledge_id}/delete")
- response.raise_for_status()
+ await self.http_client.request(
+ "DELETE", f"/api/v1/knowledge/{knowledge_id}/delete", allow_404=True
+ )
# Remove from cache if it exists
if collection_name in self._knowledge_cache:
@@ -12811,45 +13871,26 @@ class OpenWebUIStorage(BaseStorage):
LOGGER.info("Successfully deleted knowledge base: %s", collection_name)
return True
- except HTTPStatusError as e:
- # Handle 404 as success (already deleted)
- if e.response.status_code == 404:
- LOGGER.info("Knowledge base %s was already deleted or not found", collection_name)
- return True
- LOGGER.error(
- "OpenWebUI returned error %s when deleting knowledge base %s",
- e.response.status_code,
- collection_name,
- exc_info=e,
- )
- return False
- except ConnectError as e:
- LOGGER.error(
- "Failed to reach OpenWebUI when deleting knowledge base %s",
- collection_name,
- exc_info=e,
- )
- return False
- except RequestError as e:
- LOGGER.error(
- "Request error deleting knowledge base %s from OpenWebUI",
- collection_name,
- exc_info=e,
- )
- return False
except Exception as e:
- LOGGER.error("Unexpected error deleting knowledge base %s", collection_name, exc_info=e)
+ if hasattr(e, "response"):
+ response_attr = getattr(e, "response", None)
+ if response_attr is not None and hasattr(response_attr, "status_code"):
+ with contextlib.suppress(Exception):
+ status_code = response_attr.status_code
+ if status_code == 404:
+ LOGGER.info(
+ "Knowledge base %s was already deleted or not found",
+ collection_name,
+ )
+ return True
+ LOGGER.error(
+ "Error deleting knowledge base %s from OpenWebUI",
+ collection_name,
+ exc_info=e,
+ )
return False
- class CollectionSummary(TypedDict):
- """Structure describing a knowledge base summary."""
-
- name: str
- count: int
- size_mb: float
-
-
- async def _get_knowledge_base_count(self, kb: dict[str, object]) -> int:
+ async def _get_knowledge_base_count(self, kb: OpenWebUIKnowledgeBase) -> int:
"""Get the file count for a knowledge base."""
kb_id = kb.get("id")
name = kb.get("name", "Unknown")
@@ -12859,17 +13900,21 @@ class OpenWebUIStorage(BaseStorage):
return await self._count_files_from_detailed_info(str(kb_id), str(name), kb)
- def _count_files_from_basic_info(self, kb: dict[str, object]) -> int:
+ def _count_files_from_basic_info(self, kb: OpenWebUIKnowledgeBase) -> int:
"""Count files from basic knowledge base info."""
files = kb.get("files", [])
return len(files) if isinstance(files, list) and files is not None else 0
- async def _count_files_from_detailed_info(self, kb_id: str, name: str, kb: dict[str, object]) -> int:
+ async def _count_files_from_detailed_info(
+ self, kb_id: str, name: str, kb: OpenWebUIKnowledgeBase
+ ) -> int:
"""Count files by fetching detailed knowledge base info."""
try:
LOGGER.debug(f"Fetching detailed info for KB '{name}' from /api/v1/knowledge/{kb_id}")
- detail_response = await self.client.get(f"/api/v1/knowledge/{kb_id}")
- detail_response.raise_for_status()
+ detail_response = await self.http_client.request("GET", f"/api/v1/knowledge/{kb_id}")
+ if detail_response is None:
+ LOGGER.warning(f"Knowledge base '{name}' (ID: {kb_id}) not found")
+ return self._count_files_from_basic_info(kb)
detailed_kb = detail_response.json()
files = detailed_kb.get("files", [])
@@ -12882,21 +13927,18 @@ class OpenWebUIStorage(BaseStorage):
LOGGER.warning(f"Failed to get detailed info for KB '{name}' (ID: {kb_id}): {e}")
return self._count_files_from_basic_info(kb)
- async def describe_collections(self) -> list[dict[str, object]]:
+ async def describe_collections(self) -> list[CollectionSummary]:
"""Return metadata about each knowledge base."""
try:
knowledge_bases = await self._fetch_knowledge_bases()
- collections: list[dict[str, object]] = []
+ collections: list[CollectionSummary] = []
for kb in knowledge_bases:
- if not isinstance(kb, dict):
- continue
-
count = await self._get_knowledge_base_count(kb)
name = kb.get("name", "Unknown")
size_mb = count * 0.5 # rough heuristic
- summary: dict[str, object] = {
+ summary: CollectionSummary = {
"name": str(name),
"count": count,
"size_mb": float(size_mb),
@@ -12940,8 +13982,10 @@ class OpenWebUIStorage(BaseStorage):
return 0
# Get detailed knowledge base information to get accurate file count
- detail_response = await self.client.get(f"/api/v1/knowledge/{kb_id}")
- detail_response.raise_for_status()
+ detail_response = await self.http_client.request("GET", f"/api/v1/knowledge/{kb_id}")
+ if detail_response is None:
+ LOGGER.warning(f"Knowledge base '{collection_name}' (ID: {kb_id}) not found")
+ return self._count_files_from_basic_info(kb)
detailed_kb = detail_response.json()
files = detailed_kb.get("files", [])
@@ -12954,7 +13998,7 @@ class OpenWebUIStorage(BaseStorage):
LOGGER.warning(f"Failed to get count for collection '{collection_name}': {e}")
return 0
- async def get_knowledge_by_name(self, name: str) -> dict[str, object] | None:
+ async def get_knowledge_by_name(self, name: str) -> OpenWebUIKnowledgeBase | None:
"""
Get knowledge base details by name.
@@ -12965,18 +14009,33 @@ class OpenWebUIStorage(BaseStorage):
Knowledge base details or None if not found
"""
try:
- response = await self.client.get("/api/v1/knowledge/list")
- response.raise_for_status()
+ response = await self.http_client.request("GET", "/api/v1/knowledge/list")
+ if response is None:
+ return None
knowledge_bases = response.json()
- return next(
- (
- {str(k): v for k, v in kb.items()}
- for kb in knowledge_bases
- if isinstance(kb, dict) and kb.get("name") == name
- ),
- None,
- )
+ # Find and properly type the knowledge base
+ for kb in knowledge_bases:
+ if (
+ isinstance(kb, dict)
+ and kb.get("name") == name
+ and "id" in kb
+ and isinstance(kb["id"], str)
+ ):
+ # Create properly typed response
+ result: OpenWebUIKnowledgeBase = {
+ "id": kb["id"],
+ "name": str(kb["name"]),
+ "description": kb.get("description", ""),
+ "created_at": kb.get("created_at", ""),
+ "updated_at": kb.get("updated_at", ""),
+ }
+ if "files" in kb and isinstance(kb["files"], list):
+ result["files"] = kb["files"]
+ if "data" in kb and isinstance(kb["data"], dict):
+ result["data"] = kb["data"]
+ return result
+ return None
except Exception as e:
raise StorageError(f"Failed to get knowledge base by name: {e}") from e
@@ -12992,6 +14051,7 @@ class OpenWebUIStorage(BaseStorage):
exc_tb: object | None,
) -> None:
"""Async context manager exit."""
+ _ = exc_type, exc_val, exc_tb # Mark as used
await self.close()
async def list_documents(
@@ -13000,7 +14060,7 @@ class OpenWebUIStorage(BaseStorage):
offset: int = 0,
*,
collection_name: str | None = None,
- ) -> list[dict[str, object]]:
+ ) -> list[DocumentInfo]:
"""
List documents (files) in a knowledge base.
@@ -13050,11 +14110,8 @@ class OpenWebUIStorage(BaseStorage):
paginated_files = files[offset : offset + limit]
# Convert to document format with safe field access
- documents: list[dict[str, object]] = []
+ documents: list[DocumentInfo] = []
for i, file_info in enumerate(paginated_files):
- if not isinstance(file_info, dict):
- continue
-
# Safely extract fields with fallbacks
doc_id = str(file_info.get("id", f"file_{i}"))
@@ -13068,7 +14125,11 @@ class OpenWebUIStorage(BaseStorage):
filename = file_info["name"]
# Check meta.name (from FileModelResponse schema)
elif isinstance(file_info.get("meta"), dict):
- filename = file_info["meta"].get("name")
+ meta = file_info.get("meta")
+ if isinstance(meta, dict):
+ filename_value = meta.get("name")
+ if isinstance(filename_value, str):
+ filename = filename_value
# Final fallback
if not filename:
@@ -13078,28 +14139,28 @@ class OpenWebUIStorage(BaseStorage):
# Extract size from meta if available
size = 0
- if isinstance(file_info.get("meta"), dict):
- size = file_info["meta"].get("size", 0)
+ meta = file_info.get("meta")
+ if isinstance(meta, dict):
+ size_value = meta.get("size", 0)
+ size = int(size_value) if isinstance(size_value, (int, float)) else 0
else:
- size = file_info.get("size", 0)
+ size_value = file_info.get("size", 0)
+ size = int(size_value) if isinstance(size_value, (int, float)) else 0
# Estimate word count from file size (very rough approximation)
word_count = max(1, int(size / 6)) if isinstance(size, (int, float)) else 0
- documents.append(
- {
- "id": doc_id,
- "title": filename,
- "source_url": "", # OpenWebUI files don't typically have source URLs
- "description": f"File: {filename}",
- "content_type": str(file_info.get("content_type", "text/plain")),
- "content_preview": f"File uploaded to OpenWebUI: {filename}",
- "word_count": word_count,
- "timestamp": str(
- file_info.get("created_at") or file_info.get("timestamp", "")
- ),
- }
- )
+ doc_info: DocumentInfo = {
+ "id": doc_id,
+ "title": filename,
+ "source_url": "", # OpenWebUI files don't typically have source URLs
+ "description": f"File: {filename}",
+ "content_type": str(file_info.get("content_type", "text/plain")),
+ "content_preview": f"File uploaded to OpenWebUI: {filename}",
+ "word_count": word_count,
+ "timestamp": str(file_info.get("created_at") or file_info.get("timestamp", "")),
+ }
+ documents.append(doc_info)
return documents
@@ -13126,21 +14187,17 @@ class OpenWebUIStorage(BaseStorage):
async def close(self) -> None:
"""Close client connection."""
- if hasattr(self, "client") and self.client:
- try:
- await self.client.aclose()
- except Exception as e:
- import logging
-
- logging.warning(f"Error closing OpenWebUI client: {e}")
+ if hasattr(self, "http_client"):
+ await self.http_client.close()
"""Weaviate storage adapter."""
-from collections.abc import AsyncGenerator, Mapping, Sequence
+import asyncio
+from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
from datetime import UTC, datetime
-from typing import Literal, Self, TypeAlias, cast, overload
+from typing import Literal, Self, TypeAlias, TypeVar, cast, overload
from uuid import UUID
import weaviate
@@ -13159,8 +14216,10 @@ from ..core.exceptions import StorageError
from ..core.models import Document, DocumentMetadata, IngestionSource, StorageConfig
from ..utils.vectorizer import Vectorizer
from .base import BaseStorage
+from .types import CollectionSummary, DocumentInfo
VectorContainer: TypeAlias = Mapping[str, object] | Sequence[object] | None
+T = TypeVar("T")
class WeaviateStorage(BaseStorage):
@@ -13182,6 +14241,28 @@ class WeaviateStorage(BaseStorage):
self.vectorizer = Vectorizer(config)
self._default_collection = self._normalize_collection_name(config.collection_name)
+ async def _run_sync(self, func: Callable[..., T], *args: object, **kwargs: object) -> T:
+ """
+ Run synchronous Weaviate operations in thread pool to avoid blocking event loop.
+
+ Args:
+ func: Synchronous function to run
+ *args: Positional arguments for the function
+ **kwargs: Keyword arguments for the function
+
+ Returns:
+ Result of the function call
+
+ Raises:
+ StorageError: If the operation fails
+ """
+ try:
+ return await asyncio.to_thread(func, *args, **kwargs)
+ except (WeaviateConnectionError, WeaviateBatchError, WeaviateQueryError) as e:
+ raise StorageError(f"Weaviate operation failed: {e}") from e
+ except Exception as e:
+ raise StorageError(f"Unexpected error in Weaviate operation: {e}") from e
+
@override
async def initialize(self) -> None:
"""Initialize Weaviate client and create collection if needed."""
@@ -13190,7 +14271,7 @@ class WeaviateStorage(BaseStorage):
self.client = weaviate.WeaviateClient(
connection_params=weaviate.connect.ConnectionParams.from_url(
url=str(self.config.endpoint),
- grpc_port=50051, # Default gRPC port
+ grpc_port=self.config.grpc_port or 50051,
),
additional_config=weaviate.classes.init.AdditionalConfig(
timeout=weaviate.classes.init.Timeout(init=30, query=60, insert=120),
@@ -13198,7 +14279,7 @@ class WeaviateStorage(BaseStorage):
)
# Connect to the client
- self.client.connect()
+ await self._run_sync(self.client.connect)
# Ensure the default collection exists
await self._ensure_collection(self._default_collection)
@@ -13213,8 +14294,8 @@ class WeaviateStorage(BaseStorage):
if not self.client:
raise StorageError("Weaviate client not initialized")
try:
- client = cast(weaviate.WeaviateClient, self.client)
- client.collections.create(
+ await self._run_sync(
+ self.client.collections.create,
name=collection_name,
properties=[
Property(
@@ -13243,7 +14324,7 @@ class WeaviateStorage(BaseStorage):
],
vectorizer_config=Configure.Vectorizer.none(),
)
- except Exception as e:
+ except (WeaviateConnectionError, WeaviateBatchError) as e:
raise StorageError(f"Failed to create collection: {e}") from e
@staticmethod
@@ -13251,13 +14332,9 @@ class WeaviateStorage(BaseStorage):
"""Normalize vector payloads returned by Weaviate into a float list."""
if isinstance(vector_raw, Mapping):
default_vector = vector_raw.get("default")
- return WeaviateStorage._extract_vector(
- cast(VectorContainer, default_vector)
- )
+ return WeaviateStorage._extract_vector(cast(VectorContainer, default_vector))
- if not isinstance(vector_raw, Sequence) or isinstance(
- vector_raw, (str, bytes, bytearray)
- ):
+ if not isinstance(vector_raw, Sequence) or isinstance(vector_raw, (str, bytes, bytearray)):
return None
items = list(vector_raw)
@@ -13272,9 +14349,7 @@ class WeaviateStorage(BaseStorage):
except (TypeError, ValueError):
return None
- if isinstance(first_item, Sequence) and not isinstance(
- first_item, (str, bytes, bytearray)
- ):
+ if isinstance(first_item, Sequence) and not isinstance(first_item, (str, bytes, bytearray)):
inner_items = list(first_item)
if all(isinstance(item, (int, float)) for item in inner_items):
try:
@@ -13305,8 +14380,7 @@ class WeaviateStorage(BaseStorage):
properties: object,
*,
context: str,
- ) -> Mapping[str, object]:
- ...
+ ) -> Mapping[str, object]: ...
@staticmethod
@overload
@@ -13315,8 +14389,7 @@ class WeaviateStorage(BaseStorage):
*,
context: str,
allow_missing: Literal[False],
- ) -> Mapping[str, object]:
- ...
+ ) -> Mapping[str, object]: ...
@staticmethod
@overload
@@ -13325,8 +14398,7 @@ class WeaviateStorage(BaseStorage):
*,
context: str,
allow_missing: Literal[True],
- ) -> Mapping[str, object] | None:
- ...
+ ) -> Mapping[str, object] | None: ...
@staticmethod
def _coerce_properties(
@@ -13348,6 +14420,29 @@ class WeaviateStorage(BaseStorage):
return cast(Mapping[str, object], properties)
+ @staticmethod
+ def _build_document_properties(doc: Document) -> dict[str, object]:
+ """
+ Build Weaviate properties dict from document.
+
+ Args:
+ doc: Document to build properties for
+
+ Returns:
+ Properties dict suitable for Weaviate
+ """
+ return {
+ "content": doc.content,
+ "source_url": doc.metadata["source_url"],
+ "title": doc.metadata.get("title", ""),
+ "description": doc.metadata.get("description", ""),
+ "timestamp": doc.metadata["timestamp"].isoformat(),
+ "content_type": doc.metadata["content_type"],
+ "word_count": doc.metadata["word_count"],
+ "char_count": doc.metadata["char_count"],
+ "source": doc.source.value,
+ }
+
def _normalize_collection_name(self, collection_name: str | None) -> str:
"""Return a canonicalized collection name, defaulting to configured value."""
candidate = collection_name or self.config.collection_name
@@ -13364,7 +14459,7 @@ class WeaviateStorage(BaseStorage):
if not self.client:
raise StorageError("Weaviate client not initialized")
- client = cast(weaviate.WeaviateClient, self.client)
+ client = self.client
existing = client.collections.list_all()
if collection_name not in existing:
await self._create_collection(collection_name)
@@ -13384,7 +14479,7 @@ class WeaviateStorage(BaseStorage):
if ensure_exists:
await self._ensure_collection(normalized)
- client = cast(weaviate.WeaviateClient, self.client)
+ client = self.client
return client.collections.get(normalized), normalized
@override
@@ -13408,26 +14503,19 @@ class WeaviateStorage(BaseStorage):
)
# Prepare properties
- properties = {
- "content": document.content,
- "source_url": document.metadata["source_url"],
- "title": document.metadata.get("title", ""),
- "description": document.metadata.get("description", ""),
- "timestamp": document.metadata["timestamp"].isoformat(),
- "content_type": document.metadata["content_type"],
- "word_count": document.metadata["word_count"],
- "char_count": document.metadata["char_count"],
- "source": document.source.value,
- }
+ properties = self._build_document_properties(document)
# Insert with vector
- result = collection.data.insert(
- properties=properties, vector=document.vector, uuid=str(document.id)
+ result = await self._run_sync(
+ collection.data.insert,
+ properties=properties,
+ vector=document.vector,
+ uuid=str(document.id),
)
return str(result)
- except Exception as e:
+ except (WeaviateConnectionError, WeaviateBatchError, WeaviateQueryError) as e:
raise StorageError(f"Failed to store document: {e}") from e
@override
@@ -13448,32 +14536,24 @@ class WeaviateStorage(BaseStorage):
collection_name, ensure_exists=True
)
- # Vectorize documents without vectors
- for doc in documents:
- if doc.vector is None:
- doc.vector = await self.vectorizer.vectorize(doc.content)
+ # Vectorize documents without vectors using batch processing
+ to_vectorize = [(i, doc) for i, doc in enumerate(documents) if doc.vector is None]
+ if to_vectorize:
+ contents = [doc.content for _, doc in to_vectorize]
+ vectors = await self.vectorizer.vectorize_batch(contents)
+ for (idx, _), vector in zip(to_vectorize, vectors, strict=False):
+ documents[idx].vector = vector
# Prepare batch data for insert_many
batch_objects = []
for doc in documents:
- properties = {
- "content": doc.content,
- "source_url": doc.metadata["source_url"],
- "title": doc.metadata.get("title", ""),
- "description": doc.metadata.get("description", ""),
- "timestamp": doc.metadata["timestamp"].isoformat(),
- "content_type": doc.metadata["content_type"],
- "word_count": doc.metadata["word_count"],
- "char_count": doc.metadata["char_count"],
- "source": doc.source.value,
- }
-
+ properties = self._build_document_properties(doc)
batch_objects.append(
DataObject(properties=properties, vector=doc.vector, uuid=str(doc.id))
)
# Insert batch using insert_many
- response = collection.data.insert_many(batch_objects)
+ response = await self._run_sync(collection.data.insert_many, batch_objects)
successful_ids: list[str] = []
error_indices = set(response.errors.keys()) if response else set()
@@ -13498,11 +14578,7 @@ class WeaviateStorage(BaseStorage):
return successful_ids
- except WeaviateBatchError as e:
- raise StorageError(f"Batch operation failed: {e}") from e
- except WeaviateConnectionError as e:
- raise StorageError(f"Connection to Weaviate failed: {e}") from e
- except Exception as e:
+ except (WeaviateBatchError, WeaviateConnectionError, WeaviateQueryError) as e:
raise StorageError(f"Failed to store batch: {e}") from e
@override
@@ -13522,7 +14598,7 @@ class WeaviateStorage(BaseStorage):
collection, resolved_name = await self._prepare_collection(
collection_name, ensure_exists=False
)
- result = collection.query.fetch_object_by_id(document_id)
+ result = await self._run_sync(collection.query.fetch_object_by_id, document_id)
if not result:
return None
@@ -13532,13 +14608,30 @@ class WeaviateStorage(BaseStorage):
result.properties,
context="fetch_object_by_id",
)
+ # Parse timestamp to datetime for consistent metadata format
+ from datetime import UTC, datetime
+
+ timestamp_raw = props.get("timestamp")
+ timestamp_parsed: datetime
+ try:
+ if isinstance(timestamp_raw, str):
+ timestamp_parsed = datetime.fromisoformat(timestamp_raw)
+ if timestamp_parsed.tzinfo is None:
+ timestamp_parsed = timestamp_parsed.replace(tzinfo=UTC)
+ elif isinstance(timestamp_raw, datetime):
+ timestamp_parsed = timestamp_raw
+ if timestamp_parsed.tzinfo is None:
+ timestamp_parsed = timestamp_parsed.replace(tzinfo=UTC)
+ else:
+ timestamp_parsed = datetime.now(UTC)
+ except (ValueError, TypeError):
+ timestamp_parsed = datetime.now(UTC)
+
metadata_dict = {
"source_url": str(props["source_url"]),
"title": str(props.get("title")) if props.get("title") else None,
- "description": str(props.get("description"))
- if props.get("description")
- else None,
- "timestamp": str(props["timestamp"]),
+ "description": str(props.get("description")) if props.get("description") else None,
+ "timestamp": timestamp_parsed,
"content_type": str(props["content_type"]),
"word_count": int(str(props["word_count"])),
"char_count": int(str(props["char_count"])),
@@ -13561,11 +14654,13 @@ class WeaviateStorage(BaseStorage):
except WeaviateConnectionError as e:
# Connection issues should be logged and return None
import logging
+
logging.warning(f"Weaviate connection error retrieving document {document_id}: {e}")
return None
except Exception as e:
# Log unexpected errors for debugging
import logging
+
logging.warning(f"Unexpected error retrieving document {document_id}: {e}")
return None
@@ -13574,9 +14669,7 @@ class WeaviateStorage(BaseStorage):
metadata_dict = {
"source_url": str(props["source_url"]),
"title": str(props.get("title")) if props.get("title") else None,
- "description": str(props.get("description"))
- if props.get("description")
- else None,
+ "description": str(props.get("description")) if props.get("description") else None,
"timestamp": str(props["timestamp"]),
"content_type": str(props["content_type"]),
"word_count": int(str(props["word_count"])),
@@ -13599,6 +14692,7 @@ class WeaviateStorage(BaseStorage):
return max(0.0, 1.0 - distance_value)
except (TypeError, ValueError) as e:
import logging
+
logging.debug(f"Invalid distance value {raw_distance}: {e}")
return None
@@ -13643,37 +14737,39 @@ class WeaviateStorage(BaseStorage):
collection_name: str | None = None,
) -> AsyncGenerator[Document, None]:
"""
- Search for documents in Weaviate.
+ Search for documents in Weaviate using hybrid search.
Args:
query: Search query
limit: Maximum results
- threshold: Similarity threshold
+ threshold: Similarity threshold (not used in hybrid search)
Yields:
Matching documents
"""
try:
- query_vector = await self.vectorizer.vectorize(query)
+ if not self.client:
+ raise StorageError("Weaviate client not initialized")
+
collection, resolved_name = await self._prepare_collection(
collection_name, ensure_exists=False
)
- results = collection.query.near_vector(
- near_vector=query_vector,
- limit=limit,
- distance=1 - threshold,
- return_metadata=["distance"],
- )
+ # Try hybrid search first, fall back to BM25 keyword search
+ try:
+ response = await self._run_sync(
+ collection.query.hybrid, query=query, limit=limit, return_metadata=["score"]
+ )
+ except (WeaviateQueryError, StorageError):
+ # Fall back to BM25 if hybrid search is not supported or fails
+ response = await self._run_sync(
+ collection.query.bm25, query=query, limit=limit, return_metadata=["score"]
+ )
- for result in results.objects:
- yield self._build_search_document(result, resolved_name)
+ for obj in response.objects:
+ yield self._build_document_from_search(obj, resolved_name)
- except WeaviateQueryError as e:
- raise StorageError(f"Search query failed: {e}") from e
- except WeaviateConnectionError as e:
- raise StorageError(f"Connection to Weaviate failed during search: {e}") from e
- except Exception as e:
+ except (WeaviateQueryError, WeaviateConnectionError) as e:
raise StorageError(f"Search failed: {e}") from e
@override
@@ -13689,7 +14785,7 @@ class WeaviateStorage(BaseStorage):
"""
try:
collection, _ = await self._prepare_collection(collection_name, ensure_exists=False)
- collection.data.delete_by_id(document_id)
+ await self._run_sync(collection.data.delete_by_id, document_id)
return True
except WeaviateQueryError as e:
raise StorageError(f"Delete operation failed: {e}") from e
@@ -13726,20 +14822,20 @@ class WeaviateStorage(BaseStorage):
if not self.client:
raise StorageError("Weaviate client not initialized")
- client = cast(weaviate.WeaviateClient, self.client)
+ client = self.client
return list(client.collections.list_all())
- except Exception as e:
+ except WeaviateConnectionError as e:
raise StorageError(f"Failed to list collections: {e}") from e
- async def describe_collections(self) -> list[dict[str, object]]:
+ async def describe_collections(self) -> list[CollectionSummary]:
"""Return metadata for each Weaviate collection."""
if not self.client:
raise StorageError("Weaviate client not initialized")
try:
- client = cast(weaviate.WeaviateClient, self.client)
- collections: list[dict[str, object]] = []
+ client = self.client
+ collections: list[CollectionSummary] = []
for name in client.collections.list_all():
collection_obj = client.collections.get(name)
if not collection_obj:
@@ -13747,13 +14843,12 @@ class WeaviateStorage(BaseStorage):
count = collection_obj.aggregate.over_all(total_count=True).total_count or 0
size_mb = count * 0.01
- collections.append(
- {
- "name": name,
- "count": count,
- "size_mb": size_mb,
- }
- )
+ collection_summary: CollectionSummary = {
+ "name": name,
+ "count": count,
+ "size_mb": size_mb,
+ }
+ collections.append(collection_summary)
return collections
except Exception as e:
@@ -13777,7 +14872,7 @@ class WeaviateStorage(BaseStorage):
)
# Query for sample documents
- response = collection.query.fetch_objects(limit=limit)
+ response = await self._run_sync(collection.query.fetch_objects, limit=limit)
documents = []
for obj in response.objects:
@@ -13850,9 +14945,7 @@ class WeaviateStorage(BaseStorage):
return {
"source_url": str(props.get("source_url", "")),
"title": str(props.get("title", "")) if props.get("title") else None,
- "description": str(props.get("description", ""))
- if props.get("description")
- else None,
+ "description": str(props.get("description", "")) if props.get("description") else None,
"timestamp": datetime.fromisoformat(
str(props.get("timestamp", datetime.now(UTC).isoformat()))
),
@@ -13875,6 +14968,7 @@ class WeaviateStorage(BaseStorage):
return float(raw_score)
except (TypeError, ValueError) as e:
import logging
+
logging.debug(f"Invalid score value {raw_score}: {e}")
return None
@@ -13918,31 +15012,11 @@ class WeaviateStorage(BaseStorage):
Returns:
List of matching documents
"""
- try:
- if not self.client:
- raise StorageError("Weaviate client not initialized")
-
- collection, resolved_name = await self._prepare_collection(
- collection_name, ensure_exists=False
- )
-
- # Try hybrid search first, fall back to BM25 keyword search
- try:
- response = collection.query.hybrid(
- query=query, limit=limit, return_metadata=["score"]
- )
- except Exception:
- response = collection.query.bm25(
- query=query, limit=limit, return_metadata=["score"]
- )
-
- return [
- self._build_document_from_search(obj, resolved_name)
- for obj in response.objects
- ]
-
- except Exception as e:
- raise StorageError(f"Failed to search documents: {e}") from e
+ # Delegate to the unified search method
+ results = []
+ async for document in self.search(query, limit=limit, collection_name=collection_name):
+ results.append(document)
+ return results
async def list_documents(
self,
@@ -13950,7 +15024,7 @@ class WeaviateStorage(BaseStorage):
offset: int = 0,
*,
collection_name: str | None = None,
- ) -> list[dict[str, object]]:
+ ) -> list[DocumentInfo]:
"""
List documents in the collection with pagination.
@@ -13968,11 +15042,14 @@ class WeaviateStorage(BaseStorage):
collection, _ = await self._prepare_collection(collection_name, ensure_exists=False)
# Query documents with pagination
- response = collection.query.fetch_objects(
- limit=limit, offset=offset, return_metadata=["creation_time"]
+ response = await self._run_sync(
+ collection.query.fetch_objects,
+ limit=limit,
+ offset=offset,
+ return_metadata=["creation_time"],
)
- documents: list[dict[str, object]] = []
+ documents: list[DocumentInfo] = []
for obj in response.objects:
props = self._coerce_properties(
obj.properties,
@@ -13991,7 +15068,7 @@ class WeaviateStorage(BaseStorage):
else:
word_count = 0
- doc_info: dict[str, object] = {
+ doc_info: DocumentInfo = {
"id": str(obj.uuid),
"title": str(props.get("title", "Untitled")),
"source_url": str(props.get("source_url", "")),
@@ -14034,7 +15111,9 @@ class WeaviateStorage(BaseStorage):
)
delete_filter = Filter.by_id().contains_any(document_ids)
- response = collection.data.delete_many(where=delete_filter, verbose=True)
+ response = await self._run_sync(
+ collection.data.delete_many, where=delete_filter, verbose=True
+ )
if objects := getattr(response, "objects", None):
for result_obj in objects:
@@ -14076,20 +15155,22 @@ class WeaviateStorage(BaseStorage):
# Get documents matching filter
if where_filter:
- response = collection.query.fetch_objects(
+ response = await self._run_sync(
+ collection.query.fetch_objects,
filters=where_filter,
limit=1000, # Max batch size
)
else:
- response = collection.query.fetch_objects(
- limit=1000 # Max batch size
+ response = await self._run_sync(
+ collection.query.fetch_objects,
+ limit=1000, # Max batch size
)
# Delete matching documents
deleted_count = 0
for obj in response.objects:
try:
- collection.data.delete_by_id(obj.uuid)
+ await self._run_sync(collection.data.delete_by_id, obj.uuid)
deleted_count += 1
except Exception:
continue
@@ -14113,7 +15194,7 @@ class WeaviateStorage(BaseStorage):
target = self._normalize_collection_name(collection_name)
# Delete the collection using the client's collections API
- client = cast(weaviate.WeaviateClient, self.client)
+ client = self.client
client.collections.delete(target)
return True
@@ -14135,20 +15216,29 @@ class WeaviateStorage(BaseStorage):
await self.close()
async def close(self) -> None:
- """Close client connection."""
+ """Close client connection and vectorizer HTTP client."""
if self.client:
try:
- client = cast(weaviate.WeaviateClient, self.client)
+ client = self.client
client.close()
- except Exception as e:
+ except (WeaviateConnectionError, AttributeError) as e:
import logging
+
logging.warning(f"Error closing Weaviate client: {e}")
+ # Close vectorizer HTTP client to prevent resource leaks
+ try:
+ await self.vectorizer.close()
+ except (AttributeError, OSError) as e:
+ import logging
+
+ logging.warning(f"Error closing vectorizer client: {e}")
+
def __del__(self) -> None:
"""Clean up client connection as fallback."""
if self.client:
try:
- client = cast(weaviate.WeaviateClient, self.client)
+ client = self.client
client.close()
except Exception:
pass # Ignore errors in destructor
diff --git a/tests/__pycache__/__init__.cpython-312.pyc b/tests/__pycache__/__init__.cpython-312.pyc
index 72ead4f..7359d82 100644
Binary files a/tests/__pycache__/__init__.cpython-312.pyc and b/tests/__pycache__/__init__.cpython-312.pyc differ
diff --git a/tests/__pycache__/conftest.cpython-312-pytest-8.4.2.pyc b/tests/__pycache__/conftest.cpython-312-pytest-8.4.2.pyc
index 8047a3d..44dc978 100644
Binary files a/tests/__pycache__/conftest.cpython-312-pytest-8.4.2.pyc and b/tests/__pycache__/conftest.cpython-312-pytest-8.4.2.pyc differ
diff --git a/tests/__pycache__/openapi_mocks.cpython-312.pyc b/tests/__pycache__/openapi_mocks.cpython-312.pyc
index a9e4d4f..d67ba32 100644
Binary files a/tests/__pycache__/openapi_mocks.cpython-312.pyc and b/tests/__pycache__/openapi_mocks.cpython-312.pyc differ
diff --git a/tests/conftest.py b/tests/conftest.py
index 8b77ee4..ce9a4e0 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -59,6 +59,7 @@ class StubbedResponse:
def raise_for_status(self) -> None:
if self.status_code < 400:
return
+
# Create a minimal exception for testing - we don't need full httpx objects for tests
class TestHTTPError(Exception):
request: object | None
@@ -232,7 +233,7 @@ class AsyncClientStub:
# Convert params to the format expected by other methods
converted_params: dict[str, object] | None = None
if params:
- converted_params = {k: v for k, v in params.items()}
+ converted_params = dict(params)
method_upper = method.upper()
if method_upper == "GET":
@@ -441,7 +442,6 @@ def r2r_spec() -> OpenAPISpec:
return OpenAPISpec.from_file(PROJECT_ROOT / "r2r.json")
-
@pytest.fixture(scope="session")
def firecrawl_spec() -> OpenAPISpec:
"""Load Firecrawl OpenAPI specification."""
@@ -539,7 +539,9 @@ def firecrawl_client_stub(
links = [SimpleNamespace(url=cast(str, item.get("url", ""))) for item in links_data]
return SimpleNamespace(success=payload.get("success", True), links=links)
- async def scrape(self, url: str, formats: list[str] | None = None, **_: object) -> SimpleNamespace:
+ async def scrape(
+ self, url: str, formats: list[str] | None = None, **_: object
+ ) -> SimpleNamespace:
payload = cast(MockResponseData, self._service.scrape_response(url, formats))
data = cast(MockResponseData, payload.get("data", {}))
metadata_payload = cast(MockResponseData, data.get("metadata", {}))
diff --git a/tests/openapi_mocks.py b/tests/openapi_mocks.py
index 0a418c7..b1c052e 100644
--- a/tests/openapi_mocks.py
+++ b/tests/openapi_mocks.py
@@ -158,7 +158,9 @@ class OpenAPISpec:
return []
return [segment for segment in path.strip("/").split("/") if segment]
- def find_operation(self, method: str, path: str) -> tuple[Mapping[str, Any] | None, dict[str, str]]:
+ def find_operation(
+ self, method: str, path: str
+ ) -> tuple[Mapping[str, Any] | None, dict[str, str]]:
method = method.lower()
normalized = "/" if path in {"", "/"} else "/" + path.strip("/")
actual_segments = self._split_path(normalized)
@@ -188,9 +190,7 @@ class OpenAPISpec:
status: str | None = None,
) -> tuple[int, Any]:
responses = operation.get("responses", {})
- target_status = status or (
- "200" if "200" in responses else next(iter(responses), "200")
- )
+ target_status = status or ("200" if "200" in responses else next(iter(responses), "200"))
status_code = 200
try:
status_code = int(target_status)
@@ -222,7 +222,7 @@ class OpenAPISpec:
base = self.generate({"$ref": ref})
if overrides is None or not isinstance(base, Mapping):
return copy.deepcopy(base)
- merged = copy.deepcopy(base)
+ merged: dict[str, Any] = dict(base)
for key, value in overrides.items():
if isinstance(value, Mapping) and isinstance(merged.get(key), Mapping):
merged[key] = self.generate_from_mapping(
@@ -237,8 +237,8 @@ class OpenAPISpec:
self,
base: Mapping[str, Any],
overrides: Mapping[str, Any],
- ) -> Mapping[str, Any]:
- result = copy.deepcopy(base)
+ ) -> dict[str, Any]:
+ result: dict[str, Any] = dict(base)
for key, value in overrides.items():
if isinstance(value, Mapping) and isinstance(result.get(key), Mapping):
result[key] = self.generate_from_mapping(result[key], value)
@@ -380,18 +380,20 @@ class OpenWebUIMockService(OpenAPIMockService):
def _knowledge_user_response(self, knowledge: Mapping[str, Any]) -> dict[str, Any]:
payload = self.spec.generate_from_ref("#/components/schemas/KnowledgeUserResponse")
- payload.update({
- "id": knowledge["id"],
- "user_id": knowledge["user_id"],
- "name": knowledge["name"],
- "description": knowledge.get("description", ""),
- "data": copy.deepcopy(knowledge.get("data", {})),
- "meta": copy.deepcopy(knowledge.get("meta", {})),
- "access_control": copy.deepcopy(knowledge.get("access_control", {})),
- "created_at": knowledge.get("created_at", self._timestamp()),
- "updated_at": knowledge.get("updated_at", self._timestamp()),
- "files": copy.deepcopy(knowledge.get("files", [])),
- })
+ payload.update(
+ {
+ "id": knowledge["id"],
+ "user_id": knowledge["user_id"],
+ "name": knowledge["name"],
+ "description": knowledge.get("description", ""),
+ "data": copy.deepcopy(knowledge.get("data", {})),
+ "meta": copy.deepcopy(knowledge.get("meta", {})),
+ "access_control": copy.deepcopy(knowledge.get("access_control", {})),
+ "created_at": knowledge.get("created_at", self._timestamp()),
+ "updated_at": knowledge.get("updated_at", self._timestamp()),
+ "files": copy.deepcopy(knowledge.get("files", [])),
+ }
+ )
return payload
def _knowledge_files_response(self, knowledge: Mapping[str, Any]) -> dict[str, Any]:
@@ -435,7 +437,9 @@ class OpenWebUIMockService(OpenAPIMockService):
# Delegate to parent
return super().handle(method=method, path=path, json=json, params=params, files=files)
- def _handle_knowledge_endpoints(self, method: str, segments: list[str], json: Mapping[str, Any] | None) -> tuple[int, Any]:
+ def _handle_knowledge_endpoints(
+ self, method: str, segments: list[str], json: Mapping[str, Any] | None
+ ) -> tuple[int, Any]:
"""Handle knowledge-related API endpoints."""
method_upper = method.upper()
segment_count = len(segments)
@@ -471,7 +475,9 @@ class OpenWebUIMockService(OpenAPIMockService):
)
return 200, entry
- def _handle_specific_knowledge(self, method_upper: str, segments: list[str], json: Mapping[str, Any] | None) -> tuple[int, Any]:
+ def _handle_specific_knowledge(
+ self, method_upper: str, segments: list[str], json: Mapping[str, Any] | None
+ ) -> tuple[int, Any]:
"""Handle endpoints for specific knowledge entries."""
knowledge_id = segments[3]
knowledge = self._knowledge.get(knowledge_id)
@@ -486,7 +492,9 @@ class OpenWebUIMockService(OpenAPIMockService):
# File operations
if segment_count == 6 and segments[4] == "file":
- return self._handle_knowledge_file_operations(method_upper, segments[5], knowledge, json or {})
+ return self._handle_knowledge_file_operations(
+ method_upper, segments[5], knowledge, json or {}
+ )
# Delete knowledge
if method_upper == "DELETE" and segment_count == 5 and segments[4] == "delete":
@@ -495,7 +503,13 @@ class OpenWebUIMockService(OpenAPIMockService):
return 404, {"detail": "Knowledge operation not found"}
- def _handle_knowledge_file_operations(self, method_upper: str, operation: str, knowledge: dict[str, Any], payload: Mapping[str, Any]) -> tuple[int, Any]:
+ def _handle_knowledge_file_operations(
+ self,
+ method_upper: str,
+ operation: str,
+ knowledge: dict[str, Any],
+ payload: Mapping[str, Any],
+ ) -> tuple[int, Any]:
"""Handle file operations on knowledge entries."""
if method_upper != "POST":
return 405, {"detail": "Method not allowed"}
@@ -507,7 +521,9 @@ class OpenWebUIMockService(OpenAPIMockService):
return 404, {"detail": "File operation not found"}
- def _add_file_to_knowledge(self, knowledge: dict[str, Any], payload: Mapping[str, Any]) -> tuple[int, Any]:
+ def _add_file_to_knowledge(
+ self, knowledge: dict[str, Any], payload: Mapping[str, Any]
+ ) -> tuple[int, Any]:
"""Add a file to a knowledge entry."""
file_id = str(payload.get("file_id", ""))
if not file_id:
@@ -524,15 +540,21 @@ class OpenWebUIMockService(OpenAPIMockService):
return 200, self._knowledge_files_response(knowledge)
- def _remove_file_from_knowledge(self, knowledge: dict[str, Any], payload: Mapping[str, Any]) -> tuple[int, Any]:
+ def _remove_file_from_knowledge(
+ self, knowledge: dict[str, Any], payload: Mapping[str, Any]
+ ) -> tuple[int, Any]:
"""Remove a file from a knowledge entry."""
file_id = str(payload.get("file_id", ""))
- knowledge["files"] = [item for item in knowledge.get("files", []) if item.get("id") != file_id]
+ knowledge["files"] = [
+ item for item in knowledge.get("files", []) if item.get("id") != file_id
+ ]
knowledge["updated_at"] = self._timestamp()
return 200, self._knowledge_files_response(knowledge)
- def _handle_file_endpoints(self, method: str, segments: list[str], files: Any) -> tuple[int, Any]:
+ def _handle_file_endpoints(
+ self, method: str, segments: list[str], files: Any
+ ) -> tuple[int, Any]:
"""Handle file-related API endpoints."""
method_upper = method.upper()
@@ -561,7 +583,9 @@ class OpenWebUIMockService(OpenAPIMockService):
"""Delete a file and remove from all knowledge entries."""
self._files.pop(file_id, None)
for knowledge in self._knowledge.values():
- knowledge["files"] = [item for item in knowledge.get("files", []) if item.get("id") != file_id]
+ knowledge["files"] = [
+ item for item in knowledge.get("files", []) if item.get("id") != file_id
+ ]
return 200, {"deleted": True}
@@ -620,10 +644,14 @@ class R2RMockService(OpenAPIMockService):
None,
)
- def _set_collection_document_ids(self, collection_id: str, document_id: str, *, add: bool) -> None:
+ def _set_collection_document_ids(
+ self, collection_id: str, document_id: str, *, add: bool
+ ) -> None:
collection = self._collections.get(collection_id)
if collection is None:
- collection = self.create_collection(name=f"Collection {collection_id}", collection_id=collection_id)
+ collection = self.create_collection(
+ name=f"Collection {collection_id}", collection_id=collection_id
+ )
documents = collection.setdefault("documents", [])
if add:
if document_id not in documents:
@@ -677,7 +705,9 @@ class R2RMockService(OpenAPIMockService):
self._set_collection_document_ids(collection_id, document_id, add=False)
return True
- def append_document_metadata(self, document_id: str, metadata_list: list[dict[str, Any]]) -> dict[str, Any] | None:
+ def append_document_metadata(
+ self, document_id: str, metadata_list: list[dict[str, Any]]
+ ) -> dict[str, Any] | None:
entry = self._documents.get(document_id)
if entry is None:
return None
@@ -712,7 +742,9 @@ class R2RMockService(OpenAPIMockService):
# Delegate to parent
return super().handle(method=method, path=path, json=json, params=params, files=files)
- def _handle_collections_endpoint(self, method_upper: str, json: Mapping[str, Any] | None) -> tuple[int, Any]:
+ def _handle_collections_endpoint(
+ self, method_upper: str, json: Mapping[str, Any] | None
+ ) -> tuple[int, Any]:
"""Handle collection-related endpoints."""
if method_upper == "GET":
return self._list_collections()
@@ -732,10 +764,12 @@ class R2RMockService(OpenAPIMockService):
clone["document_count"] = len(clone.get("documents", []))
results.append(clone)
- payload.update({
- "results": results,
- "total_entries": len(results),
- })
+ payload.update(
+ {
+ "results": results,
+ "total_entries": len(results),
+ }
+ )
return 200, payload
def _create_collection_endpoint(self, body: Mapping[str, Any]) -> tuple[int, Any]:
@@ -748,7 +782,9 @@ class R2RMockService(OpenAPIMockService):
payload.update({"results": entry})
return 200, payload
- def _handle_documents_endpoint(self, method_upper: str, segments: list[str], json: Mapping[str, Any] | None, files: Any) -> tuple[int, Any]:
+ def _handle_documents_endpoint(
+ self, method_upper: str, segments: list[str], json: Mapping[str, Any] | None, files: Any
+ ) -> tuple[int, Any]:
"""Handle document-related endpoints."""
if method_upper == "POST" and len(segments) == 2:
return self._create_document_endpoint(files)
@@ -774,11 +810,13 @@ class R2RMockService(OpenAPIMockService):
"#/components/schemas/R2RResults_IngestionResponse_"
)
ingestion = self.spec.generate_from_ref("#/components/schemas/IngestionResponse")
- ingestion.update({
- "message": ingestion.get("message") or "Ingestion task queued successfully.",
- "document_id": document["id"],
- "task_id": ingestion.get("task_id") or str(uuid4()),
- })
+ ingestion.update(
+ {
+ "message": ingestion.get("message") or "Ingestion task queued successfully.",
+ "document_id": document["id"],
+ "task_id": ingestion.get("task_id") or str(uuid4()),
+ }
+ )
response_payload["results"] = ingestion
return 202, response_payload
@@ -835,7 +873,11 @@ class R2RMockService(OpenAPIMockService):
def _extract_content_from_files(self, raw_text_entry: Any) -> str:
"""Extract content from files entry."""
- if isinstance(raw_text_entry, tuple) and len(raw_text_entry) >= 2 and raw_text_entry[1] is not None:
+ if (
+ isinstance(raw_text_entry, tuple)
+ and len(raw_text_entry) >= 2
+ and raw_text_entry[1] is not None
+ ):
return str(raw_text_entry[1])
return ""
@@ -886,7 +928,7 @@ class R2RMockService(OpenAPIMockService):
results = response_payload.get("results")
if isinstance(results, Mapping):
- results = copy.deepcopy(results)
+ results = dict(results)
results.update({"success": success})
else:
results = {"success": success}
@@ -894,7 +936,9 @@ class R2RMockService(OpenAPIMockService):
response_payload["results"] = results
return (200 if success else 404), response_payload
- def _update_document_metadata(self, doc_id: str, json: Mapping[str, Any] | None) -> tuple[int, Any]:
+ def _update_document_metadata(
+ self, doc_id: str, json: Mapping[str, Any] | None
+ ) -> tuple[int, Any]:
"""Update document metadata."""
metadata_list = [dict(item) for item in json] if isinstance(json, list) else []
document = self.append_document_metadata(doc_id, metadata_list)
@@ -919,7 +963,15 @@ class FirecrawlMockService(OpenAPIMockService):
def register_map_result(self, origin: str, links: list[str]) -> None:
self._maps[origin] = list(links)
- def register_page(self, url: str, *, markdown: str | None = None, html: str | None = None, metadata: Mapping[str, Any] | None = None, links: list[str] | None = None) -> None:
+ def register_page(
+ self,
+ url: str,
+ *,
+ markdown: str | None = None,
+ html: str | None = None,
+ metadata: Mapping[str, Any] | None = None,
+ links: list[str] | None = None,
+ ) -> None:
self._pages[url] = {
"markdown": markdown,
"html": html,
diff --git a/tests/unit/__pycache__/__init__.cpython-312.pyc b/tests/unit/__pycache__/__init__.cpython-312.pyc
index f7bf11b..6eff6e7 100644
Binary files a/tests/unit/__pycache__/__init__.cpython-312.pyc and b/tests/unit/__pycache__/__init__.cpython-312.pyc differ
diff --git a/tests/unit/automations/conftest.py b/tests/unit/automations/conftest.py
new file mode 100644
index 0000000..7e05a3b
--- /dev/null
+++ b/tests/unit/automations/conftest.py
@@ -0,0 +1,15 @@
+from __future__ import annotations
+
+import pytest
+
+from ingest_pipeline.automations import AUTOMATION_TEMPLATES, get_automation_yaml_templates
+
+
+@pytest.fixture(scope="function")
+def automation_templates_copy() -> dict[str, str]:
+ return get_automation_yaml_templates()
+
+
+@pytest.fixture(scope="function")
+def automation_templates_reference() -> dict[str, str]:
+ return AUTOMATION_TEMPLATES
diff --git a/tests/unit/automations/test_init.py b/tests/unit/automations/test_init.py
new file mode 100644
index 0000000..8749439
--- /dev/null
+++ b/tests/unit/automations/test_init.py
@@ -0,0 +1,28 @@
+from __future__ import annotations
+
+import pytest
+
+
+@pytest.mark.parametrize(
+ ("template_key", "expected_snippet"),
+ [
+ ("cancel_long_running", "Cancel Long Running Ingestion Flows"),
+ ("retry_failed", "Retry Failed Ingestion Flows"),
+ ("resource_monitoring", "Manage Work Pool Based on Resources"),
+ ],
+)
+def test_automation_templates_include_expected_entries(
+ automation_templates_copy,
+ template_key: str,
+ expected_snippet: str,
+) -> None:
+ assert expected_snippet in automation_templates_copy[template_key]
+
+
+def test_get_automation_yaml_templates_returns_copy(
+ automation_templates_copy,
+ automation_templates_reference,
+) -> None:
+ automation_templates_copy["cancel_long_running"] = "overridden"
+
+ assert "overridden" not in automation_templates_reference["cancel_long_running"]
diff --git a/tests/unit/cli/__pycache__/__init__.cpython-312.pyc b/tests/unit/cli/__pycache__/__init__.cpython-312.pyc
index 014e303..cf1b2fd 100644
Binary files a/tests/unit/cli/__pycache__/__init__.cpython-312.pyc and b/tests/unit/cli/__pycache__/__init__.cpython-312.pyc differ
diff --git a/tests/unit/cli/__pycache__/test_main_cli.cpython-312-pytest-8.4.2.pyc b/tests/unit/cli/__pycache__/test_main_cli.cpython-312-pytest-8.4.2.pyc
index 4fc9d8a..c700639 100644
Binary files a/tests/unit/cli/__pycache__/test_main_cli.cpython-312-pytest-8.4.2.pyc and b/tests/unit/cli/__pycache__/test_main_cli.cpython-312-pytest-8.4.2.pyc differ
diff --git a/tests/unit/cli/test_main_cli.py b/tests/unit/cli/test_main_cli.py
index 8cf49d2..ed5e3dd 100644
--- a/tests/unit/cli/test_main_cli.py
+++ b/tests/unit/cli/test_main_cli.py
@@ -1,5 +1,6 @@
from __future__ import annotations
+from collections.abc import AsyncGenerator
from types import SimpleNamespace
from uuid import uuid4
@@ -22,7 +23,9 @@ from ingest_pipeline.core.models import (
),
)
@pytest.mark.asyncio
-async def test_run_ingestion_collection_resolution(monkeypatch: pytest.MonkeyPatch, collection_arg: str | None, expected: str) -> None:
+async def test_run_ingestion_collection_resolution(
+ monkeypatch: pytest.MonkeyPatch, collection_arg: str | None, expected: str
+) -> None:
recorded: dict[str, object] = {}
async def fake_flow(**kwargs: object) -> IngestionResult:
@@ -128,7 +131,7 @@ async def test_run_search_collects_results(monkeypatch: pytest.MonkeyPatch) -> N
threshold: float = 0.7,
*,
collection_name: str | None = None,
- ) -> object:
+ ) -> AsyncGenerator[SimpleNamespace, None]:
yield SimpleNamespace(title="Title", content="Body text", score=0.91)
dummy_settings = SimpleNamespace(
diff --git a/tests/unit/flows/__pycache__/__init__.cpython-312.pyc b/tests/unit/flows/__pycache__/__init__.cpython-312.pyc
index cf9577d..a06fa39 100644
Binary files a/tests/unit/flows/__pycache__/__init__.cpython-312.pyc and b/tests/unit/flows/__pycache__/__init__.cpython-312.pyc differ
diff --git a/tests/unit/flows/__pycache__/test_ingestion_flow.cpython-312-pytest-8.4.2.pyc b/tests/unit/flows/__pycache__/test_ingestion_flow.cpython-312-pytest-8.4.2.pyc
index 34f8b12..c10e2da 100644
Binary files a/tests/unit/flows/__pycache__/test_ingestion_flow.cpython-312-pytest-8.4.2.pyc and b/tests/unit/flows/__pycache__/test_ingestion_flow.cpython-312-pytest-8.4.2.pyc differ
diff --git a/tests/unit/flows/__pycache__/test_scheduler.cpython-312-pytest-8.4.2.pyc b/tests/unit/flows/__pycache__/test_scheduler.cpython-312-pytest-8.4.2.pyc
index f1851be..775e9cc 100644
Binary files a/tests/unit/flows/__pycache__/test_scheduler.cpython-312-pytest-8.4.2.pyc and b/tests/unit/flows/__pycache__/test_scheduler.cpython-312-pytest-8.4.2.pyc differ
diff --git a/tests/unit/flows/test_ingestion_flow.py b/tests/unit/flows/test_ingestion_flow.py
index 80c01bd..252d6d3 100644
--- a/tests/unit/flows/test_ingestion_flow.py
+++ b/tests/unit/flows/test_ingestion_flow.py
@@ -6,9 +6,11 @@ from unittest.mock import AsyncMock
import pytest
from ingest_pipeline.core.models import (
+ Document,
IngestionJob,
IngestionSource,
StorageBackend,
+ StorageConfig,
)
from ingest_pipeline.flows import ingestion
from ingest_pipeline.flows.ingestion import (
@@ -18,6 +20,7 @@ from ingest_pipeline.flows.ingestion import (
filter_existing_documents_task,
ingest_documents_task,
)
+from ingest_pipeline.storage.base import BaseStorage
@pytest.mark.parametrize(
@@ -52,15 +55,34 @@ async def test_filter_existing_documents_task_filters_known_urls(
new_url = "https://new.example.com"
existing_id = str(FirecrawlIngestor.compute_document_id(existing_url))
- class StubStorage:
- display_name = "stub-storage"
+ class StubStorage(BaseStorage):
+ def __init__(self) -> None:
+ super().__init__(StorageConfig(backend=StorageBackend.WEAVIATE, endpoint="http://test.local"))
+
+ @property
+ def display_name(self) -> str:
+ return "stub-storage"
+
+ async def initialize(self) -> None:
+ pass
+
+ async def store(self, document: Document, *, collection_name: str | None = None) -> str:
+ return "stub-id"
+
+ async def store_batch(
+ self, documents: list[Document], *, collection_name: str | None = None
+ ) -> list[str]:
+ return ["stub-id"] * len(documents)
+
+ async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool:
+ return True
async def check_exists(
self,
document_id: str,
*,
collection_name: str | None = None,
- stale_after_days: int,
+ stale_after_days: int = 30,
) -> bool:
return document_id == existing_id
@@ -102,7 +124,9 @@ async def test_ingest_documents_task_invokes_helpers(
) -> object:
return storage_sentinel
- async def fake_create_ingestor(job_arg: IngestionJob, config_block_name: str | None = None) -> object:
+ async def fake_create_ingestor(
+ job_arg: IngestionJob, config_block_name: str | None = None
+ ) -> object:
return ingestor_sentinel
monkeypatch.setattr(ingestion, "_create_ingestor", fake_create_ingestor)
diff --git a/tests/unit/flows/test_scheduler.py b/tests/unit/flows/test_scheduler.py
index ed3b8ee..6ea8e1c 100644
--- a/tests/unit/flows/test_scheduler.py
+++ b/tests/unit/flows/test_scheduler.py
@@ -4,6 +4,8 @@ from datetime import timedelta
from types import SimpleNamespace
import pytest
+from prefect.deployments.runner import RunnerDeployment
+from prefect.schedules import Cron, Interval
from ingest_pipeline.flows import scheduler
@@ -17,7 +19,6 @@ def test_create_scheduled_deployment_cron(monkeypatch: pytest.MonkeyPatch) -> No
captured |= kwargs
return SimpleNamespace(**kwargs)
-
monkeypatch.setattr(scheduler, "create_ingestion_flow", DummyFlow())
deployment = scheduler.create_scheduled_deployment(
@@ -28,8 +29,14 @@ def test_create_scheduled_deployment_cron(monkeypatch: pytest.MonkeyPatch) -> No
cron_expression="0 * * * *",
)
- assert captured["schedule"].cron == "0 * * * *"
- assert captured["parameters"]["source_type"] == "web"
+ schedule = captured["schedule"]
+ # Check that it's a cron schedule by verifying it has a cron attribute
+ assert hasattr(schedule, "cron")
+ assert schedule.cron == "0 * * * *"
+
+ parameters = captured["parameters"]
+ assert isinstance(parameters, dict)
+ assert parameters["source_type"] == "web"
assert deployment.tags == ["web", "weaviate"]
@@ -42,7 +49,6 @@ def test_create_scheduled_deployment_interval(monkeypatch: pytest.MonkeyPatch) -
captured |= kwargs
return SimpleNamespace(**kwargs)
-
monkeypatch.setattr(scheduler, "create_ingestion_flow", DummyFlow())
deployment = scheduler.create_scheduled_deployment(
@@ -55,7 +61,11 @@ def test_create_scheduled_deployment_interval(monkeypatch: pytest.MonkeyPatch) -
tags=["custom"],
)
- assert captured["schedule"].interval == timedelta(minutes=15)
+ schedule = captured["schedule"]
+ # Check that it's an interval schedule by verifying it has an interval attribute
+ assert hasattr(schedule, "interval")
+ assert schedule.interval == timedelta(minutes=15)
+
assert captured["tags"] == ["custom"]
assert deployment.parameters["storage_backend"] == "open_webui"
@@ -63,12 +73,13 @@ def test_create_scheduled_deployment_interval(monkeypatch: pytest.MonkeyPatch) -
def test_serve_deployments_invokes_prefect(monkeypatch: pytest.MonkeyPatch) -> None:
called: dict[str, object] = {}
- def fake_serve(*deployments: object, limit: int) -> None:
+ def fake_serve(*deployments: RunnerDeployment, limit: int) -> None:
called["deployments"] = deployments
called["limit"] = limit
monkeypatch.setattr(scheduler, "prefect_serve", fake_serve)
+ # Create a mock deployment using SimpleNamespace to avoid Prefect complexity
deployment = SimpleNamespace(name="only")
scheduler.serve_deployments([deployment])
diff --git a/tests/unit/ingestors/__pycache__/__init__.cpython-312.pyc b/tests/unit/ingestors/__pycache__/__init__.cpython-312.pyc
index cdae77c..fa23fc7 100644
Binary files a/tests/unit/ingestors/__pycache__/__init__.cpython-312.pyc and b/tests/unit/ingestors/__pycache__/__init__.cpython-312.pyc differ
diff --git a/tests/unit/ingestors/__pycache__/test_firecrawl_ingestor.cpython-312-pytest-8.4.2.pyc b/tests/unit/ingestors/__pycache__/test_firecrawl_ingestor.cpython-312-pytest-8.4.2.pyc
index 5e4e769..94ca013 100644
Binary files a/tests/unit/ingestors/__pycache__/test_firecrawl_ingestor.cpython-312-pytest-8.4.2.pyc and b/tests/unit/ingestors/__pycache__/test_firecrawl_ingestor.cpython-312-pytest-8.4.2.pyc differ
diff --git a/tests/unit/ingestors/__pycache__/test_repomix_ingestor.cpython-312-pytest-8.4.2.pyc b/tests/unit/ingestors/__pycache__/test_repomix_ingestor.cpython-312-pytest-8.4.2.pyc
index 10b949b..1a9cfae 100644
Binary files a/tests/unit/ingestors/__pycache__/test_repomix_ingestor.cpython-312-pytest-8.4.2.pyc and b/tests/unit/ingestors/__pycache__/test_repomix_ingestor.cpython-312-pytest-8.4.2.pyc differ
diff --git a/tests/unit/ingestors/test_repomix_ingestor.py b/tests/unit/ingestors/test_repomix_ingestor.py
index 3090a6a..41f05a3 100644
--- a/tests/unit/ingestors/test_repomix_ingestor.py
+++ b/tests/unit/ingestors/test_repomix_ingestor.py
@@ -9,7 +9,7 @@ from ingest_pipeline.ingestors.repomix import RepomixIngestor
@pytest.mark.parametrize(
("content", "expected_keys"),
(
- ("## File: src/app.py\nprint(\"hi\")", ["src/app.py"]),
+ ('## File: src/app.py\nprint("hi")', ["src/app.py"]),
("plain content without markers", ["repository"]),
),
)
@@ -44,7 +44,7 @@ def test_chunk_content_respects_max_size(
("file_path", "content", "expected"),
(
("src/app.py", "def feature():\n return True", "python"),
- ("scripts/run", "#!/usr/bin/env python\nprint(\"ok\")", "python"),
+ ("scripts/run", '#!/usr/bin/env python\nprint("ok")', "python"),
("documentation.md", "# Title", "markdown"),
("unknown.ext", "text", None),
),
@@ -81,5 +81,8 @@ def test_create_document_enriches_metadata() -> None:
assert document.metadata["repository_name"] == "demo"
assert document.metadata["branch_name"] == "main"
assert document.metadata["commit_hash"] == "deadbeef"
- assert document.metadata["title"].endswith("(chunk 1)")
+
+ title = document.metadata["title"]
+ assert title is not None
+ assert title.endswith("(chunk 1)")
assert document.collection == job.storage_backend.value
diff --git a/tests/unit/storage/__pycache__/__init__.cpython-312.pyc b/tests/unit/storage/__pycache__/__init__.cpython-312.pyc
index 80dba8b..a0f488d 100644
Binary files a/tests/unit/storage/__pycache__/__init__.cpython-312.pyc and b/tests/unit/storage/__pycache__/__init__.cpython-312.pyc differ
diff --git a/tests/unit/storage/__pycache__/test_base_storage.cpython-312-pytest-8.4.2.pyc b/tests/unit/storage/__pycache__/test_base_storage.cpython-312-pytest-8.4.2.pyc
index f7e88c5..122b065 100644
Binary files a/tests/unit/storage/__pycache__/test_base_storage.cpython-312-pytest-8.4.2.pyc and b/tests/unit/storage/__pycache__/test_base_storage.cpython-312-pytest-8.4.2.pyc differ
diff --git a/tests/unit/storage/__pycache__/test_openwebui.cpython-312-pytest-8.4.2.pyc b/tests/unit/storage/__pycache__/test_openwebui.cpython-312-pytest-8.4.2.pyc
index 051c1f3..e3331f0 100644
Binary files a/tests/unit/storage/__pycache__/test_openwebui.cpython-312-pytest-8.4.2.pyc and b/tests/unit/storage/__pycache__/test_openwebui.cpython-312-pytest-8.4.2.pyc differ
diff --git a/tests/unit/storage/__pycache__/test_r2r_helpers.cpython-312-pytest-8.4.2.pyc b/tests/unit/storage/__pycache__/test_r2r_helpers.cpython-312-pytest-8.4.2.pyc
index 6728bad..61b9b51 100644
Binary files a/tests/unit/storage/__pycache__/test_r2r_helpers.cpython-312-pytest-8.4.2.pyc and b/tests/unit/storage/__pycache__/test_r2r_helpers.cpython-312-pytest-8.4.2.pyc differ
diff --git a/tests/unit/storage/__pycache__/test_weaviate_helpers.cpython-312-pytest-8.4.2.pyc b/tests/unit/storage/__pycache__/test_weaviate_helpers.cpython-312-pytest-8.4.2.pyc
index 8c7cdf2..7d35b90 100644
Binary files a/tests/unit/storage/__pycache__/test_weaviate_helpers.cpython-312-pytest-8.4.2.pyc and b/tests/unit/storage/__pycache__/test_weaviate_helpers.cpython-312-pytest-8.4.2.pyc differ
diff --git a/tests/unit/storage/conftest.py b/tests/unit/storage/conftest.py
new file mode 100644
index 0000000..eed606e
--- /dev/null
+++ b/tests/unit/storage/conftest.py
@@ -0,0 +1,41 @@
+from __future__ import annotations
+
+import importlib
+import sys
+import types
+
+import pytest
+
+
+@pytest.fixture(scope="function")
+def storage_module() -> types.ModuleType:
+ return importlib.import_module("ingest_pipeline.storage")
+
+
+@pytest.fixture(scope="function")
+def storage_module_without_r2r(request: pytest.FixtureRequest) -> types.ModuleType:
+ original_package = sys.modules.pop("ingest_pipeline.storage.r2r", None)
+ original_storage = sys.modules.pop("ingest_pipeline.storage.r2r.storage", None)
+
+ dummy_package = types.ModuleType("ingest_pipeline.storage.r2r")
+ dummy_package.__path__ = []
+ dummy_storage = types.ModuleType("ingest_pipeline.storage.r2r.storage")
+
+ sys.modules["ingest_pipeline.storage.r2r"] = dummy_package
+ sys.modules["ingest_pipeline.storage.r2r.storage"] = dummy_storage
+
+ module = importlib.reload(importlib.import_module("ingest_pipeline.storage"))
+
+ def _restore() -> None:
+ if original_package is not None:
+ sys.modules["ingest_pipeline.storage.r2r"] = original_package
+ else:
+ sys.modules.pop("ingest_pipeline.storage.r2r", None)
+ if original_storage is not None:
+ sys.modules["ingest_pipeline.storage.r2r.storage"] = original_storage
+ else:
+ sys.modules.pop("ingest_pipeline.storage.r2r.storage", None)
+ importlib.reload(importlib.import_module("ingest_pipeline.storage"))
+
+ request.addfinalizer(_restore)
+ return module
diff --git a/tests/unit/storage/test_base_storage.py b/tests/unit/storage/test_base_storage.py
index 6e9a860..7fffc8b 100644
--- a/tests/unit/storage/test_base_storage.py
+++ b/tests/unit/storage/test_base_storage.py
@@ -69,7 +69,9 @@ async def test_check_exists_uses_staleness(document_factory, storage_config, del
"""Return True only when document timestamp is within freshness window."""
timestamp = datetime.now(UTC) - timedelta(days=delta_days)
- document = document_factory(content=f"doc-{delta_days}", metadata_updates={"timestamp": timestamp})
+ document = document_factory(
+ content=f"doc-{delta_days}", metadata_updates={"timestamp": timestamp}
+ )
storage = StubStorage(storage_config, result=document)
outcome = await storage.check_exists("identifier", stale_after_days=30)
diff --git a/tests/unit/storage/test_openwebui.py b/tests/unit/storage/test_openwebui.py
index 4055bbf..42f0b8e 100644
--- a/tests/unit/storage/test_openwebui.py
+++ b/tests/unit/storage/test_openwebui.py
@@ -99,7 +99,9 @@ async def test_store_batch_handles_multiple_documents(
file_ids = await storage.store_batch([first, second])
assert len(file_ids) == 2
- files_payloads: list[Any] = [request["files"] for request in httpx_stub.requests if request["method"] == "POST"]
+ files_payloads: list[Any] = [
+ request["files"] for request in httpx_stub.requests if request["method"] == "POST"
+ ]
assert any(payload is not None for payload in files_payloads)
knowledge_entry = openwebui_service.find_knowledge_by_name(storage_config.collection_name)
assert knowledge_entry is not None
diff --git a/tests/unit/storage/test_package_init.py b/tests/unit/storage/test_package_init.py
new file mode 100644
index 0000000..bd8424d
--- /dev/null
+++ b/tests/unit/storage/test_package_init.py
@@ -0,0 +1,19 @@
+from __future__ import annotations
+
+import pytest
+
+
+@pytest.mark.parametrize(
+ "symbol_name",
+ ("BaseStorage", "WeaviateStorage", "OpenWebUIStorage", "R2RStorage"),
+)
+def test_storage_exports_expected_symbols(storage_module, symbol_name: str) -> None:
+ assert symbol_name in storage_module.__all__
+
+
+def test_storage_r2r_available_by_default(storage_module) -> None:
+ assert storage_module.R2RStorage is not None
+
+
+def test_storage_optional_import_fallback(storage_module_without_r2r) -> None:
+ assert storage_module_without_r2r.R2RStorage is None
diff --git a/tests/unit/storage/test_r2r_helpers.py b/tests/unit/storage/test_r2r_helpers.py
index 6ef2af7..e6220a9 100644
--- a/tests/unit/storage/test_r2r_helpers.py
+++ b/tests/unit/storage/test_r2r_helpers.py
@@ -5,6 +5,7 @@ from datetime import UTC, datetime
from types import SimpleNamespace
from typing import Any, Self
+import httpx
import pytest
from ingest_pipeline.core.models import StorageConfig
@@ -23,7 +24,6 @@ def r2r_client_stub(
monkeypatch: pytest.MonkeyPatch,
r2r_service,
) -> object:
-
class DummyR2RException(Exception):
def __init__(self, message: str, status_code: int | None = None) -> None:
super().__init__(message)
@@ -39,10 +39,13 @@ def r2r_client_stub(
def raise_for_status(self) -> None:
if self.status_code >= 400:
- from httpx import HTTPStatusError
- raise HTTPStatusError("HTTP error", request=None, response=self)
-
-
+ # Create minimal mock request and response for HTTPStatusError
+ mock_request = httpx.Request("GET", "http://test.local")
+ mock_response = httpx.Response(
+ status_code=self.status_code,
+ request=mock_request,
+ )
+ raise httpx.HTTPStatusError("HTTP error", request=mock_request, response=mock_response)
class MockAsyncClient:
def __init__(self, service: Any) -> None:
@@ -53,15 +56,23 @@ def r2r_client_stub(
# Return existing collections
collections = []
for collection_id, collection_data in self._service._collections.items():
- collections.append({
- "id": collection_id,
- "name": collection_data["name"],
- "description": collection_data.get("description", ""),
- })
+ collections.append(
+ {
+ "id": collection_id,
+ "name": collection_data["name"],
+ "description": collection_data.get("description", ""),
+ }
+ )
return MockResponse({"results": collections})
return MockResponse({})
- async def post(self, url: str, *, json: dict[str, Any] | None = None, files: dict[str, Any] | None = None) -> MockResponse:
+ async def post(
+ self,
+ url: str,
+ *,
+ json: dict[str, Any] | None = None,
+ files: dict[str, Any] | None = None,
+ ) -> MockResponse:
if "/v3/collections" in url and json:
return self._handle_collection_creation(json)
if "/v3/documents" in url and files:
@@ -76,13 +87,15 @@ def r2r_client_stub(
collection_id=new_collection_id,
description=json.get("description", ""),
)
- return MockResponse({
- "results": {
- "id": new_collection_id,
- "name": json["name"],
- "description": json.get("description", ""),
+ return MockResponse(
+ {
+ "results": {
+ "id": new_collection_id,
+ "name": json["name"],
+ "description": json.get("description", ""),
+ }
}
- })
+ )
def _handle_document_creation(self, files: dict[str, Any]) -> MockResponse:
"""Handle document creation via POST with files."""
@@ -102,12 +115,14 @@ def r2r_client_stub(
# Update collection document count if needed
self._update_collection_document_count(files)
- return MockResponse({
- "results": {
- "document_id": document_id,
- "message": "Document created successfully",
+ return MockResponse(
+ {
+ "results": {
+ "document_id": document_id,
+ "message": "Document created successfully",
+ }
}
- })
+ )
def _extract_document_id(self, files: dict[str, Any]) -> str:
"""Extract document ID from files."""
@@ -120,6 +135,7 @@ def r2r_client_stub(
def _extract_metadata(self, files: dict[str, Any]) -> dict[str, Any]:
"""Extract and parse metadata from files."""
import json as json_lib
+
metadata_str = files.get("metadata", (None, "{}"))[1]
return json_lib.loads(metadata_str) if metadata_str else {}
@@ -137,9 +153,15 @@ def r2r_client_stub(
"""Parse collection IDs from files entry."""
import json as json_lib
- collection_ids_str = collection_ids[1] if isinstance(collection_ids, tuple) else collection_ids
+ collection_ids_str = (
+ collection_ids[1] if isinstance(collection_ids, tuple) else collection_ids
+ )
try:
- collection_list = json_lib.loads(collection_ids_str) if isinstance(collection_ids_str, str) else collection_ids_str
+ collection_list = (
+ json_lib.loads(collection_ids_str)
+ if isinstance(collection_ids_str, str)
+ else collection_ids_str
+ )
if isinstance(collection_list, list) and collection_list:
first_collection = collection_list[0]
if isinstance(first_collection, dict) and "id" in first_collection:
@@ -159,7 +181,6 @@ def r2r_client_stub(
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
return None
-
class DocumentsAPI:
def __init__(self, service: Any) -> None:
self._service = service
@@ -186,10 +207,7 @@ def r2r_client_stub(
self._service = service
async def search(self, query: str, search_settings: Mapping[str, Any]) -> dict[str, Any]:
- results = [
- {"document_id": doc_id, "score": 1.0}
- for doc_id in self._service._documents
- ]
+ results = [{"document_id": doc_id, "score": 1.0} for doc_id in self._service._documents]
return {"results": results}
class DummyClient:
diff --git a/tests/unit/tui/__pycache__/__init__.cpython-312.pyc b/tests/unit/tui/__pycache__/__init__.cpython-312.pyc
index 0316e93..4568bc8 100644
Binary files a/tests/unit/tui/__pycache__/__init__.cpython-312.pyc and b/tests/unit/tui/__pycache__/__init__.cpython-312.pyc differ
diff --git a/tests/unit/tui/__pycache__/test_dashboard_screen.cpython-312-pytest-8.4.2.pyc b/tests/unit/tui/__pycache__/test_dashboard_screen.cpython-312-pytest-8.4.2.pyc
index 0e464ef..b0e86f2 100644
Binary files a/tests/unit/tui/__pycache__/test_dashboard_screen.cpython-312-pytest-8.4.2.pyc and b/tests/unit/tui/__pycache__/test_dashboard_screen.cpython-312-pytest-8.4.2.pyc differ
diff --git a/tests/unit/tui/__pycache__/test_storage_manager.cpython-312-pytest-8.4.2.pyc b/tests/unit/tui/__pycache__/test_storage_manager.cpython-312-pytest-8.4.2.pyc
index 4ce1491..d325884 100644
Binary files a/tests/unit/tui/__pycache__/test_storage_manager.cpython-312-pytest-8.4.2.pyc and b/tests/unit/tui/__pycache__/test_storage_manager.cpython-312-pytest-8.4.2.pyc differ
diff --git a/tests/unit/tui/test_dashboard_screen.py b/tests/unit/tui/test_dashboard_screen.py
index a919494..8ae15f2 100644
--- a/tests/unit/tui/test_dashboard_screen.py
+++ b/tests/unit/tui/test_dashboard_screen.py
@@ -12,6 +12,7 @@ from ingest_pipeline.cli.tui.screens.documents import DocumentManagementScreen
from ingest_pipeline.cli.tui.widgets.tables import EnhancedDataTable
from ingest_pipeline.core.models import Document, StorageBackend, StorageConfig
from ingest_pipeline.storage.base import BaseStorage
+from ingest_pipeline.storage.types import DocumentInfo
if TYPE_CHECKING:
from ingest_pipeline.cli.tui.utils.storage_manager import StorageManager
@@ -35,7 +36,9 @@ class StorageManagerStub:
backends: dict[StorageBackend, BaseStorage]
is_initialized: bool
- def __init__(self, collections: list[CollectionInfo], backends: dict[StorageBackend, BaseStorage]) -> None:
+ def __init__(
+ self, collections: list[CollectionInfo], backends: dict[StorageBackend, BaseStorage]
+ ) -> None:
self._collections = collections
self._available = list(backends.keys())
self.backends = backends
@@ -71,13 +74,19 @@ class PassiveStorage(BaseStorage):
async def initialize(self) -> None:
return None
- async def store(self, document: Document, *, collection_name: str | None = None) -> str: # pragma: no cover - unused
+ async def store(
+ self, document: Document, *, collection_name: str | None = None
+ ) -> str: # pragma: no cover - unused
raise NotImplementedError
- async def store_batch(self, documents: list[Document], *, collection_name: str | None = None) -> list[str]: # pragma: no cover - unused
+ async def store_batch(
+ self, documents: list[Document], *, collection_name: str | None = None
+ ) -> list[str]: # pragma: no cover - unused
raise NotImplementedError
- async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool: # pragma: no cover - unused
+ async def delete(
+ self, document_id: str, *, collection_name: str | None = None
+ ) -> bool: # pragma: no cover - unused
raise NotImplementedError
@@ -165,6 +174,7 @@ async def test_collection_overview_table_reflects_collections() -> None:
assert screen.total_documents == 1292
assert screen.active_backends == 2
+
class DocumentStorageStub(BaseStorage):
def __init__(self) -> None:
super().__init__(
@@ -178,16 +188,24 @@ class DocumentStorageStub(BaseStorage):
async def initialize(self) -> None:
return None
- async def store(self, document: Document, *, collection_name: str | None = None) -> str: # pragma: no cover - unused
+ async def store(
+ self, document: Document, *, collection_name: str | None = None
+ ) -> str: # pragma: no cover - unused
raise NotImplementedError
- async def store_batch(self, documents: list[Document], *, collection_name: str | None = None) -> list[str]: # pragma: no cover - unused
+ async def store_batch(
+ self, documents: list[Document], *, collection_name: str | None = None
+ ) -> list[str]: # pragma: no cover - unused
raise NotImplementedError
- async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool: # pragma: no cover - unused
+ async def delete(
+ self, document_id: str, *, collection_name: str | None = None
+ ) -> bool: # pragma: no cover - unused
raise NotImplementedError
- async def list_documents(self, limit: int = 100, offset: int = 0, *, collection_name: str | None = None) -> list[dict[str, object]]:
+ async def list_documents(
+ self, limit: int = 100, offset: int = 0, *, collection_name: str | None = None
+ ) -> list[DocumentInfo]:
return [
{
"id": "doc-1234567890",
diff --git a/tests/unit/tui/test_storage_manager.py b/tests/unit/tui/test_storage_manager.py
index f95833d..e506f4a 100644
--- a/tests/unit/tui/test_storage_manager.py
+++ b/tests/unit/tui/test_storage_manager.py
@@ -4,14 +4,23 @@ from types import SimpleNamespace
import pytest
-from ingest_pipeline.cli.tui.utils.storage_manager import MultiStorageAdapter, StorageManager
+from ingest_pipeline.cli.tui.utils.storage_manager import (
+ MultiStorageAdapter,
+ StorageCapabilities,
+ StorageManager,
+)
+from typing import cast
+
+from ingest_pipeline.config.settings import Settings
from ingest_pipeline.core.exceptions import StorageError
from ingest_pipeline.core.models import Document, StorageBackend, StorageConfig
from ingest_pipeline.storage.base import BaseStorage
class StubStorage(BaseStorage):
- def __init__(self, config: StorageConfig, *, documents: list[Document] | None = None, fail: bool = False) -> None:
+ def __init__(
+ self, config: StorageConfig, *, documents: list[Document] | None = None, fail: bool = False
+ )) -> None:
super().__init__(config)
self.documents = documents or []
self.fail = fail
@@ -26,7 +35,9 @@ class StubStorage(BaseStorage):
raise RuntimeError("store failed")
return f"{self.config.backend.value}-single"
- async def store_batch(self, documents: list[Document], *, collection_name: str | None = None) -> list[str]:
+ async def store_batch(
+ self, documents: list[Document], *, collection_name: str | None = None
+ )) -> list[str]:
self.stored.extend(documents)
if self.fail:
raise RuntimeError("batch failed")
@@ -50,7 +61,7 @@ class StubStorage(BaseStorage):
threshold: float = 0.7,
*,
collection_name: str | None = None,
- ):
+ )):
for document in self.documents:
yield document
@@ -58,18 +69,58 @@ class StubStorage(BaseStorage):
return None
+class CollectionStubStorage(StubStorage):
+ def __init__(
+ self,
+ config: StorageConfig,
+ *,
+ collections: list[str],
+ counts: dict[str, int],
+ )) -> None:
+ super().__init__(config)
+ self.collections = collections
+ self.counts = counts
+
+ async def list_collections(self) -> list[str]:
+ return self.collections
+
+ async def count(self, *, collection_name: str | None = None) -> int:
+ if collection_name is None:
+ raise ValueError("collection name required")
+ return self.counts[collection_name]
+
+
+class FailingStatusStorage(StubStorage):
+ async def list_collections(self) -> list[str]:
+ raise RuntimeError("status unavailable")
+
+
+class ClosableStubStorage(StubStorage):
+ def __init__(self, config: StorageConfig) -> None:
+ super().__init__(config)
+ self.closed = False
+
+ async def close(self) -> None:
+ self.closed = True
+
+
+class FailingCloseStorage(StubStorage):
+ async def close(self) -> None:
+ raise RuntimeError("close failure")
+
+
@pytest.mark.asyncio
async def test_multi_storage_adapter_reports_replication_failure(document_factory) -> None:
primary_config = StorageConfig(
backend=StorageBackend.WEAVIATE,
endpoint="http://weaviate.local",
collection_name="primary",
- )
+ ))
secondary_config = StorageConfig(
backend=StorageBackend.OPEN_WEBUI,
endpoint="http://chat.local",
collection_name="secondary",
- )
+ ))
primary = StubStorage(primary_config)
secondary = StubStorage(secondary_config, fail=True)
@@ -82,33 +133,33 @@ async def test_multi_storage_adapter_reports_replication_failure(document_factor
def test_storage_manager_build_multi_storage_adapter_deduplicates(document_factory) -> None:
- settings = SimpleNamespace(
+ settings = cast(Settings, SimpleNamespace(
weaviate_endpoint="http://weaviate.local",
weaviate_api_key=None,
openwebui_endpoint="http://chat.local",
openwebui_api_key=None,
r2r_endpoint=None,
r2r_api_key=None,
- )
+ )))
manager = StorageManager(settings)
weaviate_config = StorageConfig(
backend=StorageBackend.WEAVIATE,
endpoint="http://weaviate.local",
collection_name="primary",
- )
+ ))
openwebui_config = StorageConfig(
backend=StorageBackend.OPEN_WEBUI,
endpoint="http://chat.local",
collection_name="secondary",
- )
+ ))
manager.backends[StorageBackend.WEAVIATE] = StubStorage(weaviate_config)
manager.backends[StorageBackend.OPEN_WEBUI] = StubStorage(openwebui_config)
adapter = manager.build_multi_storage_adapter(
[StorageBackend.WEAVIATE, StorageBackend.WEAVIATE, StorageBackend.OPEN_WEBUI]
- )
+ ))
assert len(adapter._storages) == 2
assert adapter._storages[0].config.backend == StorageBackend.WEAVIATE
@@ -116,14 +167,14 @@ def test_storage_manager_build_multi_storage_adapter_deduplicates(document_facto
def test_storage_manager_build_multi_storage_adapter_missing_backend() -> None:
- settings = SimpleNamespace(
+ settings = cast(Settings, SimpleNamespace(
weaviate_endpoint="http://weaviate.local",
weaviate_api_key=None,
openwebui_endpoint="http://chat.local",
openwebui_api_key=None,
r2r_endpoint=None,
r2r_api_key=None,
- )
+ ))
manager = StorageManager(settings)
with pytest.raises(ValueError):
@@ -132,41 +183,294 @@ def test_storage_manager_build_multi_storage_adapter_missing_backend() -> None:
@pytest.mark.asyncio
async def test_storage_manager_search_across_backends_groups_results(document_factory) -> None:
- settings = SimpleNamespace(
+ settings = cast(Settings, SimpleNamespace(
weaviate_endpoint="http://weaviate.local",
weaviate_api_key=None,
openwebui_endpoint="http://chat.local",
openwebui_api_key=None,
r2r_endpoint=None,
r2r_api_key=None,
- )
+ ))
manager = StorageManager(settings)
- document_weaviate = document_factory(content="alpha", metadata_updates={"source_url": "https://alpha"})
- document_openwebui = document_factory(content="beta", metadata_updates={"source_url": "https://beta"})
+ document_weaviate = document_factory(
+ content="alpha", metadata_updates={"source_url": "https://alpha"}
+ ))
+ document_openwebui = document_factory(
+ content="beta", metadata_updates={"source_url": "https://beta"}
+ ))
manager.backends[StorageBackend.WEAVIATE] = StubStorage(
StorageConfig(
backend=StorageBackend.WEAVIATE,
endpoint="http://weaviate.local",
collection_name="primary",
- ),
+ )),
documents=[document_weaviate],
- )
+ ))
manager.backends[StorageBackend.OPEN_WEBUI] = StubStorage(
StorageConfig(
backend=StorageBackend.OPEN_WEBUI,
endpoint="http://chat.local",
collection_name="secondary",
- ),
+ )),
documents=[document_openwebui],
- )
+ ))
results = await manager.search_across_backends(
"query",
limit=5,
backends=[StorageBackend.WEAVIATE, StorageBackend.OPEN_WEBUI],
- )
+ ))
assert results[StorageBackend.WEAVIATE][0].content == "alpha"
assert results[StorageBackend.OPEN_WEBUI][0].content == "beta"
+
+
+@pytest.mark.asyncio
+async def test_multi_storage_adapter_store_batch_replicates_to_all_backends(document_factory) -> None:
+ primary_config = StorageConfig(
+ backend=StorageBackend.WEAVIATE,
+ endpoint="http://weaviate.local",
+ collection_name="primary",
+ ))
+ secondary_config = StorageConfig(
+ backend=StorageBackend.OPEN_WEBUI,
+ endpoint="http://chat.local",
+ collection_name="secondary",
+ ))
+
+ primary = StubStorage(primary_config)
+ secondary = StubStorage(secondary_config)
+ adapter = MultiStorageAdapter([primary, secondary, secondary])
+
+ first_document = document_factory(content="first")
+ second_document = document_factory(content="second")
+
+ document_ids = await adapter.store_batch([first_document, second_document])
+
+ assert document_ids == ["weaviate-0", "weaviate-1"]
+ assert adapter._storages[0] is primary
+ assert primary.stored[0].content == "first"
+ assert secondary.stored[1].content == "second"
+
+
+@pytest.mark.asyncio
+async def test_multi_storage_adapter_delete_reports_secondary_failures() -> None:
+ primary_config = StorageConfig(
+ backend=StorageBackend.WEAVIATE,
+ endpoint="http://weaviate.local",
+ collection_name="primary",
+ ))
+ secondary_config = StorageConfig(
+ backend=StorageBackend.OPEN_WEBUI,
+ endpoint="http://chat.local",
+ collection_name="secondary",
+ ))
+
+ primary = StubStorage(primary_config)
+ secondary = StubStorage(secondary_config, fail=True)
+ adapter = MultiStorageAdapter([primary, secondary])
+
+ with pytest.raises(StorageError) as exc_info:
+ await adapter.delete("identifier")
+
+ assert "open_webui" in str(exc_info.value)
+
+
+@pytest.mark.asyncio
+async def test_storage_manager_initialize_all_backends_registers_capabilities(monkeypatch) -> None:
+ settings = cast(Settings, SimpleNamespace(
+ weaviate_endpoint="http://weaviate.local",
+ weaviate_api_key="key",
+ openwebui_endpoint="http://chat.local",
+ openwebui_api_key="token",
+ r2r_endpoint="http://r2r.local",
+ r2r_api_key="secret",
+ ))
+ manager = StorageManager(settings)
+
+ monkeypatch.setattr(
+ "ingest_pipeline.cli.tui.utils.storage_manager.WeaviateStorage",
+ StubStorage,
+ ))
+ monkeypatch.setattr(
+ "ingest_pipeline.cli.tui.utils.storage_manager.OpenWebUIStorage",
+ StubStorage,
+ ))
+ monkeypatch.setattr(
+ "ingest_pipeline.cli.tui.utils.storage_manager.R2RStorage",
+ StubStorage,
+ ))
+
+ results = await manager.initialize_all_backends()
+
+ assert results[StorageBackend.WEAVIATE] is True
+ assert results[StorageBackend.OPEN_WEBUI] is True
+ assert results[StorageBackend.R2R] is True
+ assert manager.get_available_backends() == [
+ StorageBackend.WEAVIATE,
+ StorageBackend.OPEN_WEBUI,
+ StorageBackend.R2R,
+ ]
+ assert manager.capabilities[StorageBackend.WEAVIATE] == StorageCapabilities.VECTOR_SEARCH
+ assert manager.capabilities[StorageBackend.OPEN_WEBUI] == StorageCapabilities.KNOWLEDGE_BASE
+ assert manager.capabilities[StorageBackend.R2R] == StorageCapabilities.FULL_FEATURED
+ assert manager.supports_advanced_features(StorageBackend.R2R) is True
+ assert manager.supports_advanced_features(StorageBackend.WEAVIATE) is False
+ assert manager.is_initialized is True
+ assert isinstance(manager.get_backend(StorageBackend.R2R), StubStorage)
+
+
+@pytest.mark.asyncio
+async def test_storage_manager_initialize_all_backends_handles_missing_config() -> None:
+ settings = cast(Settings, SimpleNamespace(
+ weaviate_endpoint=None,
+ weaviate_api_key=None,
+ openwebui_endpoint="http://chat.local",
+ openwebui_api_key=None,
+ r2r_endpoint=None,
+ r2r_api_key=None,
+ ))
+ manager = StorageManager(settings)
+
+ results = await manager.initialize_all_backends()
+
+ assert results[StorageBackend.WEAVIATE] is False
+ assert results[StorageBackend.OPEN_WEBUI] is False
+ assert results[StorageBackend.R2R] is False
+ assert manager.get_available_backends() == []
+ assert manager.is_initialized is True
+
+
+@pytest.mark.asyncio
+async def test_storage_manager_get_all_collections_merges_counts_and_backends() -> None:
+ settings = cast(Settings, SimpleNamespace(
+ weaviate_endpoint="http://weaviate.local",
+ weaviate_api_key=None,
+ openwebui_endpoint="http://chat.local",
+ openwebui_api_key=None,
+ r2r_endpoint=None,
+ r2r_api_key=None,
+ ))
+ manager = StorageManager(settings)
+
+ weaviate_storage = CollectionStubStorage(
+ StorageConfig(
+ backend=StorageBackend.WEAVIATE,
+ endpoint="http://weaviate.local",
+ collection_name="shared",
+ )),
+ collections=["shared", ""],
+ counts={"shared": 2},
+ ))
+ openwebui_storage = CollectionStubStorage(
+ StorageConfig(
+ backend=StorageBackend.OPEN_WEBUI,
+ endpoint="http://chat.local",
+ collection_name="secondary",
+ )),
+ collections=["shared"],
+ counts={"shared": -1},
+ ))
+ manager.backends = {
+ StorageBackend.WEAVIATE: weaviate_storage,
+ StorageBackend.OPEN_WEBUI: openwebui_storage,
+ }
+
+ collections = await manager.get_all_collections()
+
+ assert len(collections) == 1
+ assert collections[0]["name"] == "shared"
+ assert collections[0]["count"] == 2
+ assert collections[0]["backend"] == ["weaviate", "open_webui"]
+ assert collections[0]["type"] == "weaviate"
+ assert collections[0]["size_mb"] == pytest.approx(0.02)
+
+
+@pytest.mark.asyncio
+async def test_storage_manager_get_backend_status_reports_failures() -> None:
+ settings = cast(Settings, SimpleNamespace(
+ weaviate_endpoint="http://weaviate.local",
+ weaviate_api_key=None,
+ openwebui_endpoint="http://chat.local",
+ openwebui_api_key=None,
+ r2r_endpoint=None,
+ r2r_api_key=None,
+ ))
+ manager = StorageManager(settings)
+
+ healthy_storage = CollectionStubStorage(
+ StorageConfig(
+ backend=StorageBackend.WEAVIATE,
+ endpoint="http://weaviate.local",
+ collection_name="primary",
+ )),
+ collections=["collection", "archive"],
+ counts={"collection": 2, "archive": 1},
+ ))
+ failing_storage = FailingStatusStorage(
+ StorageConfig(
+ backend=StorageBackend.OPEN_WEBUI,
+ endpoint="http://chat.local",
+ collection_name="secondary",
+ ))
+ ))
+ manager.backends = {
+ StorageBackend.WEAVIATE: healthy_storage,
+ StorageBackend.OPEN_WEBUI: failing_storage,
+ }
+ manager.capabilities[StorageBackend.WEAVIATE] = StorageCapabilities.VECTOR_SEARCH
+
+ status = await manager.get_backend_status()
+
+ assert status[StorageBackend.WEAVIATE]["available"] is True
+ assert status[StorageBackend.WEAVIATE]["collections"] == 2
+ assert status[StorageBackend.WEAVIATE]["total_documents"] == 3
+ assert status[StorageBackend.WEAVIATE]["capabilities"] == StorageCapabilities.VECTOR_SEARCH
+ assert str(status[StorageBackend.WEAVIATE]["endpoint"]) == "http://weaviate.local/"
+ assert status[StorageBackend.OPEN_WEBUI]["available"] is False
+ assert status[StorageBackend.OPEN_WEBUI]["capabilities"] == StorageCapabilities.NONE
+ assert "status unavailable" in str(status[StorageBackend.OPEN_WEBUI]["error"])
+
+
+@pytest.mark.asyncio
+async def test_storage_manager_close_all_clears_state() -> None:
+ settings = cast(Settings, SimpleNamespace(
+ weaviate_endpoint="http://weaviate.local",
+ weaviate_api_key=None,
+ openwebui_endpoint="http://chat.local",
+ openwebui_api_key=None,
+ r2r_endpoint=None,
+ r2r_api_key=None,
+ ))
+ manager = StorageManager(settings)
+
+ closable_storage = ClosableStubStorage(
+ StorageConfig(
+ backend=StorageBackend.WEAVIATE,
+ endpoint="http://weaviate.local",
+ collection_name="primary",
+ ))
+ ))
+ failing_close_storage = FailingCloseStorage(
+ StorageConfig(
+ backend=StorageBackend.OPEN_WEBUI,
+ endpoint="http://chat.local",
+ collection_name="secondary",
+ ))
+ ))
+ manager.backends = {
+ StorageBackend.WEAVIATE: closable_storage,
+ StorageBackend.OPEN_WEBUI: failing_close_storage,
+ }
+ manager.capabilities[StorageBackend.WEAVIATE] = StorageCapabilities.VECTOR_SEARCH
+ manager.capabilities[StorageBackend.OPEN_WEBUI] = StorageCapabilities.KNOWLEDGE_BASE
+ manager._initialized = True
+
+ await manager.close_all()
+
+ assert closable_storage.closed is True
+ assert manager.backends == {}
+ assert manager.capabilities == {}
+ assert manager.is_initialized is False
diff --git a/tests/unit/utils/__pycache__/__init__.cpython-312.pyc b/tests/unit/utils/__pycache__/__init__.cpython-312.pyc
index 481ffeb..7fb1140 100644
Binary files a/tests/unit/utils/__pycache__/__init__.cpython-312.pyc and b/tests/unit/utils/__pycache__/__init__.cpython-312.pyc differ
diff --git a/tests/unit/utils/__pycache__/test_metadata_tagger.cpython-312-pytest-8.4.2.pyc b/tests/unit/utils/__pycache__/test_metadata_tagger.cpython-312-pytest-8.4.2.pyc
index f5e51d2..ab4fce7 100644
Binary files a/tests/unit/utils/__pycache__/test_metadata_tagger.cpython-312-pytest-8.4.2.pyc and b/tests/unit/utils/__pycache__/test_metadata_tagger.cpython-312-pytest-8.4.2.pyc differ
diff --git a/tests/unit/utils/__pycache__/test_vectorizer.cpython-312-pytest-8.4.2.pyc b/tests/unit/utils/__pycache__/test_vectorizer.cpython-312-pytest-8.4.2.pyc
index d3d034c..d5aad3d 100644
Binary files a/tests/unit/utils/__pycache__/test_vectorizer.cpython-312-pytest-8.4.2.pyc and b/tests/unit/utils/__pycache__/test_vectorizer.cpython-312-pytest-8.4.2.pyc differ
diff --git a/tests/unit/utils/conftest.py b/tests/unit/utils/conftest.py
new file mode 100644
index 0000000..9bdce8b
--- /dev/null
+++ b/tests/unit/utils/conftest.py
@@ -0,0 +1,112 @@
+from __future__ import annotations
+
+import asyncio
+from collections.abc import Awaitable, Callable
+from typing import Protocol
+
+import pytest
+
+from ingest_pipeline.utils.async_helpers import AsyncTaskManager
+
+Delays = tuple[float, ...]
+Results = tuple[float, ...]
+
+
+class ProcessorFactory(Protocol):
+ def __call__(self, value: int) -> Awaitable[int]: ...
+
+
+@pytest.fixture(scope="function")
+def async_manager_factory() -> Callable[[int], AsyncTaskManager]:
+ def _factory(limit: int) -> AsyncTaskManager:
+ return AsyncTaskManager(max_concurrent=limit)
+
+ return _factory
+
+
+@pytest.fixture(scope="function")
+def concurrency_measure() -> Callable[[AsyncTaskManager, Delays], Awaitable[tuple[Results, int]]]:
+ async def _measure(manager: AsyncTaskManager, delays: Delays) -> tuple[Results, int]:
+ active = 0
+ peak = 0
+ lock = asyncio.Lock()
+
+ async def tracked(delay: float) -> float:
+ nonlocal active, peak
+ async with lock:
+ active += 1
+ peak = peak if peak >= active else active
+ await asyncio.sleep(delay)
+ async with lock:
+ active -= 1
+ return delay
+
+ tasks = tuple(tracked(delay) for delay in delays)
+ results = await manager.run_tasks(tasks)
+ # Filter out any exceptions and only keep float results
+ float_results = [r for r in results if isinstance(r, (int, float))]
+ return tuple(float_results), peak
+
+ return _measure
+
+
+@pytest.fixture(scope="function")
+def success_failure_coroutines() -> tuple[Awaitable[str], Awaitable[str], Awaitable[str]]:
+ async def success() -> str:
+ await asyncio.sleep(0)
+ return "ok"
+
+ async def failure() -> str:
+ await asyncio.sleep(0)
+ raise RuntimeError("boom")
+
+ return (success(), failure(), success())
+
+
+@pytest.fixture(scope="function")
+def batch_inputs() -> tuple[int, ...]:
+ return (1, 2, 3, 4, 5)
+
+
+@pytest.fixture(scope="function")
+def batch_processor_record() -> tuple[list[int], ProcessorFactory]:
+ recorded: list[int] = []
+
+ async def processor(value: int) -> int:
+ recorded.append(value)
+ await asyncio.sleep(0)
+ return value * 2
+
+ return recorded, processor
+
+
+@pytest.fixture(scope="function")
+def error_processor() -> ProcessorFactory:
+ async def processor(_: int) -> int:
+ await asyncio.sleep(0)
+ raise ValueError("failure")
+
+ # Type checker doesn't understand this function raises, which is expected for error testing
+ return processor # type: ignore[return-value]
+
+
+@pytest.fixture(scope="function")
+def semaphore_case() -> tuple[asyncio.Semaphore, Callable[[], Awaitable[str]], list[str]]:
+ semaphore = asyncio.Semaphore(0)
+ events: list[str] = []
+
+ async def tracked() -> str:
+ events.append("entered")
+ await asyncio.sleep(0)
+ events.append("completed")
+ return "done"
+
+ def factory() -> Awaitable[str]:
+ return tracked()
+
+ return semaphore, factory, events
+
+
+@pytest.fixture(scope="function")
+def batch_expected_results(batch_inputs: tuple[int, ...]) -> tuple[int, ...]:
+ return tuple(value * 2 for value in batch_inputs)
diff --git a/tests/unit/utils/test_async_helpers.py b/tests/unit/utils/test_async_helpers.py
new file mode 100644
index 0000000..0cce241
--- /dev/null
+++ b/tests/unit/utils/test_async_helpers.py
@@ -0,0 +1,85 @@
+from __future__ import annotations
+
+import asyncio
+from collections.abc import Callable
+
+import pytest
+
+from ingest_pipeline.utils.async_helpers import AsyncTaskManager, batch_process, run_with_semaphore
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ ("limit", "delays"),
+ [(2, (0.02, 0.01, 0.03, 0.02))],
+)
+async def test_async_task_manager_limits_concurrency(
+ async_manager_factory: Callable[[int], AsyncTaskManager],
+ concurrency_measure,
+ limit: int,
+ delays: tuple[float, ...],
+) -> None:
+ manager = async_manager_factory(limit)
+ results, peak = await concurrency_measure(manager, delays)
+
+ assert results == delays
+ assert peak <= limit
+
+
+@pytest.mark.asyncio
+async def test_async_task_manager_collects_exceptions(
+ async_manager_factory: Callable[[int], AsyncTaskManager],
+ success_failure_coroutines,
+) -> None:
+ manager = async_manager_factory(3)
+ outcomes = await manager.run_tasks(success_failure_coroutines, return_exceptions=True)
+
+ assert outcomes[0] == "ok"
+ assert isinstance(outcomes[1], RuntimeError)
+ assert outcomes[2] == "ok"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ ("batch_size", "max_concurrent"),
+ [(2, 1)],
+)
+async def test_batch_process_handles_multiple_batches(
+ batch_inputs,
+ batch_processor_record,
+ batch_expected_results,
+ batch_size: int,
+ max_concurrent: int,
+) -> None:
+ recorded, processor = batch_processor_record
+
+ results = await batch_process(
+ list(batch_inputs),
+ processor,
+ batch_size=batch_size,
+ max_concurrent=max_concurrent,
+ )
+
+ assert tuple(results) == batch_expected_results
+ assert recorded == list(batch_inputs)
+
+
+@pytest.mark.asyncio
+async def test_batch_process_propagates_errors(error_processor) -> None:
+ with pytest.raises(ValueError):
+ await batch_process([1, 2], error_processor, batch_size=1)
+
+
+@pytest.mark.asyncio
+async def test_run_with_semaphore_blocks_until_available(semaphore_case) -> None:
+ semaphore, factory, events = semaphore_case
+ pending = asyncio.create_task(run_with_semaphore(semaphore, factory()))
+
+ await asyncio.sleep(0.01)
+ assert events == []
+
+ semaphore.release()
+ result = await pending
+
+ assert result == "done"
+ assert events == ["entered", "completed"]
diff --git a/tests/unit/utils/test_metadata_tagger.py b/tests/unit/utils/test_metadata_tagger.py
index 6056a00..3012839 100644
--- a/tests/unit/utils/test_metadata_tagger.py
+++ b/tests/unit/utils/test_metadata_tagger.py
@@ -86,14 +86,10 @@ async def test_tag_batch_processes_documents(
"""Apply tagging to each document in order."""
first_payload = {
- "choices": [
- {"message": {"content": json.dumps({"summary": "First summary"})}}
- ]
+ "choices": [{"message": {"content": json.dumps({"summary": "First summary"})}}]
}
second_payload = {
- "choices": [
- {"message": {"content": json.dumps({"summary": "Second summary"})}}
- ]
+ "choices": [{"message": {"content": json.dumps({"summary": "Second summary"})}}]
}
httpx_stub.queue_json(first_payload)
httpx_stub.queue_json(second_payload)
@@ -113,11 +109,7 @@ async def test_tag_batch_processes_documents(
async def test_tag_document_raises_on_invalid_json(httpx_stub, document_factory) -> None:
"""Raise ingestion error when the LLM response is malformed."""
- payload = {
- "choices": [
- {"message": {"content": "not-json"}}
- ]
- }
+ payload = {"choices": [{"message": {"content": "not-json"}}]}
httpx_stub.queue_json(payload)
document = document_factory(content="broken payload")
diff --git a/typings/__init__.py b/typings/__init__.py
index 1efb184..e4997b7 100644
--- a/typings/__init__.py
+++ b/typings/__init__.py
@@ -5,9 +5,11 @@ from typing import TypedDict
class EmbeddingData(TypedDict):
"""Structure for embedding data from API response."""
+
embedding: list[float]
class EmbeddingResponse(TypedDict):
"""Structure for OpenAI-compatible embedding API response."""
+
data: list[EmbeddingData]
diff --git a/typings/__init__.pyi b/typings/__init__.pyi
index 929db06..06e608f 100644
--- a/typings/__init__.pyi
+++ b/typings/__init__.pyi
@@ -4,9 +4,10 @@ from typing import TypedDict
class EmbeddingData(TypedDict):
"""Structure for embedding data from API response."""
- embedding: list[float]
+ embedding: list[float]
class EmbeddingResponse(TypedDict):
"""Structure for OpenAI-compatible embedding API response."""
+
data: list[EmbeddingData]
diff --git a/typings/__pycache__/__init__.cpython-312.pyc b/typings/__pycache__/__init__.cpython-312.pyc
index 903cd04..a5e78ad 100644
Binary files a/typings/__pycache__/__init__.cpython-312.pyc and b/typings/__pycache__/__init__.cpython-312.pyc differ
diff --git a/typings/httpx.pyi b/typings/httpx.pyi
index 32fbc0a..6008c98 100644
--- a/typings/httpx.pyi
+++ b/typings/httpx.pyi
@@ -8,33 +8,17 @@ class Response:
def raise_for_status(self) -> None: ...
def json(self) -> dict[str, object]: ...
-
class AsyncClient:
"""Async HTTP client."""
def __init__(
- self,
- *,
- timeout: float | None = None,
- headers: dict[str, str] | None = None
+ self, *, timeout: float | None = None, headers: dict[str, str] | None = None
) -> None: ...
-
async def get(
- self,
- url: str,
- *,
- params: dict[str, object] | dict[str, int] | None = None
+ self, url: str, *, params: dict[str, object] | dict[str, int] | None = None
) -> Response: ...
-
- async def post(
- self,
- url: str,
- *,
- json: dict[str, object] | None = None
- ) -> Response: ...
-
+ async def post(self, url: str, *, json: dict[str, object] | None = None) -> Response: ...
async def aclose(self) -> None: ...
-
# Make AsyncClient available for import
__all__ = ["AsyncClient", "Response"]
diff --git a/typings/prefect.pyi b/typings/prefect.pyi
index 6ed4156..2f62582 100644
--- a/typings/prefect.pyi
+++ b/typings/prefect.pyi
@@ -1,11 +1,14 @@
"""Type stubs for Prefect to fix missing type information."""
-from typing import Any, Awaitable, Callable, TypeVar, overload
+from collections.abc import Awaitable, Callable
+from typing import Any, TypeVar, overload
+
from typing_extensions import ParamSpec
# Prefect-specific types for cache key functions
class TaskRunContext:
"""Prefect task run context."""
+
pass
P = ParamSpec("P")
@@ -37,7 +40,6 @@ def task(
retry_condition_fn: Any = None,
viz_return_value: Any = None,
) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]: ...
-
@overload
def task(
*,
@@ -64,17 +66,14 @@ def task(
retry_condition_fn: Any = None,
viz_return_value: Any = None,
) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
-
@overload
def task(
fn: Callable[P, Awaitable[T]],
) -> Callable[P, Awaitable[T]]: ...
-
@overload
def task(
fn: Callable[P, T],
) -> Callable[P, T]: ...
-
@overload
def flow(
*,
@@ -99,7 +98,6 @@ def flow(
on_cancellation: list[Any] | None = None,
on_running: list[Any] | None = None,
) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]: ...
-
@overload
def flow(
fn: Callable[P, Awaitable[T]],
@@ -111,4 +109,4 @@ class Logger:
def warning(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
def error(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
-def get_run_logger() -> Logger: ...
\ No newline at end of file
+def get_run_logger() -> Logger: ...
diff --git a/typings/prefect/blocks/core.pyi b/typings/prefect/blocks/core.pyi
index 44a878a..b0b4cbd 100644
--- a/typings/prefect/blocks/core.pyi
+++ b/typings/prefect/blocks/core.pyi
@@ -1,7 +1,6 @@
"""Type stubs for Prefect blocks core module."""
from typing import Any, ClassVar, TypeVar
-from typing_extensions import Self
T = TypeVar("T", bound="Block")
@@ -12,8 +11,6 @@ class Block:
@classmethod
async def aload(cls: type[T], name: str, validate: bool = True, client: Any = None) -> T: ...
-
@classmethod
def load(cls: type[T], name: str, validate: bool = True, client: Any = None) -> T: ...
-
- def save(self, name: str, overwrite: bool = False) -> None: ...
\ No newline at end of file
+ def save(self, name: str, overwrite: bool = False) -> None: ...
diff --git a/typings/r2r/__init__.pyi b/typings/r2r/__init__.pyi
index c4bce07..9309f29 100644
--- a/typings/r2r/__init__.pyi
+++ b/typings/r2r/__init__.pyi
@@ -5,18 +5,15 @@ from typing import Generic, Protocol, TypeVar, overload
_T_co = TypeVar("_T_co", covariant=True)
-
class SupportsModelDump(Protocol):
def model_dump(self) -> dict[str, object]: ...
-
class CollectionResponse(SupportsModelDump):
id: str
name: str
description: str | None
owner_id: str | None
-
class CollectionSequence(SupportsModelDump, Protocol):
@overload
def __getitem__(self, index: int) -> CollectionResponse: ...
@@ -26,64 +23,52 @@ class CollectionSequence(SupportsModelDump, Protocol):
def __len__(self) -> int: ...
def __iter__(self) -> Iterator[CollectionResponse]: ...
-
class DocumentSequence(SupportsModelDump, Protocol):
def model_dump(self) -> dict[str, object]: ...
-
class UserSequence(SupportsModelDump, Protocol):
def model_dump(self) -> dict[str, object]: ...
-
class GenericBooleanResponse(SupportsModelDump):
success: bool
-
class GenericMessageResponse(SupportsModelDump):
message: str | None
-
class DocumentCreateResponse(SupportsModelDump):
"""Response from document creation."""
+
id: str
message: str | None = None
-
class DocumentRetrieveResponse(SupportsModelDump):
"""Response from document retrieval."""
+
id: str
content: str | None = None
metadata: dict[str, object] | None = None
-
class SearchResult(SupportsModelDump):
"""Individual search result."""
+
id: str
content: str | None = None
metadata: dict[str, object] | None = None
score: float | None = None
-
class SearchResponse(SupportsModelDump):
"""Response from search operations."""
- results: list[SearchResult]
+ results: list[SearchResult]
class R2RResults(Generic[_T_co], SupportsModelDump):
results: _T_co
def model_dump(self) -> dict[str, object]: ...
-
class WrappedCollectionsResponse(R2RResults[CollectionSequence]): ...
-
-
class WrappedDocumentsResponse(R2RResults[DocumentSequence]): ...
-
-
class WrappedUsersResponse(R2RResults[UserSequence]): ...
-
class R2RException(Exception): ...
-
class R2RClientException(Exception): ...
class _SystemClient:
@@ -97,8 +82,12 @@ class _CollectionsClient:
limit: int | None = ...,
owner_only: bool | None = ...,
) -> WrappedCollectionsResponse: ...
- async def create(self, name: str, description: str | None = ...) -> R2RResults[CollectionResponse]: ...
- async def add_document(self, id: str, document_id: str) -> R2RResults[GenericMessageResponse]: ...
+ async def create(
+ self, name: str, description: str | None = ...
+ ) -> R2RResults[CollectionResponse]: ...
+ async def add_document(
+ self, id: str, document_id: str
+ ) -> R2RResults[GenericMessageResponse]: ...
async def delete(self, id: str) -> None: ...
async def retrieve(self, id: str) -> R2RResults[CollectionResponse]: ...
async def update(