diff --git a/.claude/settings.local.json b/.claude/settings.local.json index a3a1353..a4bbc33 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -8,7 +8,8 @@ "Bash(cat:*)", "mcp__firecrawl__firecrawl_search", "mcp__firecrawl__firecrawl_scrape", - "Bash(python:*)" + "Bash(python:*)", + "mcp__coder__coder_report_task" ], "deny": [], "ask": [] diff --git a/docs/feeds.md b/docs/feeds.md index b8c7f69..97f96fb 100644 --- a/docs/feeds.md +++ b/docs/feeds.md @@ -1,263 +1,255 @@ -## Codebase Analysis Report: RAG Manager Ingestion Pipeline +TL;DR / Highest‑impact fixes (do these first) -**Status:** Validated against current codebase implementation -**Target:** Enhanced implementation guidance for efficient agent execution +Event-loop blocking in the TUI (user-visible stutter): +time.sleep(2) is called at the end of an ingestion run on the IngestionScreen. Even though this runs in a thread worker, it blocks that worker and delays UI transition; prefer a scheduled UI callback instead. -This analysis has been validated against the actual codebase structure and provides implementation-specific details for executing recommended improvements. The codebase demonstrates solid architecture with clear separation of concerns between ingestion flows, storage adapters, and TUI components. +repomix-output (2) -### Architecture Overview -- **Storage Backends**: Weaviate, OpenWebUI, R2R with unified `BaseStorage` interface -- **TUI Framework**: Textual-based with reactive components and async worker patterns -- **Orchestration**: Prefect flows with retry logic and progress callbacks -- **Configuration**: Pydantic-based settings with environment variable support +Blocking Weaviate client calls inside async methods: +Several async methods in WeaviateStorage call the synchronous Weaviate client directly (connect(), collections.create, queries, inserts, deletes). Wrap those in asyncio.to_thread(...) (or equivalent) to avoid freezing the loop when these calls take time. -### Validated Implementation Analysis +Embedding at scale: use batch vectorization: +store_batch vectorizes each document one by one; you already have Vectorizer.vectorize_batch. Switching to batch reduces HTTP round trips and improves throughput under backpressure. -### 1. Bug Fixes & Potential Issues +Connection lifecycle: close the vectorizer client: +WeaviateStorage.close() closes the Weaviate client but not the httpx.AsyncClient inside the Vectorizer. Add an await to close it to prevent leaked connections/sockets under heavy usage. -These are areas where the code may not function as intended or could lead to errors. +Broadened exception handling in UI utilities: +Multiple places catch Exception broadly, making failures ambiguous and harder to surface to users (e.g., storage manager and list builders). Narrow these to expected exceptions and fail fast with user-friendly messages where appropriate. -*
- - HIGH PRIORITY: `R2RStorage.store_batch` inefficient looping (Lines 161-179) - +repomix-output (2) - * **File:** `ingest_pipeline/storage/r2r/storage.py:161-179` - * **Issue:** CONFIRMED - Method loops through documents calling `_store_single_document` individually - * **Impact:** ~5-10x performance degradation for batch operations - * **Implementation:** Check R2R v3 API for bulk endpoints; current implementation uses `/v3/documents` per document - * **Effort:** Medium (API research + refactor) - * **Priority:** High - affects all R2R ingestion workflows -
+Correctness / Bugs -*
- - MEDIUM PRIORITY: Mixed HTTP client usage in `R2RStorage` (Lines 80, 99, 258) - +Blocking sleep in TUI: +time.sleep(2) after posting notifications and before app.pop_screen() blocks the worker thread; use Textual timers instead. - * **File:** `ingest_pipeline/storage/r2r/storage.py:80,99,258` - * **Issue:** VALIDATED - Mixes `R2RAsyncClient` (line 80) with direct `httpx.AsyncClient` (lines 99, 258) - * **Specific Methods:** `initialize()`, `_ensure_collection()`, `_attempt_document_creation()` - * **Impact:** Inconsistent auth/header handling, connection pooling inefficiency - * **Implementation:** Extend `R2RAsyncClient` or create adapter pattern for missing endpoints - * **Test Coverage:** Check if affected methods have unit tests before refactoring - * **Effort:** Medium (requires SDK analysis) -
+repomix-output (2) -*
- - MEDIUM PRIORITY: TUI blocking during storage init (Line 91) - +Synchronous SDK in async contexts (Weaviate): - * **File:** `ingest_pipeline/cli/tui/utils/runners.py:91` - * **Issue:** CONFIRMED - `await storage_manager.initialize_all_backends()` blocks TUI startup - * **Current Implementation:** 30s timeout per backend in `StorageManager.initialize_all_backends()` - * **User Impact:** Frozen terminal for up to 90s if all backends timeout - * **Solution:** Move to `CollectionOverviewScreen.on_mount()` as `@work` task - * **Dependencies:** `dashboard.py:304` already has worker pattern for `refresh_collections` - * **Implementation:** Use existing loading indicators and status updates (lines 308-312) - * **Effort:** Low (pattern exists, needs relocation) -
+initialize() calls self.client.connect() (sync). Wrap with asyncio.to_thread(self.client.connect). -*
- - LOW PRIORITY: Weak URL validation in `IngestionScreen` (Lines 240-260) - +repomix-output (2) - * **File:** `ingest_pipeline/cli/tui/screens/ingestion.py:240-260` - * **Issue:** CONFIRMED - Method accepts `foo/bar` as valid (line 258) - * **Security Risk:** Medium - malicious URLs could be passed to ingestors - * **Current Logic:** Basic prefix checks only (http/https/file://) - * **Enhancement:** Add `pathlib.Path.exists()` for file:// paths, `.git` directory check for repos - * **Dependencies:** Import `pathlib` and add proper regex validation - * **Alternative:** Use `validators` library (not currently imported) - * **Effort:** Low (validation logic only) -
+Object operations such as collection.data.insert(...), collection.query.fetch_objects(...), collection.data.delete_many(...), and client.collections.delete(...) are sync calls invoked from async methods. These can stall the event loop under latency; use asyncio.to_thread(...) (see snippet below). -### 2. Code Redundancy & Refactoring Opportunities +HTTP client lifecycle (Vectorizer): +Vectorizer owns an httpx.AsyncClient but WeaviateStorage.close() doesn’t close it—add a close call to avoid resource leaks. -These suggestions aim to make the code more concise, maintainable, and reusable (D.R.Y. - Don't Repeat Yourself). +Heuristic “word_count” for OpenWebUI file listing: +word_count is estimated via size/6. That’s fine as a placeholder but can mislead paging and UI logic downstream—consider a sentinel or a clearer label (e.g., estimated_word_count). -*
- - HIGH IMPACT: Redundant collection logic in dashboard (Lines 356-424) - +repomix-output (2) - * **File:** `ingest_pipeline/cli/tui/screens/dashboard.py:356-424` - * **Issue:** CONFIRMED - `list_weaviate_collections()` and `list_openwebui_collections()` duplicate `StorageManager.get_all_collections()` - * **Code Duplication:** ~70 lines of redundant collection listing logic - * **Architecture Violation:** UI layer coupled to specific storage implementations - * **Current Usage:** `refresh_collections()` calls `get_all_collections()` (line 327), making methods obsolete - * **Action:** DELETE methods `list_weaviate_collections` and `list_openwebui_collections` - * **Impact:** Code reduction ~70 lines, improved maintainability - * **Risk:** Low - methods appear unused in current flow - * **Effort:** Low (deletion only) -
+Wide except Exception: +Broad catches appear in places like the storage manager and TUI screen update closures; they hide actionable errors. Prefer catching StorageError, IngestionError, or specific SDK exceptions—and surface toasts with actionable details. -*
- - MEDIUM IMPACT: Repetitive backend init pattern (Lines 255-291) - +repomix-output (2) - * **File:** `ingest_pipeline/cli/tui/utils/storage_manager.py:255-291` - * **Issue:** CONFIRMED - Pattern repeated 3x for each backend type - * **Code Structure:** Check settings → Create config → Add task (12 lines × 3 backends) - * **Current Backends:** Weaviate (258-267), OpenWebUI (270-279), R2R (282-291) - * **Refactor Pattern:** Create `BackendConfig` dataclass with `(backend_type, endpoint_setting, api_key_setting, storage_class)` - * **Implementation:** Loop over config list, reducing ~36 lines to ~15 lines - * **Extensibility:** Adding new backend becomes one-line config addition - * **Testing:** Ensure `asyncio.gather()` behavior unchanged (line 296) - * **Effort:** Medium (requires dataclass design + testing) -
+Performance & Scalability -*
- - MEDIUM IMPACT: Repeated Prefect block loading pattern (Lines 266-311) - +Batch embeddings: +In WeaviateStorage.store_batch you conditionally vectorize each doc inline. Use your existing vectorize_batch and map the results back to documents. This cuts request overhead and enables controlled concurrency (you already have AsyncTaskManager to help). - * **File:** `ingest_pipeline/flows/ingestion.py:266-311` - * **Issue:** CONFIRMED - Pattern in `_create_ingestor()` and `_create_storage()` methods - * **Duplication:** `Block.aload()` + fallback logic repeated 4x across both methods - * **Variable Resolution:** Batch size logic (lines 244-255) also needs abstraction - * **Helper Functions Needed:** - - `load_block_with_fallback(block_slug: str, default_config: T) -> T` - - `resolve_prefect_variable(var_name: str, default: T, type_cast: Type[T]) -> T` - * **Impact:** Cleaner flow logic, better error handling, type safety - * **Lines Reduced:** ~20 lines of repetitive code - * **Effort:** Medium (requires generic typing) -
+Async-friendly Weaviate calls: +Offload sync SDK operations to a thread so your Textual UI remains responsive while bulk inserting/querying/deleting. (See patch below.) -### 3. User Experience (UX) Enhancements +Retry & backoff are solid in your HTTP base client: +Your TypedHttpClient.request adds exponential backoff + jitter; this is great and worth keeping consistent across adapters. -These are suggestions to make your TUI more powerful, intuitive, and enjoyable for the user. +repomix-output (2) -*
- - HIGH IMPACT: Document content viewer modal (Add to documents.py) - +UX notes (Textual TUI) - * **Target File:** `ingest_pipeline/cli/tui/screens/documents.py` - * **Current State:** READY - `DocumentManagementScreen` has table selection (line 212) - * **Implementation:** - - Add `Binding("v", "view_document", "View")` to BINDINGS (line 27) - - Create `DocumentContentModal(ModalScreen)` with `ScrollableContainer` + `Markdown` - - Use existing `get_current_document()` method (line 212) - - Fetch full content via `storage.retrieve(document_id)` - * **Dependencies:** Import `ModalScreen`, `ScrollableContainer`, `Markdown` from textual - * **User Value:** HIGH - essential for content inspection workflow - * **Effort:** Low-Medium (~50 lines of modal code) - * **Pattern:** Follow existing modal patterns in codebase -
+Notifications are good—make them consistent: +app.safe_notify(...) exists—use that everywhere instead of self.notify(...) to normalize markup handling & safety. -*
- - HIGH IMPACT: Analytics tab visualization (Lines 164-189) - +repomix-output (2) - * **Target File:** `ingest_pipeline/cli/tui/screens/dashboard.py:164-189` - * **Current State:** PLACEHOLDER - Static widgets with dummy content - * **Data Source:** Use existing `self.collections` (line 65) populated by `refresh_collections()` - * **Implementation Options:** - 1. **Simple Text Chart:** ASCII bar chart using existing collections data - 2. **textual-plotext:** Add dependency + bar chart widget - 3. **Custom Widget:** Simple bar visualization with Static widgets - * **Metrics to Show:** - - Documents per collection (data available) - - Storage usage per backend (calculated in `_calculate_metrics()`) - - Ingestion timeline (requires timestamp tracking) - * **Effort:** Low-Medium (depends on visualization complexity) - * **Dependencies:** Consider `textual-plotext` or pure ASCII approach -
+Avoid visible “freeze” scenarios: +Replace the final time.sleep(2) with a timer and transition immediately; users see their actions complete without lag. (Patch below.) -*
- - MEDIUM IMPACT: Global search implementation (Button exists, needs screen) - +repomix-output (2) - * **Target File:** `ingest_pipeline/cli/tui/screens/dashboard.py` - * **Current State:** READY - "Search All" button exists (line 122), handler stubbed - * **Backend Support:** `StorageManager.search_across_backends()` method exists (line 413-441) - * **Implementation:** - - Create `GlobalSearchScreen(ModalScreen)` with search input + results table - - Use existing `search_across_backends()` method for data - - Add "Backend" column to results table showing data source - - Handle async search with loading indicators - * **Current Limitation:** Search only works for Weaviate (line 563), need to extend - * **Data Flow:** Input → `storage_manager.search_across_backends()` → Results display - * **Effort:** Medium (~100 lines for new screen + search logic) -
+Search table quick find: +You have EnhancedDataTable with quick-search messages; wire a shortcut (e.g., /) to focus a search input and filter rows live. -*
- - MEDIUM IMPACT: R2R advanced features integration (Widgets ready) - +repomix-output (2) - * **Target File:** `ingest_pipeline/cli/tui/screens/documents.py` - * **Available Widgets:** CONFIRMED - `ChunkViewer`, `EntityGraph`, `CollectionStats`, `DocumentOverview` in `r2r_widgets.py` - * **Current Implementation:** Basic document table only, R2R-specific features unused - * **Integration Points:** - - Add "R2R Details" button when `collection["type"] == "r2r"` (conditional UI) - - Create `R2RDocumentDetailsScreen` using existing widgets - - Use `StorageManager.get_r2r_storage()` method (exists at line 442) - * **R2R Methods Available:** - - `get_document_chunks()`, `extract_entities()`, `get_document_overview()` - * **User Value:** Medium-High for R2R users, showcases advanced features - * **Effort:** Low-Medium (widgets exist, need screen integration) -
+Theme file size & maintainability: +The theming system is thorough; consider splitting styles.py into smaller modules or generating CSS at build time to keep the Python file lean. (The responsive CSS generators are consolidated here.) -*
- - LOW IMPACT: Create collection dialog (Backend methods exist) - +repomix-output (2) - * **Target File:** `ingest_pipeline/cli/tui/screens/dashboard.py` - * **Backend Support:** CONFIRMED - `create_collection()` method exists for R2R storage (line 690) - * **Current State:** No "Create Collection" button in existing UI - * **Implementation:** - - Add "New Collection" button to dashboard action buttons - - Create `CreateCollectionModal` with name input + backend checkboxes - - Iterate over `storage_manager.get_available_backends()` for backend selection - - Call `storage.create_collection()` on selected backends - * **Backend Compatibility:** Check which storage backends support collection creation - * **User Value:** Low-Medium (manual workflow, not critical) - * **Effort:** Low-Medium (~75 lines for modal + integration) -
+Modularity / Redundancy -## Implementation Priority Matrix +Converge repeated property mapping: +WeaviateStorage.store and store_batch build the same properties dict; factor this into a small helper to keep schemas in one place (less drift, easier to extend). -### Quick Wins (High Impact, Low Effort) -1. **Delete redundant collection methods** (dashboard.py:356-424) - 5 min -2. **Fix TUI startup blocking** (runners.py:91) - 15 min -3. **Document content viewer modal** (documents.py) - 30 min +repomix-output (2) -### High Impact Fixes (Medium Effort) -1. **R2R batch operation optimization** (storage.py:161-179) - Research R2R v3 API + implementation -2. **Analytics tab visualization** (dashboard.py:164-189) - Choose visualization approach + implement -3. **Backend initialization refactoring** (storage_manager.py:255-291) - Dataclass design + testing +Common “describe/list/count” patterns across storages: +R2RStorage, OpenWebUIStorage, and WeaviateStorage present similar collection/document listing and counting methods. Consider a small “collection view” mixin with shared helpers; each backend implements only the API-specific steps. -### Technical Debt (Long-term) -1. **R2R client consistency** (storage.py) - SDK analysis + refactoring -2. **Prefect block loading helpers** (ingestion.py:266-311) - Generic typing + testing -3. **URL validation enhancement** (ingestion.py:240-260) - Security + validation logic +Security & Reliability -### Feature Enhancements (User Value) -1. **Global search implementation** - Medium effort, requires search backend extension -2. **R2R advanced features integration** - Showcase existing widget capabilities -3. **Create collection dialog** - Nice-to-have administrative feature +Input sanitization for LLM metadata: +Your MetadataTagger sanitizes and bounds fields (e.g., max lengths, language whitelist). This is a strong pattern—keep it. -## Agent Execution Notes +repomix-output (2) -**Context Efficiency Tips:** -- Focus on one priority tier at a time -- Read specific file ranges mentioned in line numbers -- Use existing patterns (worker decorators, modal screens, async methods) -- Test changes incrementally, especially async operations -- Verify import dependencies before implementation +Timeouts and typed HTTP clients: +You standardize HTTP clients, headers, and timeouts. Good foundation for consistent behavior & observability. -**Architecture Constraints:** -- Maintain async/await patterns throughout -- Follow Textual reactive widget patterns -- Preserve Prefect flow structure for orchestration -- Keep storage backend abstraction intact +repomix-output (2) -The codebase demonstrates excellent architectural foundations - these enhancements build upon existing strengths rather than requiring structural changes. \ No newline at end of file +Suggested patches (drop‑in) +1) Don’t block the UI when closing the ingestion screen + +Current (blocking): + +repomix-output (2) + +import time +time.sleep(2) +cast("CollectionManagementApp", self.app).pop_screen() + + +Safer (schedule via the app’s timer) + +def _pop() -> None: + try: + self.app.pop_screen() + except Exception: + pass + +# schedule from the worker thread +cast("CollectionManagementApp", self.app).call_from_thread( + lambda: self.app.set_timer(2.0, _pop) +) + +2) Offload Weaviate sync calls from async methods + +Example – insert (from WeaviateStorage.store) + +repomix-output (2) + +# before (sync in async method) +collection.data.insert(properties=properties, vector=vector) + +# after +await asyncio.to_thread(collection.data.insert, properties=properties, vector=vector) + + +Example – fetch/delete by filter (from delete_by_filter) + +repomix-output (2) + +response = await asyncio.to_thread( + collection.query.fetch_objects, filters=where_filter, limit=1000 +) +for obj in response.objects: + await asyncio.to_thread(collection.data.delete_by_id, obj.uuid) + + +Example – connect during initialize + +repomix-output (2) + +await asyncio.to_thread(self.client.connect) + +3) Batch embeddings in store_batch + +Current (per‑doc): + +repomix-output (2) + +for doc in documents: + if doc.vector is None: + doc.vector = await self.vectorizer.vectorize(doc.content) + + +Proposed (batch): + +repomix-output (2) + +# collect contents needing vectors +to_embed_idxs = [i for i, d in enumerate(documents) if d.vector is None] +if to_embed_idxs: + contents = [documents[i].content for i in to_embed_idxs] + vectors = await self.vectorizer.vectorize_batch(contents) + for j, idx in enumerate(to_embed_idxs): + documents[idx].vector = vectors[j] + +4) Close the Vectorizer HTTP client when storage closes + +Current: WeaviateStorage.close() only closes the Weaviate client. + +repomix-output (2) + +Add: + +async def close(self) -> None: + if self.client: + try: + cast(weaviate.WeaviateClient, self.client).close() + except Exception as e: + import logging + logging.warning("Error closing Weaviate client: %s", e) + # NEW: close vectorizer HTTP client too + try: + await self.vectorizer.client.aclose() + except Exception: + pass + + +(Your Vectorizer owns an httpx.AsyncClient with headers/timeouts set.) + +repomix-output (2) + +5) Prefer safe_notify consistently + +Replace direct self.notify(...) calls inside TUI screens with cast(AppType, self.app).safe_notify(...) (same severity, markup=False by default), centralizing styling/sanitization. + +repomix-output (2) + +Smaller improvements (quality of life) + +Quick search focus: Wire the EnhancedDataTable quick-search to a keyboard binding (e.g., /) so users can filter rows without reaching for the mouse. + +repomix-output (2) + +Refactor repeated properties dict: Extract the property construction in WeaviateStorage.store/store_batch into a helper to avoid drift and reduce line count. + +repomix-output (2) + +Styles organization: styles.py is hefty by design. Consider splitting “theme palette”, “components”, and “responsive” generators into separate modules to keep diffs small and reviews easier. + +repomix-output (2) + +Architecture & Modularity + +Storage adapters: The three adapters share “describe/list/count” concepts. Introduce a tiny shared “CollectionIntrospection” mixin (interfaces + default fallbacks), and keep only API specifics in each adapter. This will simplify the TUI’s StorageManager as well. + +Ingestion flows: Good use of Prefect tasks/flows with retries, variables, and tagging. The Firecrawl→R2R specialization cleanly reuses common steps. The batch boundaries are clear and progress updates are threaded back to the UI. + +DX / Testing + +Unit tests: In this export I don’t see tests. Add lightweight tests for: + +Vector extraction/format parsing in Vectorizer (covers multiple providers). + +repomix-output (2) + +Weaviate adapters: property building, name normalization, vector extraction. + +repomix-output (2) + +MetadataTagger sanitization rules. + +repomix-output (2) + +TUI: use Textual’s pilot to test that notifications and timers trigger and that screens transition without blocking (verifies the sleep fix). + +Static analysis: You already have excellent typing. Add ruff + mypy --strict and a pre-commit config to keep it consistent across contributors. \ No newline at end of file diff --git a/ingest_pipeline/automations/__init__.py b/ingest_pipeline/automations/__init__.py index b50ca00..83a0e3b 100644 --- a/ingest_pipeline/automations/__init__.py +++ b/ingest_pipeline/automations/__init__.py @@ -19,7 +19,6 @@ actions: source: inferred enabled: true """, - "retry_failed": """ name: Retry Failed Ingestion Flows description: Retries failed ingestion flows with original parameters @@ -39,7 +38,6 @@ actions: validate_first: false enabled: true """, - "resource_monitoring": """ name: Manage Work Pool Based on Resources description: Pauses work pool when system resources are constrained diff --git a/ingest_pipeline/cli/__pycache__/__init__.cpython-312.pyc b/ingest_pipeline/cli/__pycache__/__init__.cpython-312.pyc index 7f1d8f8..04f5255 100644 Binary files a/ingest_pipeline/cli/__pycache__/__init__.cpython-312.pyc and b/ingest_pipeline/cli/__pycache__/__init__.cpython-312.pyc differ diff --git a/ingest_pipeline/cli/__pycache__/main.cpython-312.pyc b/ingest_pipeline/cli/__pycache__/main.cpython-312.pyc index c0fabe5..bb38b8b 100644 Binary files a/ingest_pipeline/cli/__pycache__/main.cpython-312.pyc and b/ingest_pipeline/cli/__pycache__/main.cpython-312.pyc differ diff --git a/ingest_pipeline/cli/main.py b/ingest_pipeline/cli/main.py index 9022ad8..b52c4a0 100644 --- a/ingest_pipeline/cli/main.py +++ b/ingest_pipeline/cli/main.py @@ -139,6 +139,7 @@ def ingest( # If we're already in an event loop (e.g., in Jupyter), use nest_asyncio try: import nest_asyncio + nest_asyncio.apply() result = asyncio.run(run_with_progress()) except ImportError: @@ -449,7 +450,9 @@ def search( def blocks_command() -> None: """🧩 List and manage Prefect Blocks.""" console.print("[bold cyan]📦 Prefect Blocks Management[/bold cyan]") - console.print("Use 'prefect block register --module ingest_pipeline.core.models' to register custom blocks") + console.print( + "Use 'prefect block register --module ingest_pipeline.core.models' to register custom blocks" + ) console.print("Use 'prefect block ls' to list available blocks") @@ -507,7 +510,9 @@ async def run_list_collections() -> None: weaviate_config = StorageConfig( backend=StorageBackend.WEAVIATE, endpoint=settings.weaviate_endpoint, - api_key=SecretStr(settings.weaviate_api_key) if settings.weaviate_api_key is not None else None, + api_key=SecretStr(settings.weaviate_api_key) + if settings.weaviate_api_key is not None + else None, collection_name="default", ) weaviate = WeaviateStorage(weaviate_config) @@ -528,7 +533,9 @@ async def run_list_collections() -> None: openwebui_config = StorageConfig( backend=StorageBackend.OPEN_WEBUI, endpoint=settings.openwebui_endpoint, - api_key=SecretStr(settings.openwebui_api_key) if settings.openwebui_api_key is not None else None, + api_key=SecretStr(settings.openwebui_api_key) + if settings.openwebui_api_key is not None + else None, collection_name="default", ) openwebui = OpenWebUIStorage(openwebui_config) @@ -593,7 +600,9 @@ async def run_search(query: str, collection: str | None, backend: str, limit: in weaviate_config = StorageConfig( backend=StorageBackend.WEAVIATE, endpoint=settings.weaviate_endpoint, - api_key=SecretStr(settings.weaviate_api_key) if settings.weaviate_api_key is not None else None, + api_key=SecretStr(settings.weaviate_api_key) + if settings.weaviate_api_key is not None + else None, collection_name=collection or "default", ) weaviate = WeaviateStorage(weaviate_config) diff --git a/ingest_pipeline/cli/tui/__pycache__/__init__.cpython-312.pyc b/ingest_pipeline/cli/tui/__pycache__/__init__.cpython-312.pyc index 5b9dae0..c1b4d9c 100644 Binary files a/ingest_pipeline/cli/tui/__pycache__/__init__.cpython-312.pyc and b/ingest_pipeline/cli/tui/__pycache__/__init__.cpython-312.pyc differ diff --git a/ingest_pipeline/cli/tui/__pycache__/app.cpython-312.pyc b/ingest_pipeline/cli/tui/__pycache__/app.cpython-312.pyc index fd540df..3cd000f 100644 Binary files a/ingest_pipeline/cli/tui/__pycache__/app.cpython-312.pyc and b/ingest_pipeline/cli/tui/__pycache__/app.cpython-312.pyc differ diff --git a/ingest_pipeline/cli/tui/__pycache__/models.cpython-312.pyc b/ingest_pipeline/cli/tui/__pycache__/models.cpython-312.pyc index d2d5850..65dfb85 100644 Binary files a/ingest_pipeline/cli/tui/__pycache__/models.cpython-312.pyc and b/ingest_pipeline/cli/tui/__pycache__/models.cpython-312.pyc differ diff --git a/ingest_pipeline/cli/tui/__pycache__/styles.cpython-312.pyc b/ingest_pipeline/cli/tui/__pycache__/styles.cpython-312.pyc index 940720f..6c53f50 100644 Binary files a/ingest_pipeline/cli/tui/__pycache__/styles.cpython-312.pyc and b/ingest_pipeline/cli/tui/__pycache__/styles.cpython-312.pyc differ diff --git a/ingest_pipeline/cli/tui/app.py b/ingest_pipeline/cli/tui/app.py index 3e6651d..62d8d9d 100644 --- a/ingest_pipeline/cli/tui/app.py +++ b/ingest_pipeline/cli/tui/app.py @@ -31,7 +31,6 @@ else: # pragma: no cover - optional dependency fallback R2RStorage = BaseStorage - class CollectionManagementApp(App[None]): """Enhanced modern Textual application with comprehensive keyboard navigation.""" diff --git a/ingest_pipeline/cli/tui/layouts.py b/ingest_pipeline/cli/tui/layouts.py index c458b01..db26aaf 100644 --- a/ingest_pipeline/cli/tui/layouts.py +++ b/ingest_pipeline/cli/tui/layouts.py @@ -57,7 +57,9 @@ class ResponsiveGrid(Container): markup: bool = True, ) -> None: """Initialize responsive grid.""" - super().__init__(*children, name=name, id=id, classes=classes, disabled=disabled, markup=markup) + super().__init__( + *children, name=name, id=id, classes=classes, disabled=disabled, markup=markup + ) self._columns: int = columns self._auto_fit: bool = auto_fit self._compact: bool = compact @@ -327,7 +329,9 @@ class CardLayout(ResponsiveGrid): ) -> None: """Initialize card layout with default settings for cards.""" # Default to auto-fit cards with minimum width - super().__init__(auto_fit=True, name=name, id=id, classes=classes, disabled=disabled, markup=markup) + super().__init__( + auto_fit=True, name=name, id=id, classes=classes, disabled=disabled, markup=markup + ) class SplitPane(Container): @@ -396,7 +400,9 @@ class SplitPane(Container): if self._vertical: cast(Widget, self).add_class("vertical") - pane_classes = ("top-pane", "bottom-pane") if self._vertical else ("left-pane", "right-pane") + pane_classes = ( + ("top-pane", "bottom-pane") if self._vertical else ("left-pane", "right-pane") + ) yield Container(self._left_content, classes=pane_classes[0]) yield Static("", classes="splitter") diff --git a/ingest_pipeline/cli/tui/models.py b/ingest_pipeline/cli/tui/models.py index 3aeaeda..25d2ac2 100644 --- a/ingest_pipeline/cli/tui/models.py +++ b/ingest_pipeline/cli/tui/models.py @@ -1,7 +1,7 @@ """Data models and TypedDict definitions for the TUI.""" from enum import IntEnum -from typing import Any, TypedDict +from typing import TypedDict class StorageCapabilities(IntEnum): @@ -47,7 +47,7 @@ class ChunkInfo(TypedDict): content: str start_index: int end_index: int - metadata: dict[str, Any] + metadata: dict[str, object] class EntityInfo(TypedDict): @@ -57,7 +57,7 @@ class EntityInfo(TypedDict): name: str type: str confidence: float - metadata: dict[str, Any] + metadata: dict[str, object] class FirecrawlOptions(TypedDict, total=False): @@ -77,7 +77,7 @@ class FirecrawlOptions(TypedDict, total=False): max_depth: int # Extraction options - extract_schema: dict[str, Any] | None + extract_schema: dict[str, object] | None extract_prompt: str | None diff --git a/ingest_pipeline/cli/tui/screens/__pycache__/__init__.cpython-312.pyc b/ingest_pipeline/cli/tui/screens/__pycache__/__init__.cpython-312.pyc index 6847502..91587a4 100644 Binary files a/ingest_pipeline/cli/tui/screens/__pycache__/__init__.cpython-312.pyc and b/ingest_pipeline/cli/tui/screens/__pycache__/__init__.cpython-312.pyc differ diff --git a/ingest_pipeline/cli/tui/screens/__pycache__/dashboard.cpython-312.pyc b/ingest_pipeline/cli/tui/screens/__pycache__/dashboard.cpython-312.pyc index 80b8710..f85215b 100644 Binary files a/ingest_pipeline/cli/tui/screens/__pycache__/dashboard.cpython-312.pyc and b/ingest_pipeline/cli/tui/screens/__pycache__/dashboard.cpython-312.pyc differ diff --git a/ingest_pipeline/cli/tui/screens/__pycache__/dialogs.cpython-312.pyc b/ingest_pipeline/cli/tui/screens/__pycache__/dialogs.cpython-312.pyc index d7c9466..2b98192 100644 Binary files a/ingest_pipeline/cli/tui/screens/__pycache__/dialogs.cpython-312.pyc and b/ingest_pipeline/cli/tui/screens/__pycache__/dialogs.cpython-312.pyc differ diff --git a/ingest_pipeline/cli/tui/screens/__pycache__/documents.cpython-312.pyc b/ingest_pipeline/cli/tui/screens/__pycache__/documents.cpython-312.pyc index 5d61cd6..737396f 100644 Binary files a/ingest_pipeline/cli/tui/screens/__pycache__/documents.cpython-312.pyc and b/ingest_pipeline/cli/tui/screens/__pycache__/documents.cpython-312.pyc differ diff --git a/ingest_pipeline/cli/tui/screens/__pycache__/help.cpython-312.pyc b/ingest_pipeline/cli/tui/screens/__pycache__/help.cpython-312.pyc index a64eb9a..efaa925 100644 Binary files a/ingest_pipeline/cli/tui/screens/__pycache__/help.cpython-312.pyc and b/ingest_pipeline/cli/tui/screens/__pycache__/help.cpython-312.pyc differ diff --git a/ingest_pipeline/cli/tui/screens/__pycache__/ingestion.cpython-312.pyc b/ingest_pipeline/cli/tui/screens/__pycache__/ingestion.cpython-312.pyc index a92317a..e42797b 100644 Binary files a/ingest_pipeline/cli/tui/screens/__pycache__/ingestion.cpython-312.pyc and b/ingest_pipeline/cli/tui/screens/__pycache__/ingestion.cpython-312.pyc differ diff --git a/ingest_pipeline/cli/tui/screens/__pycache__/search.cpython-312.pyc b/ingest_pipeline/cli/tui/screens/__pycache__/search.cpython-312.pyc index 3ca3a76..2d58b0b 100644 Binary files a/ingest_pipeline/cli/tui/screens/__pycache__/search.cpython-312.pyc and b/ingest_pipeline/cli/tui/screens/__pycache__/search.cpython-312.pyc differ diff --git a/ingest_pipeline/cli/tui/screens/base.py b/ingest_pipeline/cli/tui/screens/base.py index b6c34df..a8d9742 100644 --- a/ingest_pipeline/cli/tui/screens/base.py +++ b/ingest_pipeline/cli/tui/screens/base.py @@ -29,7 +29,7 @@ class BaseScreen(Screen[object]): name: str | None = None, id: str | None = None, classes: str | None = None, - **kwargs: object + **kwargs: object, ) -> None: """Initialize base screen.""" super().__init__(name=name, id=id, classes=classes) @@ -55,7 +55,7 @@ class CRUDScreen(BaseScreen, Generic[T]): name: str | None = None, id: str | None = None, classes: str | None = None, - **kwargs: object + **kwargs: object, ) -> None: """Initialize CRUD screen.""" super().__init__(storage_manager, name=name, id=id, classes=classes) @@ -311,7 +311,7 @@ class FormScreen(ModalScreen[T], Generic[T]): name: str | None = None, id: str | None = None, classes: str | None = None, - **kwargs: object + **kwargs: object, ) -> None: """Initialize form screen.""" super().__init__(name=name, id=id, classes=classes) diff --git a/ingest_pipeline/cli/tui/screens/dashboard.py b/ingest_pipeline/cli/tui/screens/dashboard.py index dacd32e..b94f8de 100644 --- a/ingest_pipeline/cli/tui/screens/dashboard.py +++ b/ingest_pipeline/cli/tui/screens/dashboard.py @@ -1,7 +1,6 @@ """Main dashboard screen with collections overview.""" import logging -from datetime import datetime from typing import TYPE_CHECKING, Final from textual import work @@ -357,7 +356,6 @@ class CollectionOverviewScreen(Screen[None]): self.is_loading = False loading_indicator.display = False - async def update_collections_table(self) -> None: """Update the collections table with enhanced formatting.""" table = self.query_one("#collections_table", EnhancedDataTable) diff --git a/ingest_pipeline/cli/tui/screens/dialogs.py b/ingest_pipeline/cli/tui/screens/dialogs.py index 8bbcc64..b90f526 100644 --- a/ingest_pipeline/cli/tui/screens/dialogs.py +++ b/ingest_pipeline/cli/tui/screens/dialogs.py @@ -82,7 +82,10 @@ class ConfirmDeleteScreen(Screen[None]): try: if self.collection["type"] == "weaviate" and self.parent_screen.weaviate: # Delete Weaviate collection - if self.parent_screen.weaviate.client and self.parent_screen.weaviate.client.collections: + if ( + self.parent_screen.weaviate.client + and self.parent_screen.weaviate.client.collections + ): self.parent_screen.weaviate.client.collections.delete(self.collection["name"]) self.notify( f"Deleted Weaviate collection: {self.collection['name']}", @@ -100,7 +103,7 @@ class ConfirmDeleteScreen(Screen[None]): return # Check if the storage backend supports collection deletion - if not hasattr(storage_backend, 'delete_collection'): + if not hasattr(storage_backend, "delete_collection"): self.notify( f"❌ Collection deletion not supported for {self.collection['type']} backend", severity="error", @@ -113,10 +116,13 @@ class ConfirmDeleteScreen(Screen[None]): collection_name = str(self.collection["name"]) collection_type = str(self.collection["type"]) - self.notify(f"Deleting {collection_type} collection: {collection_name}...", severity="information") + self.notify( + f"Deleting {collection_type} collection: {collection_name}...", + severity="information", + ) # Use the standard delete_collection method for all backends - if hasattr(storage_backend, 'delete_collection'): + if hasattr(storage_backend, "delete_collection"): success = await storage_backend.delete_collection(collection_name) else: self.notify("❌ Backend does not support collection deletion", severity="error") @@ -149,7 +155,6 @@ class ConfirmDeleteScreen(Screen[None]): self.parent_screen.refresh_collections() - class ConfirmDocumentDeleteScreen(Screen[None]): """Screen for confirming document deletion.""" @@ -223,11 +228,11 @@ class ConfirmDocumentDeleteScreen(Screen[None]): try: results: dict[str, bool] = {} - if hasattr(self.parent_screen, 'storage') and self.parent_screen.storage: + if hasattr(self.parent_screen, "storage") and self.parent_screen.storage: # Delete documents via storage # The storage should have delete_documents method for weaviate storage = self.parent_screen.storage - if hasattr(storage, 'delete_documents'): + if hasattr(storage, "delete_documents"): results = await storage.delete_documents( self.doc_ids, collection_name=self.collection["name"], @@ -280,7 +285,9 @@ class LogViewerScreen(ModalScreen[None]): yield Header(show_clock=True) yield Container( Static("📜 Live Application Logs", classes="title"), - Static("Logs update in real time. Press S to reveal the log file path.", classes="subtitle"), + Static( + "Logs update in real time. Press S to reveal the log file path.", classes="subtitle" + ), RichLog(id="log_stream", classes="log-stream", wrap=True, highlight=False), Static("", id="log_file_path", classes="subtitle"), classes="main_container log-viewer-container", @@ -291,13 +298,13 @@ class LogViewerScreen(ModalScreen[None]): """Attach this viewer to the parent application once mounted.""" self._log_widget = self.query_one(RichLog) - if hasattr(self.app, 'attach_log_viewer'): + if hasattr(self.app, "attach_log_viewer"): self.app.attach_log_viewer(self) # type: ignore[arg-type] def on_unmount(self) -> None: """Detach from the parent application when closed.""" - if hasattr(self.app, 'detach_log_viewer'): + if hasattr(self.app, "detach_log_viewer"): self.app.detach_log_viewer(self) # type: ignore[arg-type] def _get_log_widget(self) -> RichLog: @@ -340,4 +347,6 @@ class LogViewerScreen(ModalScreen[None]): if self._log_file is None: self.notify("File logging is disabled for this session.", severity="warning") else: - self.notify(f"Log file available at: {self._log_file}", severity="information", markup=False) + self.notify( + f"Log file available at: {self._log_file}", severity="information", markup=False + ) diff --git a/ingest_pipeline/cli/tui/screens/documents.py b/ingest_pipeline/cli/tui/screens/documents.py index 4d04e3c..2d24175 100644 --- a/ingest_pipeline/cli/tui/screens/documents.py +++ b/ingest_pipeline/cli/tui/screens/documents.py @@ -116,7 +116,9 @@ class DocumentManagementScreen(Screen[None]): content_type=str(doc.get("content_type", "text/plain")), content_preview=str(doc.get("content_preview", "")), word_count=( - lambda wc_val: int(wc_val) if isinstance(wc_val, (int, str)) and str(wc_val).isdigit() else 0 + lambda wc_val: int(wc_val) + if isinstance(wc_val, (int, str)) and str(wc_val).isdigit() + else 0 )(doc.get("word_count", 0)), timestamp=str(doc.get("timestamp", "")), ) @@ -126,7 +128,7 @@ class DocumentManagementScreen(Screen[None]): # For storage backends that don't support document listing, show a message self.notify( f"Document listing not supported for {self.storage.__class__.__name__}", - severity="information" + severity="information", ) self.documents = [] @@ -330,7 +332,9 @@ class DocumentManagementScreen(Screen[None]): """View the content of the currently selected document.""" if doc := self.get_current_document(): if self.storage: - self.app.push_screen(DocumentContentModal(doc, self.storage, self.collection["name"])) + self.app.push_screen( + DocumentContentModal(doc, self.storage, self.collection["name"]) + ) else: self.notify("No storage backend available", severity="error") else: @@ -381,13 +385,13 @@ class DocumentContentModal(ModalScreen[None]): yield Container( Static( f"📄 Document: {self.document['title'][:60]}{'...' if len(self.document['title']) > 60 else ''}", - classes="modal-header" + classes="modal-header", ), ScrollableContainer( Markdown("Loading document content...", id="document_content"), LoadingIndicator(id="content_loading"), - classes="modal-content" - ) + classes="modal-content", + ), ) async def on_mount(self) -> None: @@ -398,30 +402,29 @@ class DocumentContentModal(ModalScreen[None]): try: # Get full document content doc_content = await self.storage.retrieve( - self.document["id"], - collection_name=self.collection_name + self.document["id"], collection_name=self.collection_name ) # Format content for display if isinstance(doc_content, str): - formatted_content = f"""# {self.document['title']} + formatted_content = f"""# {self.document["title"]} -**Source:** {self.document.get('source_url', 'N/A')} -**Type:** {self.document.get('content_type', 'text/plain')} -**Words:** {self.document.get('word_count', 0):,} -**Timestamp:** {self.document.get('timestamp', 'N/A')} +**Source:** {self.document.get("source_url", "N/A")} +**Type:** {self.document.get("content_type", "text/plain")} +**Words:** {self.document.get("word_count", 0):,} +**Timestamp:** {self.document.get("timestamp", "N/A")} --- {doc_content} """ else: - formatted_content = f"""# {self.document['title']} + formatted_content = f"""# {self.document["title"]} -**Source:** {self.document.get('source_url', 'N/A')} -**Type:** {self.document.get('content_type', 'text/plain')} -**Words:** {self.document.get('word_count', 0):,} -**Timestamp:** {self.document.get('timestamp', 'N/A')} +**Source:** {self.document.get("source_url", "N/A")} +**Type:** {self.document.get("content_type", "text/plain")} +**Words:** {self.document.get("word_count", 0):,} +**Timestamp:** {self.document.get("timestamp", "N/A")} --- @@ -431,6 +434,8 @@ class DocumentContentModal(ModalScreen[None]): content_widget.update(formatted_content) except Exception as e: - content_widget.update(f"# Error Loading Document\n\nFailed to load document content: {e}") + content_widget.update( + f"# Error Loading Document\n\nFailed to load document content: {e}" + ) finally: loading.display = False diff --git a/ingest_pipeline/cli/tui/screens/ingestion.py b/ingest_pipeline/cli/tui/screens/ingestion.py index 5608228..6149a19 100644 --- a/ingest_pipeline/cli/tui/screens/ingestion.py +++ b/ingest_pipeline/cli/tui/screens/ingestion.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from textual import work from textual.app import ComposeResult @@ -105,12 +105,23 @@ class IngestionScreen(ModalScreen[None]): Label("📋 Source Type (Press 1/2/3):", classes="input-label"), Horizontal( Button("🌐 Web (1)", id="web_btn", variant="primary", classes="type-button"), - Button("📦 Repository (2)", id="repo_btn", variant="default", classes="type-button"), - Button("📖 Documentation (3)", id="docs_btn", variant="default", classes="type-button"), + Button( + "📦 Repository (2)", id="repo_btn", variant="default", classes="type-button" + ), + Button( + "📖 Documentation (3)", + id="docs_btn", + variant="default", + classes="type-button", + ), classes="type_buttons", ), Rule(line_style="dashed"), - Label(f"🗄️ Target Storages ({len(self.available_backends)} available):", classes="input-label", id="backend_label"), + Label( + f"🗄️ Target Storages ({len(self.available_backends)} available):", + classes="input-label", + id="backend_label", + ), Container( *self._create_backend_checkbox_widgets(), classes="backend-selection", @@ -139,7 +150,6 @@ class IngestionScreen(ModalScreen[None]): yield LoadingIndicator(id="loading", classes="pulse") - def _create_backend_checkbox_widgets(self) -> list[Checkbox]: """Create checkbox widgets for each available backend.""" checkboxes: list[Checkbox] = [ @@ -219,19 +229,26 @@ class IngestionScreen(ModalScreen[None]): collection_name = collection_input.value.strip() if not source_url: - self.notify("🔍 Please enter a source URL", severity="error") + cast("CollectionManagementApp", self.app).safe_notify( + "🔍 Please enter a source URL", severity="error" + ) url_input.focus() return # Validate URL format if not self._validate_url(source_url): - self.notify("❌ Invalid URL format. Please enter a valid HTTP/HTTPS URL or file:// path", severity="error") + cast("CollectionManagementApp", self.app).safe_notify( + "❌ Invalid URL format. Please enter a valid HTTP/HTTPS URL or file:// path", + severity="error", + ) url_input.focus() return resolved_backends = self._resolve_selected_backends() if not resolved_backends: - self.notify("⚠️ Select at least one storage backend", severity="warning") + cast("CollectionManagementApp", self.app).safe_notify( + "⚠️ Select at least one storage backend", severity="warning" + ) return self.selected_backends = resolved_backends @@ -246,18 +263,16 @@ class IngestionScreen(ModalScreen[None]): url_lower = url.lower() # Allow HTTP/HTTPS URLs - if url_lower.startswith(('http://', 'https://')): + if url_lower.startswith(("http://", "https://")): # Additional validation could be added here return True # Allow file:// URLs for repository paths - if url_lower.startswith('file://'): + if url_lower.startswith("file://"): return True # Allow local file paths that look like repositories - return '/' in url and not url_lower.startswith( - ('javascript:', 'data:', 'vbscript:') - ) + return "/" in url and not url_lower.startswith(("javascript:", "data:", "vbscript:")) def _resolve_selected_backends(self) -> list[StorageBackend]: selected: list[StorageBackend] = [] @@ -288,13 +303,14 @@ class IngestionScreen(ModalScreen[None]): if not self.selected_backends: status_widget.update("📋 Selected: None") elif len(self.selected_backends) == 1: - backend_name = BACKEND_LABELS.get(self.selected_backends[0], self.selected_backends[0].value) + backend_name = BACKEND_LABELS.get( + self.selected_backends[0], self.selected_backends[0].value + ) status_widget.update(f"📋 Selected: {backend_name}") else: # Multiple backends selected backend_names = [ - BACKEND_LABELS.get(backend, backend.value) - for backend in self.selected_backends + BACKEND_LABELS.get(backend, backend.value) for backend in self.selected_backends ] if len(backend_names) <= 3: # Show all names if 3 or fewer @@ -319,7 +335,10 @@ class IngestionScreen(ModalScreen[None]): for backend in BACKEND_ORDER: if backend not in self.available_backends: continue - if backend.value.lower() == backend_name_lower or backend.name.lower() == backend_name_lower: + if ( + backend.value.lower() == backend_name_lower + or backend.name.lower() == backend_name_lower + ): matched_backends.append(backend) break return matched_backends or [self.available_backends[0]] @@ -351,6 +370,7 @@ class IngestionScreen(ModalScreen[None]): loading.display = False except Exception: pass + cast("CollectionManagementApp", self.app).call_from_thread(_update) def progress_reporter(percent: int, message: str) -> None: @@ -362,6 +382,7 @@ class IngestionScreen(ModalScreen[None]): progress_text.update(message) except Exception: pass + cast("CollectionManagementApp", self.app).call_from_thread(_update_progress) try: @@ -377,13 +398,18 @@ class IngestionScreen(ModalScreen[None]): for i, backend in enumerate(backends): progress_percent = 20 + (60 * i) // len(backends) - progress_reporter(progress_percent, f"🔗 Processing {backend.value} backend ({i+1}/{len(backends)})...") + progress_reporter( + progress_percent, + f"🔗 Processing {backend.value} backend ({i + 1}/{len(backends)})...", + ) try: # Run the Prefect flow for this backend using asyncio.run with timeout import asyncio - async def run_flow_with_timeout(current_backend: StorageBackend = backend) -> IngestionResult: + async def run_flow_with_timeout( + current_backend: StorageBackend = backend, + ) -> IngestionResult: return await asyncio.wait_for( create_ingestion_flow( source_url=source_url, @@ -392,7 +418,7 @@ class IngestionScreen(ModalScreen[None]): collection_name=final_collection_name, progress_callback=progress_reporter, ), - timeout=600.0 # 10 minute timeout + timeout=600.0, # 10 minute timeout ) result = asyncio.run(run_flow_with_timeout()) @@ -401,25 +427,33 @@ class IngestionScreen(ModalScreen[None]): total_failed += result.documents_failed if result.error_messages: - flow_errors.extend([f"{backend.value}: {err}" for err in result.error_messages]) + flow_errors.extend( + [f"{backend.value}: {err}" for err in result.error_messages] + ) except TimeoutError: error_msg = f"{backend.value}: Timeout after 10 minutes" flow_errors.append(error_msg) progress_reporter(0, f"❌ {backend.value} timed out") - def notify_timeout(msg: str = f"⏰ {backend.value} flow timed out after 10 minutes") -> None: + + def notify_timeout( + msg: str = f"⏰ {backend.value} flow timed out after 10 minutes", + ) -> None: try: self.notify(msg, severity="error", markup=False) except Exception: pass + cast("CollectionManagementApp", self.app).call_from_thread(notify_timeout) except Exception as exc: flow_errors.append(f"{backend.value}: {exc}") + def notify_error(msg: str = f"❌ {backend.value} flow failed: {exc}") -> None: try: self.notify(msg, severity="error", markup=False) except Exception: pass + cast("CollectionManagementApp", self.app).call_from_thread(notify_error) successful = total_successful @@ -445,20 +479,37 @@ class IngestionScreen(ModalScreen[None]): cast("CollectionManagementApp", self.app).call_from_thread(notify_results) - import time - time.sleep(2) - cast("CollectionManagementApp", self.app).pop_screen() + def _pop() -> None: + try: + self.app.pop_screen() + except Exception: + pass + + # Schedule screen pop via timer instead of blocking + cast("CollectionManagementApp", self.app).call_from_thread( + lambda: self.app.set_timer(2.0, _pop) + ) except Exception as exc: # pragma: no cover - defensive progress_reporter(0, f"❌ Prefect flows error: {exc}") + def notify_error(msg: str = f"❌ Prefect flows failed: {exc}") -> None: try: self.notify(msg, severity="error") except Exception: pass + cast("CollectionManagementApp", self.app).call_from_thread(notify_error) - import time - time.sleep(2) + + def _pop_on_error() -> None: + try: + self.app.pop_screen() + except Exception: + pass + + # Schedule screen pop via timer for error case too + cast("CollectionManagementApp", self.app).call_from_thread( + lambda: self.app.set_timer(2.0, _pop_on_error) + ) finally: update_ui("hide_loading") - diff --git a/ingest_pipeline/cli/tui/screens/search.py b/ingest_pipeline/cli/tui/screens/search.py index 3b8427f..916dd0f 100644 --- a/ingest_pipeline/cli/tui/screens/search.py +++ b/ingest_pipeline/cli/tui/screens/search.py @@ -43,12 +43,24 @@ class SearchScreen(Screen[None]): @override def compose(self) -> ComposeResult: yield Header() + # Check if search is supported for this backend + backends = self.collection["backend"] + if isinstance(backends, str): + backends = [backends] + search_supported = "weaviate" in backends + search_indicator = "✅ Search supported" if search_supported else "❌ Search not supported" + yield Container( Static( - f"🔍 Search in: {self.collection['name']} ({self.collection['backend']})", + f"🔍 Search in: {self.collection['name']} ({', '.join(backends)}) - {search_indicator}", classes="title", ), - Static("Press / or Ctrl+F to focus search, Enter to search", classes="subtitle"), + Static( + "Press / or Ctrl+F to focus search, Enter to search" + if search_supported + else "Search functionality not available for this backend", + classes="subtitle", + ), Input(placeholder="Enter search query... (press Enter to search)", id="search_input"), Button("🔍 Search", id="search_btn", variant="primary"), Button("🗑️ Clear Results", id="clear_btn", variant="default"), @@ -127,7 +139,9 @@ class SearchScreen(Screen[None]): finally: loading.display = False - def _setup_search_ui(self, loading: LoadingIndicator, table: EnhancedDataTable, status: Static, query: str) -> None: + def _setup_search_ui( + self, loading: LoadingIndicator, table: EnhancedDataTable, status: Static, query: str + ) -> None: """Setup the search UI elements.""" loading.display = True status.update(f"🔍 Searching for '{query}'...") @@ -139,10 +153,18 @@ class SearchScreen(Screen[None]): if self.collection["type"] == "weaviate" and self.weaviate: return await self.search_weaviate(query) elif self.collection["type"] == "openwebui" and self.openwebui: - return await self.search_openwebui(query) + # OpenWebUI search is not yet implemented + self.notify("Search not supported for OpenWebUI collections", severity="warning") + return [] + elif self.collection["type"] == "r2r": + # R2R search would go here when implemented + self.notify("Search not supported for R2R collections", severity="warning") + return [] return [] - def _populate_results_table(self, table: EnhancedDataTable, results: list[dict[str, str | float]]) -> None: + def _populate_results_table( + self, table: EnhancedDataTable, results: list[dict[str, str | float]] + ) -> None: """Populate the results table with search results.""" for result in results: row_data = self._format_result_row(result) @@ -193,7 +215,11 @@ class SearchScreen(Screen[None]): return str(score) def _update_search_status( - self, status: Static, query: str, results: list[dict[str, str | float]], table: EnhancedDataTable + self, + status: Static, + query: str, + results: list[dict[str, str | float]], + table: EnhancedDataTable, ) -> None: """Update search status and notifications based on results.""" if not results: diff --git a/ingest_pipeline/cli/tui/styles.py b/ingest_pipeline/cli/tui/styles.py index 0095d1e..ac6b237 100644 --- a/ingest_pipeline/cli/tui/styles.py +++ b/ingest_pipeline/cli/tui/styles.py @@ -13,6 +13,8 @@ TextualApp = App[object] class AppProtocol(Protocol): """Protocol for apps that support CSS and refresh.""" + CSS: str + def refresh(self) -> None: """Refresh the app.""" ... @@ -1122,16 +1124,16 @@ def get_css_for_theme(theme_type: ThemeType) -> str: def apply_theme_to_app(app: TextualApp | AppProtocol, theme_type: ThemeType) -> None: """Apply a theme to a Textual app instance.""" try: - css = set_theme(theme_type) - # Set CSS using the standard Textual approach - if hasattr(app, "CSS") or isinstance(app, App): - setattr(app, "CSS", css) - # Refresh the app to apply new CSS - if hasattr(app, "refresh"): - app.refresh() + # Note: CSS class variable cannot be changed at runtime + # This function would need to be called during app initialization + # or implement a different approach for dynamic theming + _ = set_theme(theme_type) # Keep for future implementation + if hasattr(app, "refresh"): + app.refresh() except Exception as e: # Graceful fallback - log but don't crash the UI import logging + logging.debug(f"Failed to apply theme to app: {e}") @@ -1185,11 +1187,11 @@ class ThemeSwitcher: # Responsive breakpoints for dynamic layout adaptation RESPONSIVE_BREAKPOINTS = { - "xs": 40, # Extra small terminals - "sm": 60, # Small terminals - "md": 100, # Medium terminals - "lg": 140, # Large terminals - "xl": 180, # Extra large terminals + "xs": 40, # Extra small terminals + "sm": 60, # Small terminals + "md": 100, # Medium terminals + "lg": 140, # Large terminals + "xl": 180, # Extra large terminals } diff --git a/ingest_pipeline/cli/tui/utils/__pycache__/__init__.cpython-312.pyc b/ingest_pipeline/cli/tui/utils/__pycache__/__init__.cpython-312.pyc index 970fef0..bdfca59 100644 Binary files a/ingest_pipeline/cli/tui/utils/__pycache__/__init__.cpython-312.pyc and b/ingest_pipeline/cli/tui/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/ingest_pipeline/cli/tui/utils/__pycache__/runners.cpython-312.pyc b/ingest_pipeline/cli/tui/utils/__pycache__/runners.cpython-312.pyc index cd7f620..993c1c4 100644 Binary files a/ingest_pipeline/cli/tui/utils/__pycache__/runners.cpython-312.pyc and b/ingest_pipeline/cli/tui/utils/__pycache__/runners.cpython-312.pyc differ diff --git a/ingest_pipeline/cli/tui/utils/__pycache__/storage_manager.cpython-312.pyc b/ingest_pipeline/cli/tui/utils/__pycache__/storage_manager.cpython-312.pyc index e62a89e..8bc2d6f 100644 Binary files a/ingest_pipeline/cli/tui/utils/__pycache__/storage_manager.cpython-312.pyc and b/ingest_pipeline/cli/tui/utils/__pycache__/storage_manager.cpython-312.pyc differ diff --git a/ingest_pipeline/cli/tui/utils/runners.py b/ingest_pipeline/cli/tui/utils/runners.py index fc629ea..5e6dc47 100644 --- a/ingest_pipeline/cli/tui/utils/runners.py +++ b/ingest_pipeline/cli/tui/utils/runners.py @@ -10,8 +10,9 @@ from pathlib import Path from queue import Queue from typing import NamedTuple +import platformdirs + from ....config import configure_prefect, get_settings -from ....core.models import StorageBackend from .storage_manager import StorageManager @@ -53,21 +54,36 @@ def _configure_tui_logging(*, log_level: str) -> _TuiLoggingContext: log_file: Path | None = None try: + # Try current directory first for development log_dir = Path.cwd() / "logs" log_dir.mkdir(parents=True, exist_ok=True) log_file = log_dir / "tui.log" - file_handler = RotatingFileHandler( - log_file, - maxBytes=2_000_000, - backupCount=5, - encoding="utf-8", - ) - file_handler.setLevel(resolved_level) - file_handler.setFormatter(formatter) - root_logger.addHandler(file_handler) - except OSError as exc: # pragma: no cover - filesystem specific - fallback = logging.getLogger(__name__) - fallback.warning("Failed to configure file logging for TUI: %s", exc) + except OSError: + # Fall back to user log directory + try: + log_dir = Path(platformdirs.user_log_dir("ingest-pipeline", "ingest-pipeline")) + log_dir.mkdir(parents=True, exist_ok=True) + log_file = log_dir / "tui.log" + except OSError as exc: + fallback = logging.getLogger(__name__) + fallback.warning("Failed to create log directory, file logging disabled: %s", exc) + log_file = None + + if log_file: + try: + file_handler = RotatingFileHandler( + log_file, + maxBytes=2_000_000, + backupCount=5, + encoding="utf-8", + ) + file_handler.setLevel(resolved_level) + file_handler.setFormatter(formatter) + root_logger.addHandler(file_handler) + except OSError as exc: + fallback = logging.getLogger(__name__) + fallback.warning("Failed to configure file logging for TUI: %s", exc) + log_file = None _logging_context = _TuiLoggingContext(log_queue, formatter, log_file) return _logging_context @@ -93,6 +109,7 @@ async def run_textual_tui() -> None: # Import here to avoid circular import from ..app import CollectionManagementApp + app = CollectionManagementApp( storage_manager, None, # weaviate - will be available after initialization diff --git a/ingest_pipeline/cli/tui/utils/storage_manager.py b/ingest_pipeline/cli/tui/utils/storage_manager.py index 42313e2..031c9f5 100644 --- a/ingest_pipeline/cli/tui/utils/storage_manager.py +++ b/ingest_pipeline/cli/tui/utils/storage_manager.py @@ -1,6 +1,5 @@ """Storage management utilities for TUI applications.""" - from __future__ import annotations import asyncio @@ -11,12 +10,11 @@ from pydantic import SecretStr from ....core.exceptions import StorageError from ....core.models import Document, StorageBackend, StorageConfig -from ..models import CollectionInfo, StorageCapabilities - from ....storage.base import BaseStorage from ....storage.openwebui import OpenWebUIStorage from ....storage.r2r.storage import R2RStorage from ....storage.weaviate import WeaviateStorage +from ..models import CollectionInfo, StorageCapabilities if TYPE_CHECKING: from ....config.settings import Settings @@ -39,7 +37,6 @@ class StorageBackendProtocol(Protocol): async def close(self) -> None: ... - class MultiStorageAdapter(BaseStorage): """Mirror writes to multiple storage backends.""" @@ -70,7 +67,10 @@ class MultiStorageAdapter(BaseStorage): # Replicate to secondary backends concurrently if len(self._storages) > 1: - async def replicate_to_backend(storage: BaseStorage) -> tuple[BaseStorage, bool, Exception | None]: + + async def replicate_to_backend( + storage: BaseStorage, + ) -> tuple[BaseStorage, bool, Exception | None]: try: await storage.store(document, collection_name=collection_name) return storage, True, None @@ -106,11 +106,16 @@ class MultiStorageAdapter(BaseStorage): self, documents: list[Document], *, collection_name: str | None = None ) -> list[str]: # Store in primary backend first - primary_ids: list[str] = await self._primary.store_batch(documents, collection_name=collection_name) + primary_ids: list[str] = await self._primary.store_batch( + documents, collection_name=collection_name + ) # Replicate to secondary backends concurrently if len(self._storages) > 1: - async def replicate_batch_to_backend(storage: BaseStorage) -> tuple[BaseStorage, bool, Exception | None]: + + async def replicate_batch_to_backend( + storage: BaseStorage, + ) -> tuple[BaseStorage, bool, Exception | None]: try: await storage.store_batch(documents, collection_name=collection_name) return storage, True, None @@ -135,7 +140,9 @@ class MultiStorageAdapter(BaseStorage): if failures: backends = ", ".join(failures) - primary_error = errors[0] if errors else Exception("Unknown batch replication error") + primary_error = ( + errors[0] if errors else Exception("Unknown batch replication error") + ) raise StorageError( f"Batch stored in primary backend but replication failed for: {backends}" ) from primary_error @@ -144,11 +151,16 @@ class MultiStorageAdapter(BaseStorage): async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool: # Delete from primary backend first - primary_deleted: bool = await self._primary.delete(document_id, collection_name=collection_name) + primary_deleted: bool = await self._primary.delete( + document_id, collection_name=collection_name + ) # Delete from secondary backends concurrently if len(self._storages) > 1: - async def delete_from_backend(storage: BaseStorage) -> tuple[BaseStorage, bool, Exception | None]: + + async def delete_from_backend( + storage: BaseStorage, + ) -> tuple[BaseStorage, bool, Exception | None]: try: await storage.delete(document_id, collection_name=collection_name) return storage, True, None @@ -222,7 +234,6 @@ class MultiStorageAdapter(BaseStorage): return class_name - class StorageManager: """Centralized manager for all storage backend operations.""" @@ -237,7 +248,9 @@ class StorageManager: """Initialize all available storage backends with timeout protection.""" results: dict[StorageBackend, bool] = {} - async def init_backend(backend_type: StorageBackend, config: StorageConfig, storage_class: type[BaseStorage]) -> bool: + async def init_backend( + backend_type: StorageBackend, config: StorageConfig, storage_class: type[BaseStorage] + ) -> bool: """Initialize a single backend with timeout.""" try: storage = storage_class(config) @@ -261,10 +274,17 @@ class StorageManager: config = StorageConfig( backend=StorageBackend.WEAVIATE, endpoint=self.settings.weaviate_endpoint, - api_key=SecretStr(self.settings.weaviate_api_key) if self.settings.weaviate_api_key else None, + api_key=SecretStr(self.settings.weaviate_api_key) + if self.settings.weaviate_api_key + else None, collection_name="default", ) - tasks.append((StorageBackend.WEAVIATE, init_backend(StorageBackend.WEAVIATE, config, WeaviateStorage))) + tasks.append( + ( + StorageBackend.WEAVIATE, + init_backend(StorageBackend.WEAVIATE, config, WeaviateStorage), + ) + ) else: results[StorageBackend.WEAVIATE] = False @@ -273,10 +293,17 @@ class StorageManager: config = StorageConfig( backend=StorageBackend.OPEN_WEBUI, endpoint=self.settings.openwebui_endpoint, - api_key=SecretStr(self.settings.openwebui_api_key) if self.settings.openwebui_api_key else None, + api_key=SecretStr(self.settings.openwebui_api_key) + if self.settings.openwebui_api_key + else None, collection_name="default", ) - tasks.append((StorageBackend.OPEN_WEBUI, init_backend(StorageBackend.OPEN_WEBUI, config, OpenWebUIStorage))) + tasks.append( + ( + StorageBackend.OPEN_WEBUI, + init_backend(StorageBackend.OPEN_WEBUI, config, OpenWebUIStorage), + ) + ) else: results[StorageBackend.OPEN_WEBUI] = False @@ -295,7 +322,9 @@ class StorageManager: # Execute initialization tasks concurrently if tasks: backend_types, task_coroutines = zip(*tasks, strict=False) - task_results: Sequence[bool | BaseException] = await asyncio.gather(*task_coroutines, return_exceptions=True) + task_results: Sequence[bool | BaseException] = await asyncio.gather( + *task_coroutines, return_exceptions=True + ) for backend_type, task_result in zip(backend_types, task_results, strict=False): results[backend_type] = task_result if isinstance(task_result, bool) else False @@ -312,7 +341,9 @@ class StorageManager: storages: list[BaseStorage] = [] seen: set[StorageBackend] = set() for backend in backends: - backend_enum = backend if isinstance(backend, StorageBackend) else StorageBackend(backend) + backend_enum = ( + backend if isinstance(backend, StorageBackend) else StorageBackend(backend) + ) if backend_enum in seen: continue seen.add(backend_enum) @@ -350,12 +381,18 @@ class StorageManager: except StorageError as e: # Storage-specific errors - log and use 0 count import logging - logging.warning(f"Failed to get count for {collection_name} on {backend_type.value}: {e}") + + logging.warning( + f"Failed to get count for {collection_name} on {backend_type.value}: {e}" + ) count = 0 except Exception as e: # Unexpected errors - log and skip this collection from this backend import logging - logging.warning(f"Unexpected error counting {collection_name} on {backend_type.value}: {e}") + + logging.warning( + f"Unexpected error counting {collection_name} on {backend_type.value}: {e}" + ) continue size_mb = count * 0.01 # Rough estimate: 10KB per document @@ -446,7 +483,9 @@ class StorageManager: storage = self.backends.get(StorageBackend.R2R) return storage if isinstance(storage, R2RStorage) else None - async def get_backend_status(self) -> dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]]: + async def get_backend_status( + self, + ) -> dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]]: """Get comprehensive status for all backends.""" status: dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]] = {} diff --git a/ingest_pipeline/cli/tui/widgets/__pycache__/__init__.cpython-312.pyc b/ingest_pipeline/cli/tui/widgets/__pycache__/__init__.cpython-312.pyc index 1a8e3a6..49bedc6 100644 Binary files a/ingest_pipeline/cli/tui/widgets/__pycache__/__init__.cpython-312.pyc and b/ingest_pipeline/cli/tui/widgets/__pycache__/__init__.cpython-312.pyc differ diff --git a/ingest_pipeline/cli/tui/widgets/__pycache__/cards.cpython-312.pyc b/ingest_pipeline/cli/tui/widgets/__pycache__/cards.cpython-312.pyc index affdeb1..775a6da 100644 Binary files a/ingest_pipeline/cli/tui/widgets/__pycache__/cards.cpython-312.pyc and b/ingest_pipeline/cli/tui/widgets/__pycache__/cards.cpython-312.pyc differ diff --git a/ingest_pipeline/cli/tui/widgets/__pycache__/indicators.cpython-312.pyc b/ingest_pipeline/cli/tui/widgets/__pycache__/indicators.cpython-312.pyc index 901bdc2..1b3408d 100644 Binary files a/ingest_pipeline/cli/tui/widgets/__pycache__/indicators.cpython-312.pyc and b/ingest_pipeline/cli/tui/widgets/__pycache__/indicators.cpython-312.pyc differ diff --git a/ingest_pipeline/cli/tui/widgets/__pycache__/tables.cpython-312.pyc b/ingest_pipeline/cli/tui/widgets/__pycache__/tables.cpython-312.pyc index 3738e5d..8b956c4 100644 Binary files a/ingest_pipeline/cli/tui/widgets/__pycache__/tables.cpython-312.pyc and b/ingest_pipeline/cli/tui/widgets/__pycache__/tables.cpython-312.pyc differ diff --git a/ingest_pipeline/cli/tui/widgets/firecrawl_config.py b/ingest_pipeline/cli/tui/widgets/firecrawl_config.py index f80aa0d..a1cc9fe 100644 --- a/ingest_pipeline/cli/tui/widgets/firecrawl_config.py +++ b/ingest_pipeline/cli/tui/widgets/firecrawl_config.py @@ -158,9 +158,7 @@ class ScrapeOptionsForm(Widget): formats.append("screenshot") options: dict[str, object] = { "formats": formats, - "only_main_content": self.query_one( - "#only_main_content", Switch - ).value, + "only_main_content": self.query_one("#only_main_content", Switch).value, } include_tags_input = self.query_one("#include_tags", Input).value if include_tags_input.strip(): @@ -195,11 +193,15 @@ class ScrapeOptionsForm(Widget): if include_tags := options.get("include_tags", []): include_list = include_tags if isinstance(include_tags, list) else [] - self.query_one("#include_tags", Input).value = ", ".join(str(tag) for tag in include_list) + self.query_one("#include_tags", Input).value = ", ".join( + str(tag) for tag in include_list + ) if exclude_tags := options.get("exclude_tags", []): exclude_list = exclude_tags if isinstance(exclude_tags, list) else [] - self.query_one("#exclude_tags", Input).value = ", ".join(str(tag) for tag in exclude_list) + self.query_one("#exclude_tags", Input).value = ", ".join( + str(tag) for tag in exclude_list + ) # Set performance wait_for = options.get("wait_for") @@ -503,7 +505,9 @@ class ExtractOptionsForm(Widget): "content": "string", "tags": ["string"] }""" - prompt_widget.text = "Extract article title, author, publication date, main content, and associated tags" + prompt_widget.text = ( + "Extract article title, author, publication date, main content, and associated tags" + ) elif event.button.id == "preset_product": schema_widget.text = """{ @@ -513,7 +517,9 @@ class ExtractOptionsForm(Widget): "category": "string", "availability": "string" }""" - prompt_widget.text = "Extract product name, price, description, category, and availability status" + prompt_widget.text = ( + "Extract product name, price, description, category, and availability status" + ) elif event.button.id == "preset_contact": schema_widget.text = """{ @@ -523,7 +529,9 @@ class ExtractOptionsForm(Widget): "company": "string", "position": "string" }""" - prompt_widget.text = "Extract contact information including name, email, phone, company, and position" + prompt_widget.text = ( + "Extract contact information including name, email, phone, company, and position" + ) elif event.button.id == "preset_data": schema_widget.text = """{ diff --git a/ingest_pipeline/cli/tui/widgets/indicators.py b/ingest_pipeline/cli/tui/widgets/indicators.py index 268b9e3..7209569 100644 --- a/ingest_pipeline/cli/tui/widgets/indicators.py +++ b/ingest_pipeline/cli/tui/widgets/indicators.py @@ -26,8 +26,12 @@ class StatusIndicator(Static): status_lower = status.lower() - if (status_lower in {"active", "online", "connected", "✓ active"} or - status_lower.endswith("active") or "✓" in status_lower and "active" in status_lower): + if ( + status_lower in {"active", "online", "connected", "✓ active"} + or status_lower.endswith("active") + or "✓" in status_lower + and "active" in status_lower + ): self.add_class("status-active") self.add_class("glow") self.update(f"🟢 {status}") diff --git a/ingest_pipeline/cli/tui/widgets/r2r_widgets.py b/ingest_pipeline/cli/tui/widgets/r2r_widgets.py index b571302..f865936 100644 --- a/ingest_pipeline/cli/tui/widgets/r2r_widgets.py +++ b/ingest_pipeline/cli/tui/widgets/r2r_widgets.py @@ -99,10 +99,17 @@ class ChunkViewer(Widget): "id": str(chunk_data.get("id", "")), "document_id": self.document_id, "content": str(chunk_data.get("text", "")), - "start_index": (lambda si: int(si) if isinstance(si, (int, str)) else 0)(chunk_data.get("start_index", 0)), - "end_index": (lambda ei: int(ei) if isinstance(ei, (int, str)) else 0)(chunk_data.get("end_index", 0)), + "start_index": (lambda si: int(si) if isinstance(si, (int, str)) else 0)( + chunk_data.get("start_index", 0) + ), + "end_index": (lambda ei: int(ei) if isinstance(ei, (int, str)) else 0)( + chunk_data.get("end_index", 0) + ), "metadata": ( - dict(metadata_val) if (metadata_val := chunk_data.get("metadata")) and isinstance(metadata_val, dict) else {} + dict(metadata_val) + if (metadata_val := chunk_data.get("metadata")) + and isinstance(metadata_val, dict) + else {} ), } self.chunks.append(chunk_info) @@ -275,10 +282,10 @@ class EntityGraph(Widget): """Show detailed information about an entity.""" details_widget = self.query_one("#entity_details", Static) - details_text = f"""**Entity:** {entity['name']} -**Type:** {entity['type']} -**Confidence:** {entity['confidence']:.2%} -**ID:** {entity['id']} + details_text = f"""**Entity:** {entity["name"]} +**Type:** {entity["type"]} +**Confidence:** {entity["confidence"]:.2%} +**ID:** {entity["id"]} **Metadata:** """ diff --git a/ingest_pipeline/config/__init__.py b/ingest_pipeline/config/__init__.py index 782f56a..ccd473e 100644 --- a/ingest_pipeline/config/__init__.py +++ b/ingest_pipeline/config/__init__.py @@ -26,9 +26,15 @@ def _setup_prefect_settings() -> tuple[object, object, object]: if registry is not None: Setting = getattr(ps, "Setting", None) if Setting is not None: - api_key = registry.get("PREFECT_API_KEY") or Setting("PREFECT_API_KEY", type_=str, default=None) - api_url = registry.get("PREFECT_API_URL") or Setting("PREFECT_API_URL", type_=str, default=None) - work_pool = registry.get("PREFECT_DEFAULT_WORK_POOL_NAME") or Setting("PREFECT_DEFAULT_WORK_POOL_NAME", type_=str, default=None) + api_key = registry.get("PREFECT_API_KEY") or Setting( + "PREFECT_API_KEY", type_=str, default=None + ) + api_url = registry.get("PREFECT_API_URL") or Setting( + "PREFECT_API_URL", type_=str, default=None + ) + work_pool = registry.get("PREFECT_DEFAULT_WORK_POOL_NAME") or Setting( + "PREFECT_DEFAULT_WORK_POOL_NAME", type_=str, default=None + ) return api_key, api_url, work_pool except ImportError: @@ -37,6 +43,7 @@ def _setup_prefect_settings() -> tuple[object, object, object]: # Ultimate fallback return None, None, None + PREFECT_API_KEY, PREFECT_API_URL, PREFECT_DEFAULT_WORK_POOL_NAME = _setup_prefect_settings() # Import after Prefect settings setup to avoid circular dependencies @@ -53,20 +60,30 @@ def configure_prefect(settings: Settings) -> None: overrides: dict[Setting, str] = {} - if settings.prefect_api_url is not None and PREFECT_API_URL is not None and isinstance(PREFECT_API_URL, Setting): + if ( + settings.prefect_api_url is not None + and PREFECT_API_URL is not None + and isinstance(PREFECT_API_URL, Setting) + ): overrides[PREFECT_API_URL] = str(settings.prefect_api_url) - if settings.prefect_api_key and PREFECT_API_KEY is not None and isinstance(PREFECT_API_KEY, Setting): + if ( + settings.prefect_api_key + and PREFECT_API_KEY is not None + and isinstance(PREFECT_API_KEY, Setting) + ): overrides[PREFECT_API_KEY] = settings.prefect_api_key - if settings.prefect_work_pool and PREFECT_DEFAULT_WORK_POOL_NAME is not None and isinstance(PREFECT_DEFAULT_WORK_POOL_NAME, Setting): + if ( + settings.prefect_work_pool + and PREFECT_DEFAULT_WORK_POOL_NAME is not None + and isinstance(PREFECT_DEFAULT_WORK_POOL_NAME, Setting) + ): overrides[PREFECT_DEFAULT_WORK_POOL_NAME] = settings.prefect_work_pool if not overrides: return filtered_overrides = { - setting: value - for setting, value in overrides.items() - if setting.value() != value + setting: value for setting, value in overrides.items() if setting.value() != value } if not filtered_overrides: diff --git a/ingest_pipeline/config/__pycache__/__init__.cpython-312.pyc b/ingest_pipeline/config/__pycache__/__init__.cpython-312.pyc index 0c86130..575cddd 100644 Binary files a/ingest_pipeline/config/__pycache__/__init__.cpython-312.pyc and b/ingest_pipeline/config/__pycache__/__init__.cpython-312.pyc differ diff --git a/ingest_pipeline/config/__pycache__/settings.cpython-312.pyc b/ingest_pipeline/config/__pycache__/settings.cpython-312.pyc index d4d2d4c..d599e53 100644 Binary files a/ingest_pipeline/config/__pycache__/settings.cpython-312.pyc and b/ingest_pipeline/config/__pycache__/settings.cpython-312.pyc differ diff --git a/ingest_pipeline/config/settings.py b/ingest_pipeline/config/settings.py index dd3fdeb..f27877e 100644 --- a/ingest_pipeline/config/settings.py +++ b/ingest_pipeline/config/settings.py @@ -139,10 +139,11 @@ class Settings(BaseSettings): key_name, key_value = required_keys[backend] if not key_value: import warnings + warnings.warn( f"{key_name} not set - authentication may fail for {backend} backend", UserWarning, - stacklevel=2 + stacklevel=2, ) return self @@ -165,16 +166,24 @@ class PrefectVariableConfig: def __init__(self) -> None: self._settings: Settings = get_settings() self._variable_names: list[str] = [ - "default_batch_size", "max_file_size", "max_crawl_depth", "max_crawl_pages", - "default_storage_backend", "default_collection_prefix", "max_concurrent_tasks", - "request_timeout", "default_schedule_interval" + "default_batch_size", + "max_file_size", + "max_crawl_depth", + "max_crawl_pages", + "default_storage_backend", + "default_collection_prefix", + "max_concurrent_tasks", + "request_timeout", + "default_schedule_interval", ] def _get_fallback_value(self, name: str, default_value: object = None) -> object: """Get fallback value from settings or default.""" return default_value or getattr(self._settings, name, default_value) - def get_with_fallback(self, name: str, default_value: str | int | float | None = None) -> str | int | float | None: + def get_with_fallback( + self, name: str, default_value: str | int | float | None = None + ) -> str | int | float | None: """Get variable value with fallback synchronously.""" fallback = self._get_fallback_value(name, default_value) # Ensure fallback is a type that Variable expects @@ -195,7 +204,9 @@ class PrefectVariableConfig: return fallback return str(fallback) if fallback is not None else None - async def get_with_fallback_async(self, name: str, default_value: str | int | float | None = None) -> str | int | float | None: + async def get_with_fallback_async( + self, name: str, default_value: str | int | float | None = None + ) -> str | int | float | None: """Get variable value with fallback asynchronously.""" fallback = self._get_fallback_value(name, default_value) variable_fallback = str(fallback) if fallback is not None else None diff --git a/ingest_pipeline/core/__pycache__/__init__.cpython-312.pyc b/ingest_pipeline/core/__pycache__/__init__.cpython-312.pyc index 8e4579b..ecdb6dc 100644 Binary files a/ingest_pipeline/core/__pycache__/__init__.cpython-312.pyc and b/ingest_pipeline/core/__pycache__/__init__.cpython-312.pyc differ diff --git a/ingest_pipeline/core/__pycache__/exceptions.cpython-312.pyc b/ingest_pipeline/core/__pycache__/exceptions.cpython-312.pyc index a991174..f82254b 100644 Binary files a/ingest_pipeline/core/__pycache__/exceptions.cpython-312.pyc and b/ingest_pipeline/core/__pycache__/exceptions.cpython-312.pyc differ diff --git a/ingest_pipeline/core/__pycache__/models.cpython-312.pyc b/ingest_pipeline/core/__pycache__/models.cpython-312.pyc index 51f00b4..accf339 100644 Binary files a/ingest_pipeline/core/__pycache__/models.cpython-312.pyc and b/ingest_pipeline/core/__pycache__/models.cpython-312.pyc differ diff --git a/ingest_pipeline/core/models.py b/ingest_pipeline/core/models.py index d2af2b2..e2ca503 100644 --- a/ingest_pipeline/core/models.py +++ b/ingest_pipeline/core/models.py @@ -12,35 +12,36 @@ from ..config import get_settings def _default_embedding_model() -> str: - return get_settings().embedding_model + return str(get_settings().embedding_model) def _default_embedding_endpoint() -> HttpUrl: - return get_settings().llm_endpoint + endpoint = get_settings().llm_endpoint + return endpoint if isinstance(endpoint, HttpUrl) else HttpUrl(str(endpoint)) def _default_embedding_dimension() -> int: - return get_settings().embedding_dimension + return int(get_settings().embedding_dimension) def _default_batch_size() -> int: - return get_settings().default_batch_size + return int(get_settings().default_batch_size) def _default_collection_name() -> str: - return get_settings().default_collection_prefix + return str(get_settings().default_collection_prefix) def _default_max_crawl_depth() -> int: - return get_settings().max_crawl_depth + return int(get_settings().max_crawl_depth) def _default_max_crawl_pages() -> int: - return get_settings().max_crawl_pages + return int(get_settings().max_crawl_pages) def _default_max_file_size() -> int: - return get_settings().max_file_size + return int(get_settings().max_file_size) class IngestionStatus(str, Enum): @@ -84,13 +85,16 @@ class StorageConfig(Block): _block_type_name: ClassVar[str | None] = "Storage Configuration" _block_type_slug: ClassVar[str | None] = "storage-config" - _description: ClassVar[str | None] = "Configures storage backend connections and settings for document ingestion" + _description: ClassVar[str | None] = ( + "Configures storage backend connections and settings for document ingestion" + ) backend: StorageBackend endpoint: HttpUrl api_key: SecretStr | None = Field(default=None) collection_name: str = Field(default_factory=_default_collection_name) batch_size: Annotated[int, Field(gt=0, le=1000)] = Field(default_factory=_default_batch_size) + grpc_port: int | None = Field(default=None, description="gRPC port for Weaviate connections") class FirecrawlConfig(Block): @@ -112,7 +116,9 @@ class RepomixConfig(Block): _block_type_name: ClassVar[str | None] = "Repomix Configuration" _block_type_slug: ClassVar[str | None] = "repomix-config" - _description: ClassVar[str | None] = "Configures repository ingestion patterns and file processing settings" + _description: ClassVar[str | None] = ( + "Configures repository ingestion patterns and file processing settings" + ) include_patterns: list[str] = Field( default_factory=lambda: ["*.py", "*.js", "*.ts", "*.md", "*.yaml", "*.json"] @@ -129,7 +135,9 @@ class R2RConfig(Block): _block_type_name: ClassVar[str | None] = "R2R Configuration" _block_type_slug: ClassVar[str | None] = "r2r-config" - _description: ClassVar[str | None] = "Configures R2R-specific ingestion settings including chunking and graph enrichment" + _description: ClassVar[str | None] = ( + "Configures R2R-specific ingestion settings including chunking and graph enrichment" + ) chunk_size: Annotated[int, Field(ge=100, le=8192)] = 1000 chunk_overlap: Annotated[int, Field(ge=0, le=1000)] = 200 @@ -139,6 +147,7 @@ class R2RConfig(Block): class DocumentMetadataRequired(TypedDict): """Required metadata fields for a document.""" + source_url: str timestamp: datetime content_type: str diff --git a/ingest_pipeline/flows/__pycache__/__init__.cpython-312.pyc b/ingest_pipeline/flows/__pycache__/__init__.cpython-312.pyc index bdb82c6..f01c833 100644 Binary files a/ingest_pipeline/flows/__pycache__/__init__.cpython-312.pyc and b/ingest_pipeline/flows/__pycache__/__init__.cpython-312.pyc differ diff --git a/ingest_pipeline/flows/__pycache__/ingestion.cpython-312.pyc b/ingest_pipeline/flows/__pycache__/ingestion.cpython-312.pyc index 029aa8d..4384a6b 100644 Binary files a/ingest_pipeline/flows/__pycache__/ingestion.cpython-312.pyc and b/ingest_pipeline/flows/__pycache__/ingestion.cpython-312.pyc differ diff --git a/ingest_pipeline/flows/__pycache__/scheduler.cpython-312.pyc b/ingest_pipeline/flows/__pycache__/scheduler.cpython-312.pyc index 0550c9b..ec27d15 100644 Binary files a/ingest_pipeline/flows/__pycache__/scheduler.cpython-312.pyc and b/ingest_pipeline/flows/__pycache__/scheduler.cpython-312.pyc differ diff --git a/ingest_pipeline/flows/ingestion.py b/ingest_pipeline/flows/ingestion.py index 53157bd..e18709e 100644 --- a/ingest_pipeline/flows/ingestion.py +++ b/ingest_pipeline/flows/ingestion.py @@ -103,8 +103,13 @@ async def initialize_storage_task(config: StorageConfig | str) -> BaseStorage: return storage -@task(name="map_firecrawl_site", retries=2, retry_delay_seconds=15, tags=["firecrawl", "map"], - cache_key_fn=lambda ctx, p: _safe_cache_key("firecrawl_map", p, "source_url")) +@task( + name="map_firecrawl_site", + retries=2, + retry_delay_seconds=15, + tags=["firecrawl", "map"], + cache_key_fn=lambda ctx, p: _safe_cache_key("firecrawl_map", p, "source_url"), +) async def map_firecrawl_site_task(source_url: str, config: FirecrawlConfig | str) -> list[str]: """Map a site using Firecrawl and return discovered URLs.""" # Load block if string provided @@ -118,8 +123,13 @@ async def map_firecrawl_site_task(source_url: str, config: FirecrawlConfig | str return mapped or [source_url] -@task(name="filter_existing_documents", retries=1, retry_delay_seconds=5, tags=["dedup"], - cache_key_fn=lambda ctx, p: _safe_cache_key("filter_docs", p, "urls")) # Cache based on URL list +@task( + name="filter_existing_documents", + retries=1, + retry_delay_seconds=5, + tags=["dedup"], + cache_key_fn=lambda ctx, p: _safe_cache_key("filter_docs", p, "urls"), +) # Cache based on URL list async def filter_existing_documents_task( urls: list[str], storage_client: BaseStorage, @@ -128,19 +138,40 @@ async def filter_existing_documents_task( collection_name: str | None = None, ) -> list[str]: """Filter URLs to only those that need scraping (missing or stale in storage).""" + import asyncio + logger = get_run_logger() - eligible: list[str] = [] - for url in urls: - document_id = str(FirecrawlIngestor.compute_document_id(url)) - exists = await storage_client.check_exists( - document_id, - collection_name=collection_name, - stale_after_days=stale_after_days - ) + # Use semaphore to limit concurrent existence checks + semaphore = asyncio.Semaphore(20) - if not exists: - eligible.append(url) + async def check_url_exists(url: str) -> tuple[str, bool]: + async with semaphore: + try: + document_id = str(FirecrawlIngestor.compute_document_id(url)) + exists = await storage_client.check_exists( + document_id, collection_name=collection_name, stale_after_days=stale_after_days + ) + return url, exists + except Exception as e: + logger.warning("Error checking existence for URL %s: %s", url, e) + # Assume doesn't exist on error to ensure we scrape it + return url, False + + # Check all URLs in parallel - use return_exceptions=True for partial failure handling + results = await asyncio.gather(*[check_url_exists(url) for url in urls], return_exceptions=True) + + # Collect URLs that need scraping, handling any exceptions + eligible = [] + for result in results: + if isinstance(result, Exception): + logger.error("Unexpected error in parallel existence check: %s", result) + continue + # Type narrowing: result is now known to be tuple[str, bool] + if isinstance(result, tuple) and len(result) == 2: + url, exists = result + if not exists: + eligible.append(url) skipped = len(urls) - len(eligible) if skipped > 0: @@ -260,7 +291,9 @@ async def ingest_documents_task( if progress_callback: progress_callback(40, "Starting document processing...") - return await _process_documents(ingestor, storage, job, batch_size, collection_name, progress_callback) + return await _process_documents( + ingestor, storage, job, batch_size, collection_name, progress_callback + ) async def _create_ingestor(job: IngestionJob, config_block_name: str | None = None) -> BaseIngestor: @@ -287,7 +320,9 @@ async def _create_ingestor(job: IngestionJob, config_block_name: str | None = No raise ValueError(f"Unsupported source: {job.source_type}") -async def _create_storage(job: IngestionJob, collection_name: str | None, storage_block_name: str | None = None) -> BaseStorage: +async def _create_storage( + job: IngestionJob, collection_name: str | None, storage_block_name: str | None = None +) -> BaseStorage: """Create and initialize storage client.""" if collection_name is None: # Use variable for default collection prefix @@ -303,6 +338,7 @@ async def _create_storage(job: IngestionJob, collection_name: str | None, storag else: # Fallback to building config from settings from ..config import get_settings + settings = get_settings() storage_config = _build_storage_config(job, settings, collection_name) @@ -391,7 +427,7 @@ async def _process_documents( progress_callback(45, "Ingesting documents from source...") # Use smart ingestion with deduplication if storage supports it - if hasattr(storage, 'check_exists'): + if hasattr(storage, "check_exists"): try: # Try to use the smart ingestion method document_generator = ingestor.ingest_with_dedup( @@ -412,7 +448,7 @@ async def _process_documents( if progress_callback: progress_callback( 45 + min(35, (batch_count * 10)), - f"Processing batch {batch_count} ({total_documents} documents so far)..." + f"Processing batch {batch_count} ({total_documents} documents so far)...", ) batch_processed, batch_failed = await _store_batch(storage, batch, collection_name) @@ -482,7 +518,9 @@ async def _store_batch( log_prints=True, ) async def firecrawl_to_r2r_flow( - job: IngestionJob, collection_name: str | None = None, progress_callback: Callable[[int, str], None] | None = None + job: IngestionJob, + collection_name: str | None = None, + progress_callback: Callable[[int, str], None] | None = None, ) -> tuple[int, int]: """Specialized flow for Firecrawl ingestion into R2R.""" logger = get_run_logger() @@ -549,10 +587,8 @@ async def firecrawl_to_r2r_flow( # Use asyncio.gather for concurrent scraping import asyncio - scrape_tasks = [ - scrape_firecrawl_batch_task(batch, firecrawl_config) - for batch in url_batches - ] + + scrape_tasks = [scrape_firecrawl_batch_task(batch, firecrawl_config) for batch in url_batches] batch_results = await asyncio.gather(*scrape_tasks) scraped_pages: list[FirecrawlPage] = [] @@ -680,9 +716,13 @@ async def create_ingestion_flow( progress_callback(30, "Starting document ingestion...") print("Ingesting documents...") if job.source_type == IngestionSource.WEB and job.storage_backend == StorageBackend.R2R: - processed, failed = await firecrawl_to_r2r_flow(job, collection_name, progress_callback=progress_callback) + processed, failed = await firecrawl_to_r2r_flow( + job, collection_name, progress_callback=progress_callback + ) else: - processed, failed = await ingest_documents_task(job, collection_name, progress_callback=progress_callback) + processed, failed = await ingest_documents_task( + job, collection_name, progress_callback=progress_callback + ) if progress_callback: progress_callback(90, "Finalizing ingestion...") diff --git a/ingest_pipeline/ingestors/__pycache__/__init__.cpython-312.pyc b/ingest_pipeline/ingestors/__pycache__/__init__.cpython-312.pyc index 60c5cf8..e589461 100644 Binary files a/ingest_pipeline/ingestors/__pycache__/__init__.cpython-312.pyc and b/ingest_pipeline/ingestors/__pycache__/__init__.cpython-312.pyc differ diff --git a/ingest_pipeline/ingestors/__pycache__/base.cpython-312.pyc b/ingest_pipeline/ingestors/__pycache__/base.cpython-312.pyc index 90833af..8f257b5 100644 Binary files a/ingest_pipeline/ingestors/__pycache__/base.cpython-312.pyc and b/ingest_pipeline/ingestors/__pycache__/base.cpython-312.pyc differ diff --git a/ingest_pipeline/ingestors/__pycache__/firecrawl.cpython-312.pyc b/ingest_pipeline/ingestors/__pycache__/firecrawl.cpython-312.pyc index 20330d4..1ef9e56 100644 Binary files a/ingest_pipeline/ingestors/__pycache__/firecrawl.cpython-312.pyc and b/ingest_pipeline/ingestors/__pycache__/firecrawl.cpython-312.pyc differ diff --git a/ingest_pipeline/ingestors/__pycache__/repomix.cpython-312.pyc b/ingest_pipeline/ingestors/__pycache__/repomix.cpython-312.pyc index 11c2959..e4260d9 100644 Binary files a/ingest_pipeline/ingestors/__pycache__/repomix.cpython-312.pyc and b/ingest_pipeline/ingestors/__pycache__/repomix.cpython-312.pyc differ diff --git a/ingest_pipeline/ingestors/firecrawl.py b/ingest_pipeline/ingestors/firecrawl.py index 7e0275b..1db5cee 100644 --- a/ingest_pipeline/ingestors/firecrawl.py +++ b/ingest_pipeline/ingestors/firecrawl.py @@ -191,9 +191,10 @@ class FirecrawlIngestor(BaseIngestor): "http://localhost" ): # Self-hosted instance - try with api_url if supported - self.client = cast(AsyncFirecrawlClient, AsyncFirecrawl( - api_key=api_key, api_url=str(settings.firecrawl_endpoint) - )) + self.client = cast( + AsyncFirecrawlClient, + AsyncFirecrawl(api_key=api_key, api_url=str(settings.firecrawl_endpoint)), + ) else: # Cloud instance - use standard initialization self.client = cast(AsyncFirecrawlClient, AsyncFirecrawl(api_key=api_key)) @@ -256,9 +257,7 @@ class FirecrawlIngestor(BaseIngestor): for check_url in site_map: document_id = str(self.compute_document_id(check_url)) exists = await storage_client.check_exists( - document_id, - collection_name=collection_name, - stale_after_days=stale_after_days + document_id, collection_name=collection_name, stale_after_days=stale_after_days ) if not exists: eligible_urls.append(check_url) @@ -394,7 +393,9 @@ class FirecrawlIngestor(BaseIngestor): # Extract basic metadata title: str | None = getattr(metadata, "title", None) if metadata else None - description: str | None = getattr(metadata, "description", None) if metadata else None + description: str | None = ( + getattr(metadata, "description", None) if metadata else None + ) # Extract enhanced metadata if available author: str | None = getattr(metadata, "author", None) if metadata else None @@ -402,26 +403,38 @@ class FirecrawlIngestor(BaseIngestor): sitemap_last_modified: str | None = ( getattr(metadata, "sitemap_last_modified", None) if metadata else None ) - source_url: str | None = getattr(metadata, "sourceURL", None) if metadata else None - keywords: str | list[str] | None = getattr(metadata, "keywords", None) if metadata else None + source_url: str | None = ( + getattr(metadata, "sourceURL", None) if metadata else None + ) + keywords: str | list[str] | None = ( + getattr(metadata, "keywords", None) if metadata else None + ) robots: str | None = getattr(metadata, "robots", None) if metadata else None # Open Graph metadata og_title: str | None = getattr(metadata, "ogTitle", None) if metadata else None - og_description: str | None = getattr(metadata, "ogDescription", None) if metadata else None + og_description: str | None = ( + getattr(metadata, "ogDescription", None) if metadata else None + ) og_url: str | None = getattr(metadata, "ogUrl", None) if metadata else None og_image: str | None = getattr(metadata, "ogImage", None) if metadata else None # Twitter metadata - twitter_card: str | None = getattr(metadata, "twitterCard", None) if metadata else None - twitter_site: str | None = getattr(metadata, "twitterSite", None) if metadata else None + twitter_card: str | None = ( + getattr(metadata, "twitterCard", None) if metadata else None + ) + twitter_site: str | None = ( + getattr(metadata, "twitterSite", None) if metadata else None + ) twitter_creator: str | None = ( getattr(metadata, "twitterCreator", None) if metadata else None ) # Additional metadata favicon: str | None = getattr(metadata, "favicon", None) if metadata else None - status_code: int | None = getattr(metadata, "statusCode", None) if metadata else None + status_code: int | None = ( + getattr(metadata, "statusCode", None) if metadata else None + ) return FirecrawlPage( url=url, @@ -594,10 +607,16 @@ class FirecrawlIngestor(BaseIngestor): "site_name": domain_info["site_name"], # Document structure "heading_hierarchy": ( - list(hierarchy_val) if (hierarchy_val := structure_info.get("heading_hierarchy")) and isinstance(hierarchy_val, (list, tuple)) else [] + list(hierarchy_val) + if (hierarchy_val := structure_info.get("heading_hierarchy")) + and isinstance(hierarchy_val, (list, tuple)) + else [] ), "section_depth": ( - int(depth_val) if (depth_val := structure_info.get("section_depth")) and isinstance(depth_val, (int, str)) else 0 + int(depth_val) + if (depth_val := structure_info.get("section_depth")) + and isinstance(depth_val, (int, str)) + else 0 ), "has_code_blocks": bool(structure_info.get("has_code_blocks", False)), "has_images": bool(structure_info.get("has_images", False)), @@ -632,7 +651,11 @@ class FirecrawlIngestor(BaseIngestor): await self.client.close() except Exception as e: logging.debug(f"Error closing Firecrawl client: {e}") - elif hasattr(self.client, "_session") and self.client._session and hasattr(self.client._session, "close"): + elif ( + hasattr(self.client, "_session") + and self.client._session + and hasattr(self.client._session, "close") + ): try: await self.client._session.close() except Exception as e: diff --git a/ingest_pipeline/ingestors/repomix.py b/ingest_pipeline/ingestors/repomix.py index 281a952..0448599 100644 --- a/ingest_pipeline/ingestors/repomix.py +++ b/ingest_pipeline/ingestors/repomix.py @@ -80,6 +80,7 @@ class RepomixIngestor(BaseIngestor): return result.returncode == 0 except Exception as e: import logging + logging.warning(f"Failed to validate repository {source_url}: {e}") return False @@ -97,9 +98,7 @@ class RepomixIngestor(BaseIngestor): try: with tempfile.TemporaryDirectory() as temp_dir: # Shallow clone to get file count - repo_path = await self._clone_repository( - source_url, temp_dir, shallow=True - ) + repo_path = await self._clone_repository(source_url, temp_dir, shallow=True) # Count files matching patterns file_count = 0 @@ -111,6 +110,7 @@ class RepomixIngestor(BaseIngestor): return file_count except Exception as e: import logging + logging.warning(f"Failed to estimate size for repository {source_url}: {e}") return 0 @@ -179,9 +179,7 @@ class RepomixIngestor(BaseIngestor): return output_file - async def _parse_repomix_output( - self, output_file: Path, job: IngestionJob - ) -> list[Document]: + async def _parse_repomix_output(self, output_file: Path, job: IngestionJob) -> list[Document]: """ Parse repomix output into documents. @@ -210,14 +208,17 @@ class RepomixIngestor(BaseIngestor): chunks = self._chunk_content(file_content) for i, chunk in enumerate(chunks): doc = self._create_document( - file_path, chunk, job, chunk_index=i, - git_metadata=git_metadata, repo_info=repo_info + file_path, + chunk, + job, + chunk_index=i, + git_metadata=git_metadata, + repo_info=repo_info, ) documents.append(doc) else: doc = self._create_document( - file_path, file_content, job, - git_metadata=git_metadata, repo_info=repo_info + file_path, file_content, job, git_metadata=git_metadata, repo_info=repo_info ) documents.append(doc) @@ -300,64 +301,64 @@ class RepomixIngestor(BaseIngestor): # Map common extensions to languages ext_map = { - '.py': 'python', - '.js': 'javascript', - '.ts': 'typescript', - '.jsx': 'javascript', - '.tsx': 'typescript', - '.java': 'java', - '.c': 'c', - '.cpp': 'cpp', - '.cc': 'cpp', - '.cxx': 'cpp', - '.h': 'c', - '.hpp': 'cpp', - '.cs': 'csharp', - '.go': 'go', - '.rs': 'rust', - '.php': 'php', - '.rb': 'ruby', - '.swift': 'swift', - '.kt': 'kotlin', - '.scala': 'scala', - '.sh': 'shell', - '.bash': 'shell', - '.zsh': 'shell', - '.sql': 'sql', - '.html': 'html', - '.css': 'css', - '.scss': 'scss', - '.less': 'less', - '.yaml': 'yaml', - '.yml': 'yaml', - '.json': 'json', - '.xml': 'xml', - '.md': 'markdown', - '.txt': 'text', - '.cfg': 'config', - '.ini': 'config', - '.toml': 'toml', + ".py": "python", + ".js": "javascript", + ".ts": "typescript", + ".jsx": "javascript", + ".tsx": "typescript", + ".java": "java", + ".c": "c", + ".cpp": "cpp", + ".cc": "cpp", + ".cxx": "cpp", + ".h": "c", + ".hpp": "cpp", + ".cs": "csharp", + ".go": "go", + ".rs": "rust", + ".php": "php", + ".rb": "ruby", + ".swift": "swift", + ".kt": "kotlin", + ".scala": "scala", + ".sh": "shell", + ".bash": "shell", + ".zsh": "shell", + ".sql": "sql", + ".html": "html", + ".css": "css", + ".scss": "scss", + ".less": "less", + ".yaml": "yaml", + ".yml": "yaml", + ".json": "json", + ".xml": "xml", + ".md": "markdown", + ".txt": "text", + ".cfg": "config", + ".ini": "config", + ".toml": "toml", } if extension in ext_map: return ext_map[extension] # Try to detect from shebang - if content.startswith('#!'): - first_line = content.split('\n')[0] - if 'python' in first_line: - return 'python' - elif 'node' in first_line or 'javascript' in first_line: - return 'javascript' - elif 'bash' in first_line or 'sh' in first_line: - return 'shell' + if content.startswith("#!"): + first_line = content.split("\n")[0] + if "python" in first_line: + return "python" + elif "node" in first_line or "javascript" in first_line: + return "javascript" + elif "bash" in first_line or "sh" in first_line: + return "shell" return None @staticmethod def _analyze_code_structure(content: str, language: str | None) -> dict[str, object]: """Analyze code structure and extract metadata.""" - lines = content.split('\n') + lines = content.split("\n") # Basic metrics has_functions = False @@ -366,29 +367,35 @@ class RepomixIngestor(BaseIngestor): has_comments = False # Language-specific patterns - if language == 'python': - has_functions = bool(re.search(r'^\s*def\s+\w+', content, re.MULTILINE)) - has_classes = bool(re.search(r'^\s*class\s+\w+', content, re.MULTILINE)) - has_imports = bool(re.search(r'^\s*(import|from)\s+', content, re.MULTILINE)) - has_comments = bool(re.search(r'^\s*#', content, re.MULTILINE)) - elif language in ['javascript', 'typescript']: - has_functions = bool(re.search(r'(function\s+\w+|^\s*\w+\s*:\s*function|\w+\s*=>\s*)', content, re.MULTILINE)) - has_classes = bool(re.search(r'^\s*class\s+\w+', content, re.MULTILINE)) - has_imports = bool(re.search(r'^\s*(import|require)', content, re.MULTILINE)) - has_comments = bool(re.search(r'//|/\*', content)) - elif language == 'java': - has_functions = bool(re.search(r'(public|private|protected).*\w+\s*\(', content, re.MULTILINE)) - has_classes = bool(re.search(r'(public|private)?\s*class\s+\w+', content, re.MULTILINE)) - has_imports = bool(re.search(r'^\s*import\s+', content, re.MULTILINE)) - has_comments = bool(re.search(r'//|/\*', content)) + if language == "python": + has_functions = bool(re.search(r"^\s*def\s+\w+", content, re.MULTILINE)) + has_classes = bool(re.search(r"^\s*class\s+\w+", content, re.MULTILINE)) + has_imports = bool(re.search(r"^\s*(import|from)\s+", content, re.MULTILINE)) + has_comments = bool(re.search(r"^\s*#", content, re.MULTILINE)) + elif language in ["javascript", "typescript"]: + has_functions = bool( + re.search( + r"(function\s+\w+|^\s*\w+\s*:\s*function|\w+\s*=>\s*)", content, re.MULTILINE + ) + ) + has_classes = bool(re.search(r"^\s*class\s+\w+", content, re.MULTILINE)) + has_imports = bool(re.search(r"^\s*(import|require)", content, re.MULTILINE)) + has_comments = bool(re.search(r"//|/\*", content)) + elif language == "java": + has_functions = bool( + re.search(r"(public|private|protected).*\w+\s*\(", content, re.MULTILINE) + ) + has_classes = bool(re.search(r"(public|private)?\s*class\s+\w+", content, re.MULTILINE)) + has_imports = bool(re.search(r"^\s*import\s+", content, re.MULTILINE)) + has_comments = bool(re.search(r"//|/\*", content)) return { - 'has_functions': has_functions, - 'has_classes': has_classes, - 'has_imports': has_imports, - 'has_comments': has_comments, - 'line_count': len(lines), - 'non_empty_lines': len([line for line in lines if line.strip()]), + "has_functions": has_functions, + "has_classes": has_classes, + "has_imports": has_imports, + "has_comments": has_comments, + "line_count": len(lines), + "non_empty_lines": len([line for line in lines if line.strip()]), } @staticmethod @@ -399,18 +406,16 @@ class RepomixIngestor(BaseIngestor): org_name = None # Handle different URL formats - if 'github.com' in repo_url or 'gitlab.com' in repo_url: - if path_match := re.search( - r'/([^/]+)/([^/]+?)(?:\.git)?/?$', repo_url - ): + if "github.com" in repo_url or "gitlab.com" in repo_url: + if path_match := re.search(r"/([^/]+)/([^/]+?)(?:\.git)?/?$", repo_url): org_name = path_match[1] repo_name = path_match[2] - elif path_match := re.search(r'/([^/]+?)(?:\.git)?/?$', repo_url): + elif path_match := re.search(r"/([^/]+?)(?:\.git)?/?$", repo_url): repo_name = path_match[1] return { - 'repository_name': repo_name or 'unknown', - 'organization': org_name or 'unknown', + "repository_name": repo_name or "unknown", + "organization": org_name or "unknown", } async def _get_git_metadata(self, repo_path: Path) -> dict[str, str | None]: @@ -418,31 +423,35 @@ class RepomixIngestor(BaseIngestor): try: # Get current branch branch_result = await self._run_command( - ['git', 'rev-parse', '--abbrev-ref', 'HEAD'], - cwd=str(repo_path), - timeout=5 + ["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd=str(repo_path), timeout=5 + ) + branch_name = ( + branch_result.stdout.decode().strip() if branch_result.returncode == 0 else None ) - branch_name = branch_result.stdout.decode().strip() if branch_result.returncode == 0 else None # Get current commit hash commit_result = await self._run_command( - ['git', 'rev-parse', 'HEAD'], - cwd=str(repo_path), - timeout=5 + ["git", "rev-parse", "HEAD"], cwd=str(repo_path), timeout=5 + ) + commit_hash = ( + commit_result.stdout.decode().strip() if commit_result.returncode == 0 else None ) - commit_hash = commit_result.stdout.decode().strip() if commit_result.returncode == 0 else None return { - 'branch_name': branch_name, - 'commit_hash': commit_hash[:8] if commit_hash else None, # Short hash + "branch_name": branch_name, + "commit_hash": commit_hash[:8] if commit_hash else None, # Short hash } except Exception: - return {'branch_name': None, 'commit_hash': None} + return {"branch_name": None, "commit_hash": None} def _create_document( - self, file_path: str, content: str, job: IngestionJob, chunk_index: int = 0, + self, + file_path: str, + content: str, + job: IngestionJob, + chunk_index: int = 0, git_metadata: dict[str, str | None] | None = None, - repo_info: dict[str, str] | None = None + repo_info: dict[str, str] | None = None, ) -> Document: """ Create a Document from repository content with enriched metadata. @@ -466,19 +475,19 @@ class RepomixIngestor(BaseIngestor): # Determine content type based on language content_type_map = { - 'python': 'text/x-python', - 'javascript': 'text/javascript', - 'typescript': 'text/typescript', - 'java': 'text/x-java-source', - 'html': 'text/html', - 'css': 'text/css', - 'json': 'application/json', - 'yaml': 'text/yaml', - 'markdown': 'text/markdown', - 'shell': 'text/x-shellscript', - 'sql': 'text/x-sql', + "python": "text/x-python", + "javascript": "text/javascript", + "typescript": "text/typescript", + "java": "text/x-java-source", + "html": "text/html", + "css": "text/css", + "json": "application/json", + "yaml": "text/yaml", + "markdown": "text/markdown", + "shell": "text/x-shellscript", + "sql": "text/x-sql", } - content_type = content_type_map.get(programming_language or 'text', 'text/plain') + content_type = content_type_map.get(programming_language or "text", "text/plain") # Build rich metadata metadata: DocumentMetadata = { @@ -487,8 +496,7 @@ class RepomixIngestor(BaseIngestor): "content_type": content_type, "word_count": len(content.split()), "char_count": len(content), - "title": f"{file_path}" - + (f" (chunk {chunk_index})" if chunk_index > 0 else ""), + "title": f"{file_path}" + (f" (chunk {chunk_index})" if chunk_index > 0 else ""), "description": f"Repository file: {file_path}", "category": "source_code" if programming_language else "documentation", "language": programming_language or "text", @@ -500,19 +508,23 @@ class RepomixIngestor(BaseIngestor): # Add repository info if available if repo_info: - metadata["repository_name"] = repo_info.get('repository_name') + metadata["repository_name"] = repo_info.get("repository_name") # Add git metadata if available if git_metadata: - metadata["branch_name"] = git_metadata.get('branch_name') - metadata["commit_hash"] = git_metadata.get('commit_hash') + metadata["branch_name"] = git_metadata.get("branch_name") + metadata["commit_hash"] = git_metadata.get("commit_hash") # Add code-specific metadata for programming files if programming_language and structure_info: # Calculate code quality score - total_lines = structure_info.get('line_count', 1) - non_empty_lines = structure_info.get('non_empty_lines', 0) - if isinstance(total_lines, int) and isinstance(non_empty_lines, int) and total_lines > 0: + total_lines = structure_info.get("line_count", 1) + non_empty_lines = structure_info.get("non_empty_lines", 0) + if ( + isinstance(total_lines, int) + and isinstance(non_empty_lines, int) + and total_lines > 0 + ): completeness_score = (non_empty_lines / total_lines) * 100 metadata["completeness_score"] = completeness_score @@ -566,5 +578,6 @@ class RepomixIngestor(BaseIngestor): await asyncio.wait_for(proc.wait(), timeout=2.0) except TimeoutError: import logging + logging.warning(f"Process {proc.pid} did not terminate cleanly") raise IngestionError(f"Command timed out: {' '.join(cmd)}") from e diff --git a/ingest_pipeline/storage/__init__.py b/ingest_pipeline/storage/__init__.py index 2e1bbfc..a580b0e 100644 --- a/ingest_pipeline/storage/__init__.py +++ b/ingest_pipeline/storage/__init__.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: try: from .r2r.storage import R2RStorage as _RuntimeR2RStorage + R2RStorage: type[BaseStorage] | None = _RuntimeR2RStorage except ImportError: R2RStorage = None diff --git a/ingest_pipeline/storage/__pycache__/__init__.cpython-312.pyc b/ingest_pipeline/storage/__pycache__/__init__.cpython-312.pyc index ad92182..ca112e2 100644 Binary files a/ingest_pipeline/storage/__pycache__/__init__.cpython-312.pyc and b/ingest_pipeline/storage/__pycache__/__init__.cpython-312.pyc differ diff --git a/ingest_pipeline/storage/__pycache__/base.cpython-312.pyc b/ingest_pipeline/storage/__pycache__/base.cpython-312.pyc index 0e0534a..47f351e 100644 Binary files a/ingest_pipeline/storage/__pycache__/base.cpython-312.pyc and b/ingest_pipeline/storage/__pycache__/base.cpython-312.pyc differ diff --git a/ingest_pipeline/storage/__pycache__/openwebui.cpython-312.pyc b/ingest_pipeline/storage/__pycache__/openwebui.cpython-312.pyc index bfec7cf..d88ec2f 100644 Binary files a/ingest_pipeline/storage/__pycache__/openwebui.cpython-312.pyc and b/ingest_pipeline/storage/__pycache__/openwebui.cpython-312.pyc differ diff --git a/ingest_pipeline/storage/__pycache__/weaviate.cpython-312.pyc b/ingest_pipeline/storage/__pycache__/weaviate.cpython-312.pyc index 4b20569..c98b719 100644 Binary files a/ingest_pipeline/storage/__pycache__/weaviate.cpython-312.pyc and b/ingest_pipeline/storage/__pycache__/weaviate.cpython-312.pyc differ diff --git a/ingest_pipeline/storage/base.py b/ingest_pipeline/storage/base.py index 2e3e8fa..34a5b36 100644 --- a/ingest_pipeline/storage/base.py +++ b/ingest_pipeline/storage/base.py @@ -1,10 +1,12 @@ """Base storage interface.""" +import asyncio import logging +import random from abc import ABC, abstractmethod from collections.abc import AsyncGenerator -from typing import Final from types import TracebackType +from typing import Final import httpx from pydantic import SecretStr @@ -37,6 +39,8 @@ class TypedHttpClient: api_key: SecretStr | None = None, timeout: float = 30.0, headers: dict[str, str] | None = None, + max_connections: int = 100, + max_keepalive_connections: int = 20, ): """ Initialize the typed HTTP client. @@ -46,6 +50,8 @@ class TypedHttpClient: api_key: Optional API key for authentication timeout: Request timeout in seconds headers: Additional headers to include with requests + max_connections: Maximum total connections in pool + max_keepalive_connections: Maximum keepalive connections """ self._base_url = base_url @@ -54,14 +60,14 @@ class TypedHttpClient: if api_key: client_headers["Authorization"] = f"Bearer {api_key.get_secret_value()}" - # Note: Pylance incorrectly reports "No parameter named 'base_url'" - # but base_url is a valid AsyncClient parameter (see HTTPX docs) - client_kwargs: dict[str, str | dict[str, str] | float] = { - "base_url": base_url, - "headers": client_headers, - "timeout": timeout, - } - self.client = httpx.AsyncClient(**client_kwargs) # type: ignore + # Create typed client configuration with connection pooling + limits = httpx.Limits( + max_connections=max_connections, max_keepalive_connections=max_keepalive_connections + ) + timeout_config = httpx.Timeout(connect=5.0, read=timeout, write=30.0, pool=10.0) + self.client = httpx.AsyncClient( + base_url=base_url, headers=client_headers, timeout=timeout_config, limits=limits + ) async def request( self, @@ -73,44 +79,92 @@ class TypedHttpClient: data: dict[str, object] | None = None, files: dict[str, tuple[str, bytes, str]] | None = None, params: dict[str, str | bool] | None = None, + max_retries: int = 3, + retry_delay: float = 1.0, ) -> httpx.Response | None: """ - Perform an HTTP request with consistent error handling. + Perform an HTTP request with consistent error handling and retries. Args: method: HTTP method (GET, POST, DELETE, etc.) path: URL path relative to base_url allow_404: If True, return None for 404 responses instead of raising - **kwargs: Arguments passed to httpx request + json: JSON data to send + data: Form data to send + files: Files to upload + params: Query parameters + max_retries: Maximum number of retry attempts + retry_delay: Base delay between retries in seconds Returns: HTTP response object, or None if allow_404=True and status is 404 Raises: - StorageError: If request fails + StorageError: If request fails after retries """ - try: - response = await self.client.request( # type: ignore - method, path, json=json, data=data, files=files, params=params - ) - response.raise_for_status() # type: ignore - return response # type: ignore - except Exception as e: - # Handle 404 as special case if requested - if allow_404 and hasattr(e, 'response') and getattr(e.response, 'status_code', None) == 404: # type: ignore - LOGGER.debug("Resource not found (404): %s %s", method, path) - return None + last_exception: Exception | None = None - # Convert all HTTP-related exceptions to StorageError - error_name = e.__class__.__name__ - if 'HTTP' in error_name or 'Connect' in error_name or 'Request' in error_name: - if hasattr(e, 'response') and hasattr(e.response, 'status_code'): # type: ignore - status_code = getattr(e.response, 'status_code', 'unknown') # type: ignore - raise StorageError(f"HTTP {status_code} error from {self._base_url}: {e}") from e - else: - raise StorageError(f"Request failed to {self._base_url}: {e}") from e - # Re-raise non-HTTP exceptions - raise + for attempt in range(max_retries + 1): + try: + response = await self.client.request( + method, path, json=json, data=data, files=files, params=params + ) + response.raise_for_status() + return response + except httpx.HTTPStatusError as e: + # Handle 404 as special case if requested + if allow_404 and e.response.status_code == 404: + LOGGER.debug("Resource not found (404): %s %s", method, path) + return None + + # Don't retry client errors (4xx except for specific cases) + if 400 <= e.response.status_code < 500 and e.response.status_code not in [429, 408]: + raise StorageError( + f"HTTP {e.response.status_code} error from {self._base_url}: {e}" + ) from e + + last_exception = e + if attempt < max_retries: + # Exponential backoff with jitter for retryable errors + delay = retry_delay * (2**attempt) + random.uniform(0, 1) + LOGGER.warning( + "HTTP %d error on attempt %d/%d, retrying in %.2fs: %s", + e.response.status_code, + attempt + 1, + max_retries + 1, + delay, + e, + ) + await asyncio.sleep(delay) + + except httpx.HTTPError as e: + last_exception = e + if attempt < max_retries: + # Retry transport errors with backoff + delay = retry_delay * (2**attempt) + random.uniform(0, 1) + LOGGER.warning( + "HTTP transport error on attempt %d/%d, retrying in %.2fs: %s", + attempt + 1, + max_retries + 1, + delay, + e, + ) + await asyncio.sleep(delay) + + # All retries exhausted - last_exception should always be set if we reach here + if last_exception is None: + raise StorageError( + f"Request to {self._base_url} failed after {max_retries + 1} attempts with unknown error" + ) + + if isinstance(last_exception, httpx.HTTPStatusError): + raise StorageError( + f"HTTP {last_exception.response.status_code} error from {self._base_url} after {max_retries + 1} attempts: {last_exception}" + ) from last_exception + else: + raise StorageError( + f"HTTP transport error to {self._base_url} after {max_retries + 1} attempts: {last_exception}" + ) from last_exception async def close(self) -> None: """Close the HTTP client and cleanup resources.""" @@ -127,7 +181,7 @@ class TypedHttpClient: self, exc_type: type[BaseException] | None, exc_val: BaseException | None, - exc_tb: TracebackType | None + exc_tb: TracebackType | None, ) -> None: """Async context manager exit.""" await self.close() @@ -224,11 +278,30 @@ class BaseStorage(ABC): # Check staleness if timestamp is available if "timestamp" in document.metadata: from datetime import UTC, datetime, timedelta + timestamp_obj = document.metadata["timestamp"] + + # Handle both datetime objects and ISO strings if isinstance(timestamp_obj, datetime): timestamp = timestamp_obj - cutoff = datetime.now(UTC) - timedelta(days=stale_after_days) - return timestamp >= cutoff + # Ensure timezone awareness + if timestamp.tzinfo is None: + timestamp = timestamp.replace(tzinfo=UTC) + elif isinstance(timestamp_obj, str): + try: + timestamp = datetime.fromisoformat(timestamp_obj) + # Ensure timezone awareness + if timestamp.tzinfo is None: + timestamp = timestamp.replace(tzinfo=UTC) + except ValueError: + # If parsing fails, assume document is stale + return False + else: + # Unknown timestamp format, assume stale + return False + + cutoff = datetime.now(UTC) - timedelta(days=stale_after_days) + return timestamp >= cutoff # If no timestamp, assume it exists and is valid return True diff --git a/ingest_pipeline/storage/openwebui.py b/ingest_pipeline/storage/openwebui.py index 8e42d0c..6b0b8e2 100644 --- a/ingest_pipeline/storage/openwebui.py +++ b/ingest_pipeline/storage/openwebui.py @@ -1,10 +1,10 @@ """Open WebUI storage adapter.""" - import asyncio import contextlib import logging -from typing import Final, TypedDict, cast +import time +from typing import Final, NamedTuple, TypedDict from typing_extensions import override @@ -18,6 +18,7 @@ LOGGER: Final[logging.Logger] = logging.getLogger(__name__) class OpenWebUIFileResponse(TypedDict, total=False): """OpenWebUI API file response structure.""" + id: str filename: str name: str @@ -29,6 +30,7 @@ class OpenWebUIFileResponse(TypedDict, total=False): class OpenWebUIKnowledgeBase(TypedDict, total=False): """OpenWebUI knowledge base response structure.""" + id: str name: str description: str @@ -38,13 +40,19 @@ class OpenWebUIKnowledgeBase(TypedDict, total=False): updated_at: str +class CacheEntry(NamedTuple): + """Cache entry with value and expiration time.""" + + value: str + expires_at: float class OpenWebUIStorage(BaseStorage): """Storage adapter for Open WebUI knowledge endpoints.""" http_client: TypedHttpClient - _knowledge_cache: dict[str, str] + _knowledge_cache: dict[str, CacheEntry] + _cache_ttl: float def __init__(self, config: StorageConfig): """ @@ -61,6 +69,7 @@ class OpenWebUIStorage(BaseStorage): timeout=30.0, ) self._knowledge_cache = {} + self._cache_ttl = 300.0 # 5 minutes TTL @override async def initialize(self) -> None: @@ -106,9 +115,25 @@ class OpenWebUIStorage(BaseStorage): return [] normalized: list[OpenWebUIKnowledgeBase] = [] for item in data: - if isinstance(item, dict): - # Cast to our expected structure - kb_item = cast(OpenWebUIKnowledgeBase, item) + if ( + isinstance(item, dict) + and "id" in item + and "name" in item + and isinstance(item["id"], str) + and isinstance(item["name"], str) + ): + # Create a new dict with known structure + kb_item: OpenWebUIKnowledgeBase = { + "id": item["id"], + "name": item["name"], + "description": item.get("description", ""), + "created_at": item.get("created_at", ""), + "updated_at": item.get("updated_at", ""), + } + if "files" in item and isinstance(item["files"], list): + kb_item["files"] = item["files"] + if "data" in item and isinstance(item["data"], dict): + kb_item["data"] = item["data"] normalized.append(kb_item) return normalized @@ -124,22 +149,29 @@ class OpenWebUIStorage(BaseStorage): if not target: raise StorageError("Knowledge base name is required") - if cached := self._knowledge_cache.get(target): - return cached + # Check cache with TTL + if cached_entry := self._knowledge_cache.get(target): + if time.time() < cached_entry.expires_at: + return cached_entry.value + else: + # Entry expired, remove it + del self._knowledge_cache[target] knowledge_bases = await self._fetch_knowledge_bases() for kb in knowledge_bases: if kb.get("name") == target: kb_id = kb.get("id") if isinstance(kb_id, str): - self._knowledge_cache[target] = kb_id + expires_at = time.time() + self._cache_ttl + self._knowledge_cache[target] = CacheEntry(kb_id, expires_at) return kb_id if not create: return None knowledge_id = await self._create_collection(target) - self._knowledge_cache[target] = knowledge_id + expires_at = time.time() + self._cache_ttl + self._knowledge_cache[target] = CacheEntry(knowledge_id, expires_at) return knowledge_id @override @@ -165,7 +197,7 @@ class OpenWebUIStorage(BaseStorage): # Use document title from metadata if available, otherwise fall back to ID filename = document.metadata.get("title") or f"doc_{document.id}" # Ensure filename has proper extension - if not filename.endswith(('.txt', '.md', '.pdf', '.doc', '.docx')): + if not filename.endswith((".txt", ".md", ".pdf", ".doc", ".docx")): filename = f"{filename}.txt" files = {"file": (filename, document.content.encode(), "text/plain")} response = await self.http_client.request( @@ -185,11 +217,9 @@ class OpenWebUIStorage(BaseStorage): # Step 2: Add file to knowledge base response = await self.http_client.request( - "POST", - f"/api/v1/knowledge/{knowledge_id}/file/add", - json={"file_id": file_id} + "POST", f"/api/v1/knowledge/{knowledge_id}/file/add", json={"file_id": file_id} ) - + return str(file_id) except Exception as e: @@ -220,7 +250,7 @@ class OpenWebUIStorage(BaseStorage): # Use document title from metadata if available, otherwise fall back to ID filename = doc.metadata.get("title") or f"doc_{doc.id}" # Ensure filename has proper extension - if not filename.endswith(('.txt', '.md', '.pdf', '.doc', '.docx')): + if not filename.endswith((".txt", ".md", ".pdf", ".doc", ".docx")): filename = f"{filename}.txt" files = {"file": (filename, doc.content.encode(), "text/plain")} upload_response = await self.http_client.request( @@ -230,7 +260,9 @@ class OpenWebUIStorage(BaseStorage): params={"process": True, "process_in_background": False}, ) if upload_response is None: - raise StorageError(f"Unexpected None response from file upload for document {doc.id}") + raise StorageError( + f"Unexpected None response from file upload for document {doc.id}" + ) file_data = upload_response.json() file_id = file_data.get("id") @@ -241,9 +273,7 @@ class OpenWebUIStorage(BaseStorage): ) await self.http_client.request( - "POST", - f"/api/v1/knowledge/{knowledge_id}/file/add", - json={"file_id": file_id} + "POST", f"/api/v1/knowledge/{knowledge_id}/file/add", json={"file_id": file_id} ) return str(file_id) @@ -259,7 +289,8 @@ class OpenWebUIStorage(BaseStorage): if isinstance(result, Exception): failures.append(f"{doc.id}: {result}") else: - file_ids.append(cast(str, result)) + if isinstance(result, str): + file_ids.append(result) if failures: LOGGER.warning( @@ -289,10 +320,86 @@ class OpenWebUIStorage(BaseStorage): """ _ = document_id, collection_name # Mark as used # OpenWebUI uses file-based storage without direct document retrieval - # This will cause the base check_exists method to return False, - # which means documents will always be re-scraped for OpenWebUI raise NotImplementedError("OpenWebUI doesn't support document retrieval by ID") + @override + async def check_exists( + self, document_id: str, *, collection_name: str | None = None, stale_after_days: int = 30 + ) -> bool: + """ + Check if a document exists in OpenWebUI knowledge base by searching files. + + Args: + document_id: Document ID to check (usually based on source URL) + collection_name: Knowledge base name + stale_after_days: Consider document stale after this many days + + Returns: + True if document exists and is not stale, False otherwise + """ + try: + from datetime import UTC, datetime, timedelta + + # Get knowledge base + knowledge_id = await self._get_knowledge_id(collection_name, create=False) + if not knowledge_id: + return False + + # Get detailed knowledge base info to access files + response = await self.http_client.request("GET", f"/api/v1/knowledge/{knowledge_id}") + if response is None: + return False + + kb_data = response.json() + files = kb_data.get("files", []) + + # Look for file with matching document ID or source URL in metadata + cutoff = datetime.now(UTC) - timedelta(days=stale_after_days) + + def _parse_openwebui_timestamp(timestamp_str: str) -> datetime | None: + """Parse OpenWebUI timestamp with proper timezone handling.""" + try: + # Handle both 'Z' suffix and explicit timezone + normalized = timestamp_str.replace("Z", "+00:00") + parsed = datetime.fromisoformat(normalized) + # Ensure timezone awareness + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=UTC) + return parsed + except (ValueError, AttributeError): + return None + + def _check_file_freshness(file_info: dict[str, object]) -> bool: + """Check if file is fresh enough based on creation date.""" + created_at = file_info.get("created_at") + if not isinstance(created_at, str): + # No date info available, consider stale to be safe + return False + + file_date = _parse_openwebui_timestamp(created_at) + return file_date is not None and file_date >= cutoff + + for file_info in files: + if not isinstance(file_info, dict): + continue + + file_id = file_info.get("id") + if str(file_id) == document_id: + return _check_file_freshness(file_info) + + # Also check meta.source_url if available for URL-based document IDs + meta = file_info.get("meta", {}) + if isinstance(meta, dict): + source_url = meta.get("source_url") + if source_url and document_id in str(source_url): + return _check_file_freshness(file_info) + + return False + + except Exception as e: + LOGGER.debug("Error checking document existence in OpenWebUI: %s", e) + return False + @override async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool: """ @@ -316,14 +423,10 @@ class OpenWebUIStorage(BaseStorage): await self.http_client.request( "POST", f"/api/v1/knowledge/{knowledge_id}/file/remove", - json={"file_id": document_id} + json={"file_id": document_id}, ) - await self.http_client.request( - "DELETE", - f"/api/v1/files/{document_id}", - allow_404=True - ) + await self.http_client.request("DELETE", f"/api/v1/files/{document_id}", allow_404=True) return True except Exception as exc: LOGGER.error("Error deleting file %s from OpenWebUI", document_id, exc_info=exc) @@ -366,9 +469,7 @@ class OpenWebUIStorage(BaseStorage): # Delete the knowledge base using the OpenWebUI API await self.http_client.request( - "DELETE", - f"/api/v1/knowledge/{knowledge_id}/delete", - allow_404=True + "DELETE", f"/api/v1/knowledge/{knowledge_id}/delete", allow_404=True ) # Remove from cache if it exists @@ -379,13 +480,16 @@ class OpenWebUIStorage(BaseStorage): return True except Exception as e: - if hasattr(e, 'response'): - response_attr = getattr(e, 'response', None) - if response_attr is not None and hasattr(response_attr, 'status_code'): + if hasattr(e, "response"): + response_attr = getattr(e, "response", None) + if response_attr is not None and hasattr(response_attr, "status_code"): with contextlib.suppress(Exception): - status_code = response_attr.status_code # type: ignore[attr-defined] + status_code = response_attr.status_code if status_code == 404: - LOGGER.info("Knowledge base %s was already deleted or not found", collection_name) + LOGGER.info( + "Knowledge base %s was already deleted or not found", + collection_name, + ) return True LOGGER.error( "Error deleting knowledge base %s from OpenWebUI", @@ -394,8 +498,6 @@ class OpenWebUIStorage(BaseStorage): ) return False - - async def _get_knowledge_base_count(self, kb: OpenWebUIKnowledgeBase) -> int: """Get the file count for a knowledge base.""" kb_id = kb.get("id") @@ -411,14 +513,13 @@ class OpenWebUIStorage(BaseStorage): files = kb.get("files", []) return len(files) if isinstance(files, list) and files is not None else 0 - async def _count_files_from_detailed_info(self, kb_id: str, name: str, kb: OpenWebUIKnowledgeBase) -> int: + async def _count_files_from_detailed_info( + self, kb_id: str, name: str, kb: OpenWebUIKnowledgeBase + ) -> int: """Count files by fetching detailed knowledge base info.""" try: LOGGER.debug(f"Fetching detailed info for KB '{name}' from /api/v1/knowledge/{kb_id}") - detail_response = await self.http_client.request( - "GET", - f"/api/v1/knowledge/{kb_id}" - ) + detail_response = await self.http_client.request("GET", f"/api/v1/knowledge/{kb_id}") if detail_response is None: LOGGER.warning(f"Knowledge base '{name}' (ID: {kb_id}) not found") return self._count_files_from_basic_info(kb) @@ -489,10 +590,7 @@ class OpenWebUIStorage(BaseStorage): return 0 # Get detailed knowledge base information to get accurate file count - detail_response = await self.http_client.request( - "GET", - f"/api/v1/knowledge/{kb_id}" - ) + detail_response = await self.http_client.request("GET", f"/api/v1/knowledge/{kb_id}") if detail_response is None: LOGGER.warning(f"Knowledge base '{collection_name}' (ID: {kb_id}) not found") return self._count_files_from_basic_info(kb) @@ -524,14 +622,28 @@ class OpenWebUIStorage(BaseStorage): return None knowledge_bases = response.json() - return next( - ( - cast(OpenWebUIKnowledgeBase, kb) - for kb in knowledge_bases - if isinstance(kb, dict) and kb.get("name") == name - ), - None, - ) + # Find and properly type the knowledge base + for kb in knowledge_bases: + if ( + isinstance(kb, dict) + and kb.get("name") == name + and "id" in kb + and isinstance(kb["id"], str) + ): + # Create properly typed response + result: OpenWebUIKnowledgeBase = { + "id": kb["id"], + "name": str(kb["name"]), + "description": kb.get("description", ""), + "created_at": kb.get("created_at", ""), + "updated_at": kb.get("updated_at", ""), + } + if "files" in kb and isinstance(kb["files"], list): + result["files"] = kb["files"] + if "data" in kb and isinstance(kb["data"], dict): + result["data"] = kb["data"] + return result + return None except Exception as e: raise StorageError(f"Failed to get knowledge base by name: {e}") from e @@ -623,7 +735,9 @@ class OpenWebUIStorage(BaseStorage): elif isinstance(file_info.get("meta"), dict): meta = file_info.get("meta") if isinstance(meta, dict): - filename = meta.get("name") + filename_value = meta.get("name") + if isinstance(filename_value, str): + filename = filename_value # Final fallback if not filename: @@ -635,9 +749,11 @@ class OpenWebUIStorage(BaseStorage): size = 0 meta = file_info.get("meta") if isinstance(meta, dict): - size = meta.get("size", 0) + size_value = meta.get("size", 0) + size = int(size_value) if isinstance(size_value, (int, float)) else 0 else: - size = file_info.get("size", 0) + size_value = file_info.get("size", 0) + size = int(size_value) if isinstance(size_value, (int, float)) else 0 # Estimate word count from file size (very rough approximation) word_count = max(1, int(size / 6)) if isinstance(size, (int, float)) else 0 @@ -650,9 +766,7 @@ class OpenWebUIStorage(BaseStorage): "content_type": str(file_info.get("content_type", "text/plain")), "content_preview": f"File uploaded to OpenWebUI: {filename}", "word_count": word_count, - "timestamp": str( - file_info.get("created_at") or file_info.get("timestamp", "") - ), + "timestamp": str(file_info.get("created_at") or file_info.get("timestamp", "")), } documents.append(doc_info) diff --git a/ingest_pipeline/storage/r2r/__pycache__/__init__.cpython-312.pyc b/ingest_pipeline/storage/r2r/__pycache__/__init__.cpython-312.pyc index b4c1153..bf0ba16 100644 Binary files a/ingest_pipeline/storage/r2r/__pycache__/__init__.cpython-312.pyc and b/ingest_pipeline/storage/r2r/__pycache__/__init__.cpython-312.pyc differ diff --git a/ingest_pipeline/storage/r2r/__pycache__/storage.cpython-312.pyc b/ingest_pipeline/storage/r2r/__pycache__/storage.cpython-312.pyc index bfe63ed..a3e6e82 100644 Binary files a/ingest_pipeline/storage/r2r/__pycache__/storage.cpython-312.pyc and b/ingest_pipeline/storage/r2r/__pycache__/storage.cpython-312.pyc differ diff --git a/ingest_pipeline/storage/r2r/storage.py b/ingest_pipeline/storage/r2r/storage.py index 1c43302..bbe0719 100644 --- a/ingest_pipeline/storage/r2r/storage.py +++ b/ingest_pipeline/storage/r2r/storage.py @@ -4,9 +4,10 @@ from __future__ import annotations import asyncio import contextlib +import logging from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence from datetime import UTC, datetime -from typing import Self, TypeVar, cast +from typing import Final, Self, TypeVar, cast from uuid import UUID, uuid4 # Direct imports for runtime and type checking @@ -19,6 +20,8 @@ from ...core.models import Document, DocumentMetadata, IngestionSource, StorageC from ..base import BaseStorage from ..types import DocumentInfo +LOGGER: Final[logging.Logger] = logging.getLogger(__name__) + T = TypeVar("T") @@ -183,8 +186,10 @@ class R2RStorage(BaseStorage): ) -> list[str]: """Store multiple documents efficiently with connection reuse.""" collection_id = await self._resolve_collection_id(collection_name) - print( - f"Using collection ID: {collection_id} for collection: {collection_name or self.config.collection_name}" + LOGGER.info( + "Using collection ID: %s for collection: %s", + collection_id, + collection_name or self.config.collection_name, ) # Filter valid documents upfront @@ -218,7 +223,7 @@ class R2RStorage(BaseStorage): if isinstance(result, str): stored_ids.append(result) elif isinstance(result, Exception): - print(f"Document upload failed: {result}") + LOGGER.error("Document upload failed: %s", result) return stored_ids @@ -239,12 +244,14 @@ class R2RStorage(BaseStorage): requested_id = str(document.id) if not document.content or not document.content.strip(): - print(f"Skipping document {requested_id}: empty content") + LOGGER.warning("Skipping document %s: empty content", requested_id) return False if len(document.content) > 1_000_000: # 1MB limit - print( - f"Skipping document {requested_id}: content too large ({len(document.content)} chars)" + LOGGER.warning( + "Skipping document %s: content too large (%d chars)", + requested_id, + len(document.content), ) return False @@ -263,7 +270,7 @@ class R2RStorage(BaseStorage): ) -> str | None: """Store a single document with retry logic using provided HTTP client.""" requested_id = str(document.id) - print(f"Creating document with ID: {requested_id}") + LOGGER.debug("Creating document with ID: %s", requested_id) max_retries = 3 retry_delay = 1.0 @@ -313,7 +320,7 @@ class R2RStorage(BaseStorage): requested_id = str(document.id) metadata = self._build_metadata(document) - print(f"Built metadata for document {requested_id}: {metadata}") + LOGGER.debug("Built metadata for document %s: %s", requested_id, metadata) files = { "raw_text": (None, document.content), @@ -324,10 +331,12 @@ class R2RStorage(BaseStorage): if collection_id: files["collection_ids"] = (None, json.dumps([collection_id])) - print(f"Creating document {requested_id} with collection_ids: [{collection_id}]") + LOGGER.debug( + "Creating document %s with collection_ids: [%s]", requested_id, collection_id + ) - print(f"Sending to R2R - files keys: {list(files.keys())}") - print(f"Metadata JSON: {files['metadata'][1]}") + LOGGER.debug("Sending to R2R - files keys: %s", list(files.keys())) + LOGGER.debug("Metadata JSON: %s", files["metadata"][1]) response = await http_client.post(f"{self.endpoint}/v3/documents", files=files) # type: ignore[call-arg] @@ -346,15 +355,17 @@ class R2RStorage(BaseStorage): error_detail = ( getattr(response, "json", lambda: {})() if hasattr(response, "json") else {} ) - print(f"R2R validation error for document {requested_id}: {error_detail}") - print(f"Document metadata sent: {metadata}") - print(f"Response status: {getattr(response, 'status_code', 'unknown')}") - print(f"Response headers: {dict(getattr(response, 'headers', {}))}") + LOGGER.error("R2R validation error for document %s: %s", requested_id, error_detail) + LOGGER.error("Document metadata sent: %s", metadata) + LOGGER.error("Response status: %s", getattr(response, "status_code", "unknown")) + LOGGER.error("Response headers: %s", dict(getattr(response, "headers", {}))) except Exception: - print( - f"R2R validation error for document {requested_id}: {getattr(response, 'text', 'unknown error')}" + LOGGER.error( + "R2R validation error for document %s: %s", + requested_id, + getattr(response, "text", "unknown error"), ) - print(f"Document metadata sent: {metadata}") + LOGGER.error("Document metadata sent: %s", metadata) def _process_document_response( self, doc_response: dict[str, object], requested_id: str, collection_id: str @@ -363,14 +374,16 @@ class R2RStorage(BaseStorage): response_payload = doc_response.get("results", doc_response) doc_id = _extract_id(response_payload, requested_id) - print(f"R2R returned document ID: {doc_id}") + LOGGER.info("R2R returned document ID: %s", doc_id) if doc_id != requested_id: - print(f"Warning: Requested ID {requested_id} but got {doc_id}") + LOGGER.warning("Requested ID %s but got %s", requested_id, doc_id) if collection_id: - print( - f"Document {doc_id} should be assigned to collection {collection_id} via creation API" + LOGGER.info( + "Document %s should be assigned to collection %s via creation API", + doc_id, + collection_id, ) return doc_id @@ -387,7 +400,7 @@ class R2RStorage(BaseStorage): if attempt >= max_retries - 1: return False - print(f"Timeout for document {requested_id}, retrying in {retry_delay}s...") + LOGGER.warning("Timeout for document %s, retrying in %ss...", requested_id, retry_delay) await asyncio.sleep(retry_delay) return True @@ -404,23 +417,26 @@ class R2RStorage(BaseStorage): if status_code < 500 or attempt >= max_retries - 1: return False - print( - f"Server error {status_code} for document {requested_id}, retrying in {retry_delay}s..." + LOGGER.warning( + "Server error %s for document %s, retrying in %ss...", + status_code, + requested_id, + retry_delay, ) await asyncio.sleep(retry_delay) return True def _log_document_error(self, document_id: object, exc: Exception) -> None: """Log document storage errors with specific categorization.""" - print(f"Failed to store document {document_id}: {exc}") + LOGGER.error("Failed to store document %s: %s", document_id, exc) exc_str = str(exc) if "422" in exc_str: - print(" → Data validation issue - check document content and metadata format") + LOGGER.error(" → Data validation issue - check document content and metadata format") elif "timeout" in exc_str.lower(): - print(" → Network timeout - R2R may be overloaded") + LOGGER.error(" → Network timeout - R2R may be overloaded") elif "500" in exc_str: - print(" → Server error - R2R internal issue") + LOGGER.error(" → Server error - R2R internal issue") else: import traceback diff --git a/ingest_pipeline/storage/types.py b/ingest_pipeline/storage/types.py index 5e5a4f4..62ad0c8 100644 --- a/ingest_pipeline/storage/types.py +++ b/ingest_pipeline/storage/types.py @@ -5,6 +5,7 @@ from typing import TypedDict class CollectionSummary(TypedDict): """Collection metadata for describe_collections.""" + name: str count: int size_mb: float @@ -12,6 +13,7 @@ class CollectionSummary(TypedDict): class DocumentInfo(TypedDict): """Document information for list_documents.""" + id: str title: str source_url: str @@ -19,4 +21,4 @@ class DocumentInfo(TypedDict): content_type: str content_preview: str word_count: int - timestamp: str \ No newline at end of file + timestamp: str diff --git a/ingest_pipeline/storage/weaviate.py b/ingest_pipeline/storage/weaviate.py index 2ee92ec..1cdfbc5 100644 --- a/ingest_pipeline/storage/weaviate.py +++ b/ingest_pipeline/storage/weaviate.py @@ -1,8 +1,9 @@ """Weaviate storage adapter.""" -from collections.abc import AsyncGenerator, Mapping, Sequence +import asyncio +from collections.abc import AsyncGenerator, Callable, Mapping, Sequence from datetime import UTC, datetime -from typing import Literal, Self, TypeAlias, cast, overload +from typing import Literal, Self, TypeAlias, TypeVar, cast, overload from uuid import UUID import weaviate @@ -24,6 +25,7 @@ from .base import BaseStorage from .types import CollectionSummary, DocumentInfo VectorContainer: TypeAlias = Mapping[str, object] | Sequence[object] | None +T = TypeVar("T") class WeaviateStorage(BaseStorage): @@ -45,6 +47,28 @@ class WeaviateStorage(BaseStorage): self.vectorizer = Vectorizer(config) self._default_collection = self._normalize_collection_name(config.collection_name) + async def _run_sync(self, func: Callable[..., T], *args: object, **kwargs: object) -> T: + """ + Run synchronous Weaviate operations in thread pool to avoid blocking event loop. + + Args: + func: Synchronous function to run + *args: Positional arguments for the function + **kwargs: Keyword arguments for the function + + Returns: + Result of the function call + + Raises: + StorageError: If the operation fails + """ + try: + return await asyncio.to_thread(func, *args, **kwargs) + except (WeaviateConnectionError, WeaviateBatchError, WeaviateQueryError) as e: + raise StorageError(f"Weaviate operation failed: {e}") from e + except Exception as e: + raise StorageError(f"Unexpected error in Weaviate operation: {e}") from e + @override async def initialize(self) -> None: """Initialize Weaviate client and create collection if needed.""" @@ -53,7 +77,7 @@ class WeaviateStorage(BaseStorage): self.client = weaviate.WeaviateClient( connection_params=weaviate.connect.ConnectionParams.from_url( url=str(self.config.endpoint), - grpc_port=50051, # Default gRPC port + grpc_port=self.config.grpc_port or 50051, ), additional_config=weaviate.classes.init.AdditionalConfig( timeout=weaviate.classes.init.Timeout(init=30, query=60, insert=120), @@ -61,7 +85,7 @@ class WeaviateStorage(BaseStorage): ) # Connect to the client - self.client.connect() + await self._run_sync(self.client.connect) # Ensure the default collection exists await self._ensure_collection(self._default_collection) @@ -76,8 +100,8 @@ class WeaviateStorage(BaseStorage): if not self.client: raise StorageError("Weaviate client not initialized") try: - client = cast(weaviate.WeaviateClient, self.client) - client.collections.create( + await self._run_sync( + self.client.collections.create, name=collection_name, properties=[ Property( @@ -106,7 +130,7 @@ class WeaviateStorage(BaseStorage): ], vectorizer_config=Configure.Vectorizer.none(), ) - except Exception as e: + except (WeaviateConnectionError, WeaviateBatchError) as e: raise StorageError(f"Failed to create collection: {e}") from e @staticmethod @@ -114,13 +138,9 @@ class WeaviateStorage(BaseStorage): """Normalize vector payloads returned by Weaviate into a float list.""" if isinstance(vector_raw, Mapping): default_vector = vector_raw.get("default") - return WeaviateStorage._extract_vector( - cast(VectorContainer, default_vector) - ) + return WeaviateStorage._extract_vector(cast(VectorContainer, default_vector)) - if not isinstance(vector_raw, Sequence) or isinstance( - vector_raw, (str, bytes, bytearray) - ): + if not isinstance(vector_raw, Sequence) or isinstance(vector_raw, (str, bytes, bytearray)): return None items = list(vector_raw) @@ -135,9 +155,7 @@ class WeaviateStorage(BaseStorage): except (TypeError, ValueError): return None - if isinstance(first_item, Sequence) and not isinstance( - first_item, (str, bytes, bytearray) - ): + if isinstance(first_item, Sequence) and not isinstance(first_item, (str, bytes, bytearray)): inner_items = list(first_item) if all(isinstance(item, (int, float)) for item in inner_items): try: @@ -168,8 +186,7 @@ class WeaviateStorage(BaseStorage): properties: object, *, context: str, - ) -> Mapping[str, object]: - ... + ) -> Mapping[str, object]: ... @staticmethod @overload @@ -178,8 +195,7 @@ class WeaviateStorage(BaseStorage): *, context: str, allow_missing: Literal[False], - ) -> Mapping[str, object]: - ... + ) -> Mapping[str, object]: ... @staticmethod @overload @@ -188,8 +204,7 @@ class WeaviateStorage(BaseStorage): *, context: str, allow_missing: Literal[True], - ) -> Mapping[str, object] | None: - ... + ) -> Mapping[str, object] | None: ... @staticmethod def _coerce_properties( @@ -211,6 +226,29 @@ class WeaviateStorage(BaseStorage): return cast(Mapping[str, object], properties) + @staticmethod + def _build_document_properties(doc: Document) -> dict[str, object]: + """ + Build Weaviate properties dict from document. + + Args: + doc: Document to build properties for + + Returns: + Properties dict suitable for Weaviate + """ + return { + "content": doc.content, + "source_url": doc.metadata["source_url"], + "title": doc.metadata.get("title", ""), + "description": doc.metadata.get("description", ""), + "timestamp": doc.metadata["timestamp"].isoformat(), + "content_type": doc.metadata["content_type"], + "word_count": doc.metadata["word_count"], + "char_count": doc.metadata["char_count"], + "source": doc.source.value, + } + def _normalize_collection_name(self, collection_name: str | None) -> str: """Return a canonicalized collection name, defaulting to configured value.""" candidate = collection_name or self.config.collection_name @@ -227,7 +265,7 @@ class WeaviateStorage(BaseStorage): if not self.client: raise StorageError("Weaviate client not initialized") - client = cast(weaviate.WeaviateClient, self.client) + client = self.client existing = client.collections.list_all() if collection_name not in existing: await self._create_collection(collection_name) @@ -247,7 +285,7 @@ class WeaviateStorage(BaseStorage): if ensure_exists: await self._ensure_collection(normalized) - client = cast(weaviate.WeaviateClient, self.client) + client = self.client return client.collections.get(normalized), normalized @override @@ -271,26 +309,19 @@ class WeaviateStorage(BaseStorage): ) # Prepare properties - properties = { - "content": document.content, - "source_url": document.metadata["source_url"], - "title": document.metadata.get("title", ""), - "description": document.metadata.get("description", ""), - "timestamp": document.metadata["timestamp"].isoformat(), - "content_type": document.metadata["content_type"], - "word_count": document.metadata["word_count"], - "char_count": document.metadata["char_count"], - "source": document.source.value, - } + properties = self._build_document_properties(document) # Insert with vector - result = collection.data.insert( - properties=properties, vector=document.vector, uuid=str(document.id) + result = await self._run_sync( + collection.data.insert, + properties=properties, + vector=document.vector, + uuid=str(document.id), ) return str(result) - except Exception as e: + except (WeaviateConnectionError, WeaviateBatchError, WeaviateQueryError) as e: raise StorageError(f"Failed to store document: {e}") from e @override @@ -311,32 +342,24 @@ class WeaviateStorage(BaseStorage): collection_name, ensure_exists=True ) - # Vectorize documents without vectors - for doc in documents: - if doc.vector is None: - doc.vector = await self.vectorizer.vectorize(doc.content) + # Vectorize documents without vectors using batch processing + to_vectorize = [(i, doc) for i, doc in enumerate(documents) if doc.vector is None] + if to_vectorize: + contents = [doc.content for _, doc in to_vectorize] + vectors = await self.vectorizer.vectorize_batch(contents) + for (idx, _), vector in zip(to_vectorize, vectors, strict=False): + documents[idx].vector = vector # Prepare batch data for insert_many batch_objects = [] for doc in documents: - properties = { - "content": doc.content, - "source_url": doc.metadata["source_url"], - "title": doc.metadata.get("title", ""), - "description": doc.metadata.get("description", ""), - "timestamp": doc.metadata["timestamp"].isoformat(), - "content_type": doc.metadata["content_type"], - "word_count": doc.metadata["word_count"], - "char_count": doc.metadata["char_count"], - "source": doc.source.value, - } - + properties = self._build_document_properties(doc) batch_objects.append( DataObject(properties=properties, vector=doc.vector, uuid=str(doc.id)) ) # Insert batch using insert_many - response = collection.data.insert_many(batch_objects) + response = await self._run_sync(collection.data.insert_many, batch_objects) successful_ids: list[str] = [] error_indices = set(response.errors.keys()) if response else set() @@ -361,11 +384,7 @@ class WeaviateStorage(BaseStorage): return successful_ids - except WeaviateBatchError as e: - raise StorageError(f"Batch operation failed: {e}") from e - except WeaviateConnectionError as e: - raise StorageError(f"Connection to Weaviate failed: {e}") from e - except Exception as e: + except (WeaviateBatchError, WeaviateConnectionError, WeaviateQueryError) as e: raise StorageError(f"Failed to store batch: {e}") from e @override @@ -385,7 +404,7 @@ class WeaviateStorage(BaseStorage): collection, resolved_name = await self._prepare_collection( collection_name, ensure_exists=False ) - result = collection.query.fetch_object_by_id(document_id) + result = await self._run_sync(collection.query.fetch_object_by_id, document_id) if not result: return None @@ -395,13 +414,30 @@ class WeaviateStorage(BaseStorage): result.properties, context="fetch_object_by_id", ) + # Parse timestamp to datetime for consistent metadata format + from datetime import UTC, datetime + + timestamp_raw = props.get("timestamp") + timestamp_parsed: datetime + try: + if isinstance(timestamp_raw, str): + timestamp_parsed = datetime.fromisoformat(timestamp_raw) + if timestamp_parsed.tzinfo is None: + timestamp_parsed = timestamp_parsed.replace(tzinfo=UTC) + elif isinstance(timestamp_raw, datetime): + timestamp_parsed = timestamp_raw + if timestamp_parsed.tzinfo is None: + timestamp_parsed = timestamp_parsed.replace(tzinfo=UTC) + else: + timestamp_parsed = datetime.now(UTC) + except (ValueError, TypeError): + timestamp_parsed = datetime.now(UTC) + metadata_dict = { "source_url": str(props["source_url"]), "title": str(props.get("title")) if props.get("title") else None, - "description": str(props.get("description")) - if props.get("description") - else None, - "timestamp": str(props["timestamp"]), + "description": str(props.get("description")) if props.get("description") else None, + "timestamp": timestamp_parsed, "content_type": str(props["content_type"]), "word_count": int(str(props["word_count"])), "char_count": int(str(props["char_count"])), @@ -424,11 +460,13 @@ class WeaviateStorage(BaseStorage): except WeaviateConnectionError as e: # Connection issues should be logged and return None import logging + logging.warning(f"Weaviate connection error retrieving document {document_id}: {e}") return None except Exception as e: # Log unexpected errors for debugging import logging + logging.warning(f"Unexpected error retrieving document {document_id}: {e}") return None @@ -437,9 +475,7 @@ class WeaviateStorage(BaseStorage): metadata_dict = { "source_url": str(props["source_url"]), "title": str(props.get("title")) if props.get("title") else None, - "description": str(props.get("description")) - if props.get("description") - else None, + "description": str(props.get("description")) if props.get("description") else None, "timestamp": str(props["timestamp"]), "content_type": str(props["content_type"]), "word_count": int(str(props["word_count"])), @@ -462,6 +498,7 @@ class WeaviateStorage(BaseStorage): return max(0.0, 1.0 - distance_value) except (TypeError, ValueError) as e: import logging + logging.debug(f"Invalid distance value {raw_distance}: {e}") return None @@ -506,37 +543,39 @@ class WeaviateStorage(BaseStorage): collection_name: str | None = None, ) -> AsyncGenerator[Document, None]: """ - Search for documents in Weaviate. + Search for documents in Weaviate using hybrid search. Args: query: Search query limit: Maximum results - threshold: Similarity threshold + threshold: Similarity threshold (not used in hybrid search) Yields: Matching documents """ try: - query_vector = await self.vectorizer.vectorize(query) + if not self.client: + raise StorageError("Weaviate client not initialized") + collection, resolved_name = await self._prepare_collection( collection_name, ensure_exists=False ) - results = collection.query.near_vector( - near_vector=query_vector, - limit=limit, - distance=1 - threshold, - return_metadata=["distance"], - ) + # Try hybrid search first, fall back to BM25 keyword search + try: + response = await self._run_sync( + collection.query.hybrid, query=query, limit=limit, return_metadata=["score"] + ) + except (WeaviateQueryError, StorageError): + # Fall back to BM25 if hybrid search is not supported or fails + response = await self._run_sync( + collection.query.bm25, query=query, limit=limit, return_metadata=["score"] + ) - for result in results.objects: - yield self._build_search_document(result, resolved_name) + for obj in response.objects: + yield self._build_document_from_search(obj, resolved_name) - except WeaviateQueryError as e: - raise StorageError(f"Search query failed: {e}") from e - except WeaviateConnectionError as e: - raise StorageError(f"Connection to Weaviate failed during search: {e}") from e - except Exception as e: + except (WeaviateQueryError, WeaviateConnectionError) as e: raise StorageError(f"Search failed: {e}") from e @override @@ -552,7 +591,7 @@ class WeaviateStorage(BaseStorage): """ try: collection, _ = await self._prepare_collection(collection_name, ensure_exists=False) - collection.data.delete_by_id(document_id) + await self._run_sync(collection.data.delete_by_id, document_id) return True except WeaviateQueryError as e: raise StorageError(f"Delete operation failed: {e}") from e @@ -589,10 +628,10 @@ class WeaviateStorage(BaseStorage): if not self.client: raise StorageError("Weaviate client not initialized") - client = cast(weaviate.WeaviateClient, self.client) + client = self.client return list(client.collections.list_all()) - except Exception as e: + except WeaviateConnectionError as e: raise StorageError(f"Failed to list collections: {e}") from e async def describe_collections(self) -> list[CollectionSummary]: @@ -601,7 +640,7 @@ class WeaviateStorage(BaseStorage): raise StorageError("Weaviate client not initialized") try: - client = cast(weaviate.WeaviateClient, self.client) + client = self.client collections: list[CollectionSummary] = [] for name in client.collections.list_all(): collection_obj = client.collections.get(name) @@ -639,7 +678,7 @@ class WeaviateStorage(BaseStorage): ) # Query for sample documents - response = collection.query.fetch_objects(limit=limit) + response = await self._run_sync(collection.query.fetch_objects, limit=limit) documents = [] for obj in response.objects: @@ -712,9 +751,7 @@ class WeaviateStorage(BaseStorage): return { "source_url": str(props.get("source_url", "")), "title": str(props.get("title", "")) if props.get("title") else None, - "description": str(props.get("description", "")) - if props.get("description") - else None, + "description": str(props.get("description", "")) if props.get("description") else None, "timestamp": datetime.fromisoformat( str(props.get("timestamp", datetime.now(UTC).isoformat())) ), @@ -737,6 +774,7 @@ class WeaviateStorage(BaseStorage): return float(raw_score) except (TypeError, ValueError) as e: import logging + logging.debug(f"Invalid score value {raw_score}: {e}") return None @@ -780,31 +818,11 @@ class WeaviateStorage(BaseStorage): Returns: List of matching documents """ - try: - if not self.client: - raise StorageError("Weaviate client not initialized") - - collection, resolved_name = await self._prepare_collection( - collection_name, ensure_exists=False - ) - - # Try hybrid search first, fall back to BM25 keyword search - try: - response = collection.query.hybrid( - query=query, limit=limit, return_metadata=["score"] - ) - except Exception: - response = collection.query.bm25( - query=query, limit=limit, return_metadata=["score"] - ) - - return [ - self._build_document_from_search(obj, resolved_name) - for obj in response.objects - ] - - except Exception as e: - raise StorageError(f"Failed to search documents: {e}") from e + # Delegate to the unified search method + results = [] + async for document in self.search(query, limit=limit, collection_name=collection_name): + results.append(document) + return results async def list_documents( self, @@ -830,8 +848,11 @@ class WeaviateStorage(BaseStorage): collection, _ = await self._prepare_collection(collection_name, ensure_exists=False) # Query documents with pagination - response = collection.query.fetch_objects( - limit=limit, offset=offset, return_metadata=["creation_time"] + response = await self._run_sync( + collection.query.fetch_objects, + limit=limit, + offset=offset, + return_metadata=["creation_time"], ) documents: list[DocumentInfo] = [] @@ -896,7 +917,9 @@ class WeaviateStorage(BaseStorage): ) delete_filter = Filter.by_id().contains_any(document_ids) - response = collection.data.delete_many(where=delete_filter, verbose=True) + response = await self._run_sync( + collection.data.delete_many, where=delete_filter, verbose=True + ) if objects := getattr(response, "objects", None): for result_obj in objects: @@ -938,20 +961,22 @@ class WeaviateStorage(BaseStorage): # Get documents matching filter if where_filter: - response = collection.query.fetch_objects( + response = await self._run_sync( + collection.query.fetch_objects, filters=where_filter, limit=1000, # Max batch size ) else: - response = collection.query.fetch_objects( - limit=1000 # Max batch size + response = await self._run_sync( + collection.query.fetch_objects, + limit=1000, # Max batch size ) # Delete matching documents deleted_count = 0 for obj in response.objects: try: - collection.data.delete_by_id(obj.uuid) + await self._run_sync(collection.data.delete_by_id, obj.uuid) deleted_count += 1 except Exception: continue @@ -975,7 +1000,7 @@ class WeaviateStorage(BaseStorage): target = self._normalize_collection_name(collection_name) # Delete the collection using the client's collections API - client = cast(weaviate.WeaviateClient, self.client) + client = self.client client.collections.delete(target) return True @@ -997,20 +1022,29 @@ class WeaviateStorage(BaseStorage): await self.close() async def close(self) -> None: - """Close client connection.""" + """Close client connection and vectorizer HTTP client.""" if self.client: try: - client = cast(weaviate.WeaviateClient, self.client) + client = self.client client.close() - except Exception as e: + except (WeaviateConnectionError, AttributeError) as e: import logging + logging.warning(f"Error closing Weaviate client: {e}") + # Close vectorizer HTTP client to prevent resource leaks + try: + await self.vectorizer.close() + except (AttributeError, OSError) as e: + import logging + + logging.warning(f"Error closing vectorizer client: {e}") + def __del__(self) -> None: """Clean up client connection as fallback.""" if self.client: try: - client = cast(weaviate.WeaviateClient, self.client) + client = self.client client.close() except Exception: pass # Ignore errors in destructor diff --git a/ingest_pipeline/utils/__pycache__/__init__.cpython-312.pyc b/ingest_pipeline/utils/__pycache__/__init__.cpython-312.pyc index 1ecac22..c6334c5 100644 Binary files a/ingest_pipeline/utils/__pycache__/__init__.cpython-312.pyc and b/ingest_pipeline/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/ingest_pipeline/utils/__pycache__/metadata_tagger.cpython-312.pyc b/ingest_pipeline/utils/__pycache__/metadata_tagger.cpython-312.pyc index 88ac753..daf5fc0 100644 Binary files a/ingest_pipeline/utils/__pycache__/metadata_tagger.cpython-312.pyc and b/ingest_pipeline/utils/__pycache__/metadata_tagger.cpython-312.pyc differ diff --git a/ingest_pipeline/utils/__pycache__/vectorizer.cpython-312.pyc b/ingest_pipeline/utils/__pycache__/vectorizer.cpython-312.pyc index 48c83db..d822e83 100644 Binary files a/ingest_pipeline/utils/__pycache__/vectorizer.cpython-312.pyc and b/ingest_pipeline/utils/__pycache__/vectorizer.cpython-312.pyc differ diff --git a/ingest_pipeline/utils/async_helpers.py b/ingest_pipeline/utils/async_helpers.py new file mode 100644 index 0000000..d4593fb --- /dev/null +++ b/ingest_pipeline/utils/async_helpers.py @@ -0,0 +1,117 @@ +"""Async utilities for task management and backpressure control.""" + +import asyncio +import logging +from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable +from contextlib import asynccontextmanager +from typing import Final, TypeVar + +LOGGER: Final[logging.Logger] = logging.getLogger(__name__) + +T = TypeVar("T") + + +class AsyncTaskManager: + """Manages concurrent tasks with backpressure control.""" + + def __init__(self, max_concurrent: int = 10): + """ + Initialize task manager. + + Args: + max_concurrent: Maximum number of concurrent tasks + """ + self.semaphore = asyncio.Semaphore(max_concurrent) + self.max_concurrent = max_concurrent + + @asynccontextmanager + async def acquire(self) -> AsyncGenerator[None, None]: + """Acquire a slot for task execution.""" + async with self.semaphore: + yield + + async def run_tasks( + self, tasks: Iterable[Awaitable[T]], return_exceptions: bool = False + ) -> list[T | BaseException]: + """ + Run multiple tasks with backpressure control. + + Args: + tasks: Iterable of awaitable tasks + return_exceptions: Whether to return exceptions or raise them + + Returns: + List of task results or exceptions + """ + + async def _controlled_task(task: Awaitable[T]) -> T: + async with self.acquire(): + return await task + + controlled_tasks = [_controlled_task(task) for task in tasks] + + if return_exceptions: + results = await asyncio.gather(*controlled_tasks, return_exceptions=True) + return list(results) + else: + results = await asyncio.gather(*controlled_tasks) + return list(results) + + async def map_async( + self, func: Callable[[T], Awaitable[T]], items: Iterable[T], return_exceptions: bool = False + ) -> list[T | BaseException]: + """ + Apply async function to items with backpressure control. + + Args: + func: Async function to apply + items: Items to process + return_exceptions: Whether to return exceptions or raise them + + Returns: + List of processed results or exceptions + """ + tasks = [func(item) for item in items] + return await self.run_tasks(tasks, return_exceptions=return_exceptions) + + +async def run_with_semaphore(semaphore: asyncio.Semaphore, coro: Awaitable[T]) -> T: + """Run coroutine with semaphore-controlled concurrency.""" + async with semaphore: + return await coro + + +async def batch_process( + items: list[T], + processor: Callable[[T], Awaitable[T]], + batch_size: int = 50, + max_concurrent: int = 5, +) -> list[T]: + """ + Process items in batches with controlled concurrency. + + Args: + items: Items to process + processor: Async function to process each item + batch_size: Number of items per batch + max_concurrent: Maximum concurrent tasks per batch + + Returns: + List of processed results + """ + task_manager = AsyncTaskManager(max_concurrent) + results: list[T] = [] + + for i in range(0, len(items), batch_size): + batch = items[i : i + batch_size] + LOGGER.debug( + "Processing batch %d-%d of %d items", i, min(i + batch_size, len(items)), len(items) + ) + + batch_results = await task_manager.map_async(processor, batch, return_exceptions=False) + # If return_exceptions=False, exceptions would have been raised, so all results are successful + # Type checker doesn't know this, so we need to cast + successful_results: list[T] = [r for r in batch_results if not isinstance(r, BaseException)] + results.extend(successful_results) + + return results diff --git a/ingest_pipeline/utils/metadata_tagger.py b/ingest_pipeline/utils/metadata_tagger.py index 9beb2d2..ebc9318 100644 --- a/ingest_pipeline/utils/metadata_tagger.py +++ b/ingest_pipeline/utils/metadata_tagger.py @@ -6,12 +6,12 @@ from typing import Final, Protocol, TypedDict, cast import httpx +from ..config import get_settings from ..core.exceptions import IngestionError from ..core.models import Document JSON_CONTENT_TYPE: Final[str] = "application/json" AUTHORIZATION_HEADER: Final[str] = "Authorization" -from ..config import get_settings class HttpResponse(Protocol): @@ -24,12 +24,7 @@ class HttpResponse(Protocol): class AsyncHttpClient(Protocol): """Protocol for async HTTP client.""" - async def post( - self, - url: str, - *, - json: dict[str, object] | None = None - ) -> HttpResponse: ... + async def post(self, url: str, *, json: dict[str, object] | None = None) -> HttpResponse: ... async def aclose(self) -> None: ... @@ -45,16 +40,19 @@ class AsyncHttpClient(Protocol): class LlmResponse(TypedDict): """Type for LLM API response structure.""" + choices: list[dict[str, object]] class LlmChoice(TypedDict): """Type for individual choice in LLM response.""" + message: dict[str, object] class LlmMessage(TypedDict): """Type for message in LLM choice.""" + content: str @@ -96,7 +94,7 @@ class MetadataTagger: """ settings = get_settings() endpoint_value = llm_endpoint or str(settings.llm_endpoint) - self.endpoint = endpoint_value.rstrip('/') + self.endpoint = endpoint_value.rstrip("/") self.model = model or settings.metadata_model resolved_timeout = timeout if timeout is not None else float(settings.request_timeout) @@ -161,12 +159,16 @@ class MetadataTagger: # Build a proper DocumentMetadata instance with only valid keys new_metadata: CoreDocumentMetadata = { "source_url": str(updated_metadata.get("source_url", "")), - "timestamp": ( - lambda ts: ts if isinstance(ts, datetime) else datetime.now(UTC) - )(updated_metadata.get("timestamp", datetime.now(UTC))), + "timestamp": (lambda ts: ts if isinstance(ts, datetime) else datetime.now(UTC))( + updated_metadata.get("timestamp", datetime.now(UTC)) + ), "content_type": str(updated_metadata.get("content_type", "text/plain")), - "word_count": (lambda wc: int(wc) if isinstance(wc, (int, str)) else 0)(updated_metadata.get("word_count", 0)), - "char_count": (lambda cc: int(cc) if isinstance(cc, (int, str)) else 0)(updated_metadata.get("char_count", 0)), + "word_count": (lambda wc: int(wc) if isinstance(wc, (int, str)) else 0)( + updated_metadata.get("word_count", 0) + ), + "char_count": (lambda cc: int(cc) if isinstance(cc, (int, str)) else 0)( + updated_metadata.get("char_count", 0) + ), } # Add optional fields if they exist diff --git a/ingest_pipeline/utils/vectorizer.py b/ingest_pipeline/utils/vectorizer.py index adc63f6..87e597f 100644 --- a/ingest_pipeline/utils/vectorizer.py +++ b/ingest_pipeline/utils/vectorizer.py @@ -1,20 +1,79 @@ """Vectorizer utility for generating embeddings.""" +import asyncio from types import TracebackType -from typing import Final, Self, cast +from typing import Final, NotRequired, Self, TypedDict import httpx -from typings import EmbeddingResponse - +from ..config import get_settings from ..core.exceptions import VectorizationError from ..core.models import StorageConfig, VectorConfig -from ..config import get_settings JSON_CONTENT_TYPE: Final[str] = "application/json" AUTHORIZATION_HEADER: Final[str] = "Authorization" +class EmbeddingData(TypedDict): + """Structure for embedding data from providers.""" + + embedding: list[float] + index: NotRequired[int] + object: NotRequired[str] + + +class EmbeddingResponse(TypedDict): + """Embedding response format for multiple providers.""" + + data: list[EmbeddingData] + model: NotRequired[str] + object: NotRequired[str] + usage: NotRequired[dict[str, int]] + # Alternative formats + embedding: NotRequired[list[float]] + vector: NotRequired[list[float]] + embeddings: NotRequired[list[list[float]]] + + +def _extract_embedding_from_response(response_data: dict[str, object]) -> list[float]: + """Extract embedding vector from provider response.""" + # OpenAI/Ollama format: {"data": [{"embedding": [...]}]} + if "data" in response_data: + data_list = response_data["data"] + if isinstance(data_list, list) and data_list: + first_item = data_list[0] + if isinstance(first_item, dict) and "embedding" in first_item: + embedding = first_item["embedding"] + if isinstance(embedding, list) and all( + isinstance(x, (int, float)) for x in embedding + ): + return [float(x) for x in embedding] + + # Direct embedding format: {"embedding": [...]} + if "embedding" in response_data: + embedding = response_data["embedding"] + if isinstance(embedding, list) and all(isinstance(x, (int, float)) for x in embedding): + return [float(x) for x in embedding] + + # Vector format: {"vector": [...]} + if "vector" in response_data: + vector = response_data["vector"] + if isinstance(vector, list) and all(isinstance(x, (int, float)) for x in vector): + return [float(x) for x in vector] + + # Embeddings array format: {"embeddings": [[...]]} + if "embeddings" in response_data: + embeddings = response_data["embeddings"] + if isinstance(embeddings, list) and embeddings: + first_embedding = embeddings[0] + if isinstance(first_embedding, list) and all( + isinstance(x, (int, float)) for x in first_embedding + ): + return [float(x) for x in first_embedding] + + raise VectorizationError("Unrecognized embedding response format") + + class Vectorizer: """Handles text vectorization using LLM endpoints.""" @@ -72,21 +131,34 @@ class Vectorizer: async def vectorize_batch(self, texts: list[str]) -> list[list[float]]: """ - Generate embeddings for multiple texts. + Generate embeddings for multiple texts in parallel. Args: texts: List of texts to vectorize Returns: List of embedding vectors + + Raises: + VectorizationError: If any vectorization fails """ - vectors: list[list[float]] = [] - for text in texts: - vector = await self.vectorize(text) - vectors.append(vector) + if not texts: + return [] - return vectors + # Use semaphore to limit concurrent requests and prevent overwhelming the endpoint + semaphore = asyncio.Semaphore(20) + + async def vectorize_with_semaphore(text: str) -> list[float]: + async with semaphore: + return await self.vectorize(text) + + try: + # Execute all vectorization requests concurrently + vectors = await asyncio.gather(*[vectorize_with_semaphore(text) for text in texts]) + return list(vectors) + except Exception as e: + raise VectorizationError(f"Batch vectorization failed: {e}") from e async def _ollama_embed(self, text: str) -> list[float]: """ @@ -112,23 +184,12 @@ class Vectorizer: _ = response.raise_for_status() response_json = response.json() - # Response is expected to be dict[str, object] from our type stub + if not isinstance(response_json, dict): + raise VectorizationError("Invalid JSON response format") - response_data = cast(EmbeddingResponse, cast(object, response_json)) + # Extract embedding using type-safe helper + embedding = _extract_embedding_from_response(response_json) - # Parse OpenAI-compatible response format - embeddings_list = response_data.get("data", []) - if not embeddings_list: - raise VectorizationError("No embeddings returned") - - first_embedding = embeddings_list[0] - embedding_raw = first_embedding.get("embedding") - if not embedding_raw: - raise VectorizationError("Invalid embedding format") - - # Convert to float list and validate - embedding: list[float] = [] - embedding.extend(float(item) for item in embedding_raw) # Ensure correct dimension if len(embedding) != self.dimension: raise VectorizationError( @@ -157,22 +218,12 @@ class Vectorizer: _ = response.raise_for_status() response_json = response.json() - # Response is expected to be dict[str, object] from our type stub + if not isinstance(response_json, dict): + raise VectorizationError("Invalid JSON response format") - response_data = cast(EmbeddingResponse, cast(object, response_json)) + # Extract embedding using type-safe helper + embedding = _extract_embedding_from_response(response_json) - embeddings_list = response_data.get("data", []) - if not embeddings_list: - raise VectorizationError("No embeddings returned") - - first_embedding = embeddings_list[0] - embedding_raw = first_embedding.get("embedding") - if not embedding_raw: - raise VectorizationError("Invalid embedding format") - - # Convert to float list and validate - embedding: list[float] = [] - embedding.extend(float(item) for item in embedding_raw) # Ensure correct dimension if len(embedding) != self.dimension: raise VectorizationError( @@ -185,6 +236,14 @@ class Vectorizer: """Async context manager entry.""" return self + async def close(self) -> None: + """Close the HTTP client connection.""" + try: + await self.client.aclose() + except Exception: + # Already closed or connection lost + pass + async def __aexit__( self, exc_type: type[BaseException] | None, @@ -192,4 +251,4 @@ class Vectorizer: exc_tb: TracebackType | None, ) -> None: """Async context manager exit.""" - await self.client.aclose() + await self.close() diff --git a/notebooks/welcome.py b/notebooks/welcome.py new file mode 100644 index 0000000..4b7ad91 --- /dev/null +++ b/notebooks/welcome.py @@ -0,0 +1,27 @@ +import marimo + +__generated_with = "0.16.0" +app = marimo.App() + + +@app.cell +def __(): + import marimo as mo + + return (mo,) + + +@app.cell +def __(mo): + mo.md("# Welcome to Marimo!") + return + + +@app.cell +def __(mo): + mo.md("This is your interactive notebook environment.") + return + + +if __name__ == "__main__": + app.run() diff --git a/repomix-output.xml b/repomix-output.xml index afa1c62..35e2616 100644 --- a/repomix-output.xml +++ b/repomix-output.xml @@ -4,7 +4,7 @@ This file is a merged representation of a subset of the codebase, containing spe This section contains a summary of this file. -This file contains a packed representation of the entire repository's contents. +This file contains a packed representation of a subset of the repository's contents that is considered the most important context. It is designed to be easily consumable by AI systems for analysis, code review, or other automated processes. @@ -14,7 +14,8 @@ The content is organized as follows: 1. This summary section 2. Repository information 3. Directory structure -4. Repository files, each consisting of: +4. Repository files (if enabled) +5. Multiple file entries, each consisting of: - File path as an attribute - Full contents of the file @@ -37,10 +38,6 @@ The content is organized as follows: - Files are sorted by Git change count (files with more changes are at the bottom) - - - - @@ -99,9 +96,11 @@ ingest_pipeline/ __init__.py base.py openwebui.py + types.py weaviate.py utils/ __init__.py + async_helpers.py metadata_tagger.py vectorizer.py __main__.py @@ -110,6 +109,126 @@ ingest_pipeline/ This section contains the contents of the repository's files. + +"""Async utilities for task management and backpressure control.""" + +import asyncio +import logging +from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable +from contextlib import asynccontextmanager +from typing import Final, TypeVar + +LOGGER: Final[logging.Logger] = logging.getLogger(__name__) + +T = TypeVar("T") + + +class AsyncTaskManager: + """Manages concurrent tasks with backpressure control.""" + + def __init__(self, max_concurrent: int = 10): + """ + Initialize task manager. + + Args: + max_concurrent: Maximum number of concurrent tasks + """ + self.semaphore = asyncio.Semaphore(max_concurrent) + self.max_concurrent = max_concurrent + + @asynccontextmanager + async def acquire(self) -> AsyncGenerator[None, None]: + """Acquire a slot for task execution.""" + async with self.semaphore: + yield + + async def run_tasks( + self, tasks: Iterable[Awaitable[T]], return_exceptions: bool = False + ) -> list[T | BaseException]: + """ + Run multiple tasks with backpressure control. + + Args: + tasks: Iterable of awaitable tasks + return_exceptions: Whether to return exceptions or raise them + + Returns: + List of task results or exceptions + """ + + async def _controlled_task(task: Awaitable[T]) -> T: + async with self.acquire(): + return await task + + controlled_tasks = [_controlled_task(task) for task in tasks] + + if return_exceptions: + results = await asyncio.gather(*controlled_tasks, return_exceptions=True) + return list(results) + else: + results = await asyncio.gather(*controlled_tasks) + return list(results) + + async def map_async( + self, func: Callable[[T], Awaitable[T]], items: Iterable[T], return_exceptions: bool = False + ) -> list[T | BaseException]: + """ + Apply async function to items with backpressure control. + + Args: + func: Async function to apply + items: Items to process + return_exceptions: Whether to return exceptions or raise them + + Returns: + List of processed results or exceptions + """ + tasks = [func(item) for item in items] + return await self.run_tasks(tasks, return_exceptions=return_exceptions) + + +async def run_with_semaphore(semaphore: asyncio.Semaphore, coro: Awaitable[T]) -> T: + """Run coroutine with semaphore-controlled concurrency.""" + async with semaphore: + return await coro + + +async def batch_process( + items: list[T], + processor: Callable[[T], Awaitable[T]], + batch_size: int = 50, + max_concurrent: int = 5, +) -> list[T]: + """ + Process items in batches with controlled concurrency. + + Args: + items: Items to process + processor: Async function to process each item + batch_size: Number of items per batch + max_concurrent: Maximum concurrent tasks per batch + + Returns: + List of processed results + """ + task_manager = AsyncTaskManager(max_concurrent) + results: list[T] = [] + + for i in range(0, len(items), batch_size): + batch = items[i : i + batch_size] + LOGGER.debug( + "Processing batch %d-%d of %d items", i, min(i + batch_size, len(items)), len(items) + ) + + batch_results = await task_manager.map_async(processor, batch, return_exceptions=False) + # If return_exceptions=False, exceptions would have been raised, so all results are successful + # Type checker doesn't know this, so we need to cast + successful_results: list[T] = [r for r in batch_results if not isinstance(r, BaseException)] + results.extend(successful_results) + + return results + + """Prefect Automations for ingestion pipeline monitoring and management.""" @@ -132,7 +251,6 @@ actions: source: inferred enabled: true """, - "retry_failed": """ name: Retry Failed Ingestion Flows description: Retries failed ingestion flows with original parameters @@ -152,7 +270,6 @@ actions: validate_first: false enabled: true """, - "resource_monitoring": """ name: Manage Work Pool Based on Resources description: Pauses work pool when system resources are constrained @@ -519,6 +636,33 @@ from .storage import R2RStorage __all__ = ["R2RStorage"] + +"""Shared types for storage adapters.""" + +from typing import TypedDict + + +class CollectionSummary(TypedDict): + """Collection metadata for describe_collections.""" + + name: str + count: int + size_mb: float + + +class DocumentInfo(TypedDict): + """Document information for list_documents.""" + + id: str + title: str + source_url: str + description: str + content_type: str + content_preview: str + word_count: int + timestamp: str + + """Utility modules.""" @@ -592,7 +736,7 @@ class BaseScreen(Screen[object]): name: str | None = None, id: str | None = None, classes: str | None = None, - **kwargs: object + **kwargs: object, ) -> None: """Initialize base screen.""" super().__init__(name=name, id=id, classes=classes) @@ -618,7 +762,7 @@ class CRUDScreen(BaseScreen, Generic[T]): name: str | None = None, id: str | None = None, classes: str | None = None, - **kwargs: object + **kwargs: object, ) -> None: """Initialize CRUD screen.""" super().__init__(storage_manager, name=name, id=id, classes=classes) @@ -874,7 +1018,7 @@ class FormScreen(ModalScreen[T], Generic[T]): name: str | None = None, id: str | None = None, classes: str | None = None, - **kwargs: object + **kwargs: object, ) -> None: """Initialize form screen.""" super().__init__(name=name, id=id, classes=classes) @@ -947,503 +1091,6 @@ class FormScreen(ModalScreen[T], Generic[T]): self.dismiss(None) - -"""Storage management utilities for TUI applications.""" - - -from __future__ import annotations - -import asyncio -from collections.abc import AsyncGenerator, Sequence -from typing import TYPE_CHECKING, Protocol - -from ....core.exceptions import StorageError -from ....core.models import Document, StorageBackend, StorageConfig -from ..models import CollectionInfo, StorageCapabilities - -from ....storage.base import BaseStorage -from ....storage.openwebui import OpenWebUIStorage -from ....storage.r2r.storage import R2RStorage -from ....storage.weaviate import WeaviateStorage - -if TYPE_CHECKING: - from ....config.settings import Settings - - -class StorageBackendProtocol(Protocol): - """Protocol defining storage backend interface.""" - - async def initialize(self) -> None: ... - async def count(self, *, collection_name: str | None = None) -> int: ... - async def list_collections(self) -> list[str]: ... - async def search( - self, - query: str, - limit: int = 10, - threshold: float = 0.7, - *, - collection_name: str | None = None, - ) -> AsyncGenerator[Document, None]: ... - async def close(self) -> None: ... - - - -class MultiStorageAdapter(BaseStorage): - """Mirror writes to multiple storage backends.""" - - def __init__(self, storages: Sequence[BaseStorage]) -> None: - if not storages: - raise ValueError("MultiStorageAdapter requires at least one storage backend") - - unique: list[BaseStorage] = [] - seen_ids: set[int] = set() - for storage in storages: - storage_id = id(storage) - if storage_id in seen_ids: - continue - seen_ids.add(storage_id) - unique.append(storage) - - self._storages = unique - self._primary = unique[0] - super().__init__(self._primary.config) - - async def initialize(self) -> None: - for storage in self._storages: - await storage.initialize() - - async def store(self, document: Document, *, collection_name: str | None = None) -> str: - # Store in primary backend first - primary_id: str = await self._primary.store(document, collection_name=collection_name) - - # Replicate to secondary backends concurrently - if len(self._storages) > 1: - async def replicate_to_backend(storage: BaseStorage) -> tuple[BaseStorage, bool, Exception | None]: - try: - await storage.store(document, collection_name=collection_name) - return storage, True, None - except Exception as exc: - return storage, False, exc - - tasks = [replicate_to_backend(storage) for storage in self._storages[1:]] - results = await asyncio.gather(*tasks, return_exceptions=True) - - failures: list[str] = [] - errors: list[Exception] = [] - - for result in results: - if isinstance(result, tuple): - storage, success, error = result - if not success and error is not None: - failures.append(self._format_backend_label(storage)) - errors.append(error) - elif isinstance(result, Exception): - failures.append("unknown") - errors.append(result) - - if failures: - backends = ", ".join(failures) - primary_error = errors[0] if errors else Exception("Unknown replication error") - raise StorageError( - f"Document stored in primary backend but replication failed for: {backends}" - ) from primary_error - - return primary_id - - async def store_batch( - self, documents: list[Document], *, collection_name: str | None = None - ) -> list[str]: - # Store in primary backend first - primary_ids: list[str] = await self._primary.store_batch(documents, collection_name=collection_name) - - # Replicate to secondary backends concurrently - if len(self._storages) > 1: - async def replicate_batch_to_backend(storage: BaseStorage) -> tuple[BaseStorage, bool, Exception | None]: - try: - await storage.store_batch(documents, collection_name=collection_name) - return storage, True, None - except Exception as exc: - return storage, False, exc - - tasks = [replicate_batch_to_backend(storage) for storage in self._storages[1:]] - results = await asyncio.gather(*tasks, return_exceptions=True) - - failures: list[str] = [] - errors: list[Exception] = [] - - for result in results: - if isinstance(result, tuple): - storage, success, error = result - if not success and error is not None: - failures.append(self._format_backend_label(storage)) - errors.append(error) - elif isinstance(result, Exception): - failures.append("unknown") - errors.append(result) - - if failures: - backends = ", ".join(failures) - primary_error = errors[0] if errors else Exception("Unknown batch replication error") - raise StorageError( - f"Batch stored in primary backend but replication failed for: {backends}" - ) from primary_error - - return primary_ids - - async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool: - # Delete from primary backend first - primary_deleted: bool = await self._primary.delete(document_id, collection_name=collection_name) - - # Delete from secondary backends concurrently - if len(self._storages) > 1: - async def delete_from_backend(storage: BaseStorage) -> tuple[BaseStorage, bool, Exception | None]: - try: - await storage.delete(document_id, collection_name=collection_name) - return storage, True, None - except Exception as exc: - return storage, False, exc - - tasks = [delete_from_backend(storage) for storage in self._storages[1:]] - results = await asyncio.gather(*tasks, return_exceptions=True) - - failures: list[str] = [] - errors: list[Exception] = [] - - for result in results: - if isinstance(result, tuple): - storage, success, error = result - if not success and error is not None: - failures.append(self._format_backend_label(storage)) - errors.append(error) - elif isinstance(result, Exception): - failures.append("unknown") - errors.append(result) - - if failures: - backends = ", ".join(failures) - primary_error = errors[0] if errors else Exception("Unknown deletion error") - raise StorageError( - f"Document deleted from primary backend but failed for: {backends}" - ) from primary_error - - return primary_deleted - - async def count(self, *, collection_name: str | None = None) -> int: - count_result: int = await self._primary.count(collection_name=collection_name) - return count_result - - async def list_collections(self) -> list[str]: - list_fn = getattr(self._primary, "list_collections", None) - if list_fn is None: - return [] - collections_result: list[str] = await list_fn() - return collections_result - - async def search( - self, - query: str, - limit: int = 10, - threshold: float = 0.7, - *, - collection_name: str | None = None, - ) -> AsyncGenerator[Document, None]: - async for item in self._primary.search( - query, - limit=limit, - threshold=threshold, - collection_name=collection_name, - ): - yield item - - async def close(self) -> None: - for storage in self._storages: - close_fn = getattr(storage, "close", None) - if close_fn is not None: - await close_fn() - - def _format_backend_label(self, storage: BaseStorage) -> str: - backend = getattr(storage.config, "backend", None) - if isinstance(backend, StorageBackend): - backend_value: str = backend.value - return backend_value - class_name: str = storage.__class__.__name__ - return class_name - - - -class StorageManager: - """Centralized manager for all storage backend operations.""" - - def __init__(self, settings: Settings) -> None: - """Initialize storage manager with application settings.""" - self.settings = settings - self.backends: dict[StorageBackend, BaseStorage] = {} - self.capabilities: dict[StorageBackend, StorageCapabilities] = {} - self._initialized = False - - async def initialize_all_backends(self) -> dict[StorageBackend, bool]: - """Initialize all available storage backends with timeout protection.""" - results: dict[StorageBackend, bool] = {} - - async def init_backend(backend_type: StorageBackend, config: StorageConfig, storage_class: type[BaseStorage]) -> bool: - """Initialize a single backend with timeout.""" - try: - storage = storage_class(config) - await asyncio.wait_for(storage.initialize(), timeout=30.0) - self.backends[backend_type] = storage - if backend_type == StorageBackend.WEAVIATE: - self.capabilities[backend_type] = StorageCapabilities.VECTOR_SEARCH - elif backend_type == StorageBackend.OPEN_WEBUI: - self.capabilities[backend_type] = StorageCapabilities.KNOWLEDGE_BASE - elif backend_type == StorageBackend.R2R: - self.capabilities[backend_type] = StorageCapabilities.FULL_FEATURED - return True - except (TimeoutError, Exception): - return False - - # Initialize backends concurrently with timeout protection - tasks = [] - - # Try Weaviate - if self.settings.weaviate_endpoint: - config = StorageConfig( - backend=StorageBackend.WEAVIATE, - endpoint=self.settings.weaviate_endpoint, - api_key=self.settings.weaviate_api_key, - collection_name="default", - ) - tasks.append((StorageBackend.WEAVIATE, init_backend(StorageBackend.WEAVIATE, config, WeaviateStorage))) - else: - results[StorageBackend.WEAVIATE] = False - - # Try OpenWebUI - if self.settings.openwebui_endpoint and self.settings.openwebui_api_key: - config = StorageConfig( - backend=StorageBackend.OPEN_WEBUI, - endpoint=self.settings.openwebui_endpoint, - api_key=self.settings.openwebui_api_key, - collection_name="default", - ) - tasks.append((StorageBackend.OPEN_WEBUI, init_backend(StorageBackend.OPEN_WEBUI, config, OpenWebUIStorage))) - else: - results[StorageBackend.OPEN_WEBUI] = False - - # Try R2R - if self.settings.r2r_endpoint: - config = StorageConfig( - backend=StorageBackend.R2R, - endpoint=self.settings.r2r_endpoint, - api_key=self.settings.r2r_api_key, - collection_name="default", - ) - tasks.append((StorageBackend.R2R, init_backend(StorageBackend.R2R, config, R2RStorage))) - else: - results[StorageBackend.R2R] = False - - # Execute initialization tasks concurrently - if tasks: - backend_types, task_coroutines = zip(*tasks, strict=False) - task_results = await asyncio.gather(*task_coroutines, return_exceptions=True) - - for backend_type, task_result in zip(backend_types, task_results, strict=False): - results[backend_type] = task_result if isinstance(task_result, bool) else False - self._initialized = True - return results - - def get_backend(self, backend_type: StorageBackend) -> BaseStorage | None: - """Get storage backend by type.""" - return self.backends.get(backend_type) - - def build_multi_storage_adapter( - self, backends: Sequence[StorageBackend] - ) -> MultiStorageAdapter: - storages: list[BaseStorage] = [] - seen: set[StorageBackend] = set() - for backend in backends: - backend_enum = backend if isinstance(backend, StorageBackend) else StorageBackend(backend) - if backend_enum in seen: - continue - seen.add(backend_enum) - storage = self.backends.get(backend_enum) - if storage is None: - raise ValueError(f"Storage backend {backend_enum.value} is not initialized") - storages.append(storage) - return MultiStorageAdapter(storages) - - def get_available_backends(self) -> list[StorageBackend]: - """Get list of successfully initialized backends.""" - return list(self.backends.keys()) - - def has_capability(self, backend: StorageBackend, capability: StorageCapabilities) -> bool: - """Check if backend has specific capability.""" - backend_caps = self.capabilities.get(backend, StorageCapabilities.BASIC) - return capability.value <= backend_caps.value - - async def get_all_collections(self) -> list[CollectionInfo]: - """Get collections from all available backends, merging collections with same name.""" - collection_map: dict[str, CollectionInfo] = {} - - for backend_type, storage in self.backends.items(): - try: - backend_collections = await storage.list_collections() - for collection_name in backend_collections: - # Validate collection name - if not collection_name or not isinstance(collection_name, str): - continue - - try: - count = await storage.count(collection_name=collection_name) - # Validate count is non-negative - count = max(count, 0) - except StorageError as e: - # Storage-specific errors - log and use 0 count - import logging - logging.warning(f"Failed to get count for {collection_name} on {backend_type.value}: {e}") - count = 0 - except Exception as e: - # Unexpected errors - log and skip this collection from this backend - import logging - logging.warning(f"Unexpected error counting {collection_name} on {backend_type.value}: {e}") - continue - - size_mb = count * 0.01 # Rough estimate: 10KB per document - - if collection_name in collection_map: - # Merge with existing collection - existing = collection_map[collection_name] - existing_backends = existing["backend"] - backend_value = backend_type.value - - if isinstance(existing_backends, str): - existing["backend"] = [existing_backends, backend_value] - elif isinstance(existing_backends, list): - # Prevent duplicates - if backend_value not in existing_backends: - existing_backends.append(backend_value) - - # Aggregate counts and sizes - existing["count"] += count - existing["size_mb"] += size_mb - else: - # Create new collection entry - collection_info: CollectionInfo = { - "name": collection_name, - "type": self._get_collection_type(collection_name, backend_type), - "count": count, - "backend": backend_type.value, - "status": "active", - "last_updated": "2024-01-01T00:00:00Z", - "size_mb": size_mb, - } - collection_map[collection_name] = collection_info - except Exception: - continue - - return list(collection_map.values()) - - def _get_collection_type(self, collection_name: str, backend: StorageBackend) -> str: - """Determine collection type based on name and backend.""" - # Prioritize definitive backend type first - if backend == StorageBackend.R2R: - return "r2r" - elif backend == StorageBackend.WEAVIATE: - return "weaviate" - elif backend == StorageBackend.OPEN_WEBUI: - return "openwebui" - - # Fallback to name-based guessing if backend is not specific - name_lower = collection_name.lower() - if "web" in name_lower or "doc" in name_lower: - return "documentation" - elif "repo" in name_lower or "code" in name_lower: - return "repository" - else: - return "general" - - async def search_across_backends( - self, - query: str, - limit: int = 10, - backends: list[StorageBackend] | None = None, - ) -> dict[StorageBackend, list[Document]]: - """Search across multiple backends and return grouped results.""" - if backends is None: - backends = self.get_available_backends() - - results: dict[StorageBackend, list[Document]] = {} - - async def search_backend(backend_type: StorageBackend) -> None: - storage = self.backends.get(backend_type) - if storage: - try: - documents = [] - async for doc in storage.search(query, limit=limit): - documents.append(doc) - results[backend_type] = documents - except Exception: - results[backend_type] = [] - - # Run searches in parallel - tasks = [search_backend(backend) for backend in backends] - await asyncio.gather(*tasks, return_exceptions=True) - - return results - - def get_r2r_storage(self) -> R2RStorage | None: - """Get R2R storage instance if available.""" - storage = self.backends.get(StorageBackend.R2R) - return storage if isinstance(storage, R2RStorage) else None - - async def get_backend_status(self) -> dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]]: - """Get comprehensive status for all backends.""" - status: dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]] = {} - - for backend_type, storage in self.backends.items(): - try: - collections = await storage.list_collections() - total_docs = 0 - for collection in collections: - total_docs += await storage.count(collection_name=collection) - - backend_status = { - "available": True, - "collections": len(collections), - "total_documents": total_docs, - "capabilities": self.capabilities.get(backend_type, StorageCapabilities.BASIC), - "endpoint": getattr(storage.config, "endpoint", "unknown"), - } - status[backend_type] = backend_status - except Exception as e: - status[backend_type] = { - "available": False, - "error": str(e), - "capabilities": StorageCapabilities.NONE, - } - - return status - - async def close_all(self) -> None: - """Close all storage connections.""" - for storage in self.backends.values(): - try: - await storage.close() - except Exception: - pass - - self.backends.clear() - self.capabilities.clear() - self._initialized = False - - @property - def is_initialized(self) -> bool: - """Check if storage manager is initialized.""" - return self._initialized - - def supports_advanced_features(self, backend: StorageBackend) -> bool: - """Check if backend supports advanced features like chunks and entities.""" - return self.has_capability(backend, StorageCapabilities.FULL_FEATURED) - - """Status indicators and progress bars with enhanced visual feedback.""" @@ -1473,8 +1120,12 @@ class StatusIndicator(Static): status_lower = status.lower() - if (status_lower in {"active", "online", "connected", "✓ active"} or - status_lower.endswith("active") or "✓" in status_lower and "active" in status_lower): + if ( + status_lower in {"active", "online", "connected", "✓ active"} + or status_lower.endswith("active") + or "✓" in status_lower + and "active" in status_lower + ): self.add_class("status-active") self.add_class("glow") self.update(f"🟢 {status}") @@ -1540,7 +1191,7 @@ class EnhancedProgressBar(Static): """Data models and TypedDict definitions for the TUI.""" from enum import IntEnum -from typing import Any, TypedDict +from typing import TypedDict class StorageCapabilities(IntEnum): @@ -1586,7 +1237,7 @@ class ChunkInfo(TypedDict): content: str start_index: int end_index: int - metadata: dict[str, Any] + metadata: dict[str, object] class EntityInfo(TypedDict): @@ -1596,7 +1247,7 @@ class EntityInfo(TypedDict): name: str type: str confidence: float - metadata: dict[str, Any] + metadata: dict[str, object] class FirecrawlOptions(TypedDict, total=False): @@ -1616,7 +1267,7 @@ class FirecrawlOptions(TypedDict, total=False): max_depth: int # Extraction options - extract_schema: dict[str, Any] | None + extract_schema: dict[str, object] | None extract_prompt: str | None @@ -2158,6 +1809,7 @@ if TYPE_CHECKING: try: from .r2r.storage import R2RStorage as _RuntimeR2RStorage + R2RStorage: type[BaseStorage] | None = _RuntimeR2RStorage except ImportError: R2RStorage = None @@ -2171,341 +1823,12 @@ __all__ = [ ] - -"""Document management screen with enhanced navigation.""" - -from datetime import datetime - -from textual.app import ComposeResult -from textual.binding import Binding -from textual.containers import Container, Horizontal -from textual.screen import Screen -from textual.widgets import Button, Footer, Header, Label, LoadingIndicator, Static -from typing_extensions import override - -from ....storage.base import BaseStorage -from ..models import CollectionInfo, DocumentInfo -from ..widgets import EnhancedDataTable - - -class DocumentManagementScreen(Screen[None]): - """Screen for managing documents within a collection with enhanced keyboard navigation.""" - - collection: CollectionInfo - storage: BaseStorage | None - documents: list[DocumentInfo] - selected_docs: set[str] - current_offset: int - page_size: int - - BINDINGS = [ - Binding("escape", "app.pop_screen", "Back"), - Binding("r", "refresh", "Refresh"), - Binding("delete", "delete_selected", "Delete Selected"), - Binding("a", "select_all", "Select All"), - Binding("ctrl+a", "select_all", "Select All"), - Binding("n", "select_none", "Clear Selection"), - Binding("ctrl+shift+a", "select_none", "Clear Selection"), - Binding("space", "toggle_selection", "Toggle Selection"), - Binding("ctrl+d", "delete_selected", "Delete Selected"), - Binding("pageup", "prev_page", "Previous Page"), - Binding("pagedown", "next_page", "Next Page"), - Binding("home", "first_page", "First Page"), - Binding("end", "last_page", "Last Page"), - ] - - def __init__(self, collection: CollectionInfo, storage: BaseStorage | None): - super().__init__() - self.collection = collection - self.storage = storage - self.documents: list[DocumentInfo] = [] - self.selected_docs: set[str] = set() - self.current_offset = 0 - self.page_size = 50 - - @override - def compose(self) -> ComposeResult: - yield Header() - yield Container( - Static(f"📄 Document Management: {self.collection['name']}", classes="title"), - Static( - f"Total Documents: {self.collection['count']:,} | Use Space to select, Delete to remove", - classes="subtitle", - ), - Label(f"Page size: {self.page_size} documents"), - EnhancedDataTable(id="documents_table", classes="enhanced-table"), - Horizontal( - Button("🔄 Refresh", id="refresh_docs_btn", variant="primary"), - Button("🗑️ Delete Selected", id="delete_selected_btn", variant="error"), - Button("✅ Select All", id="select_all_btn", variant="default"), - Button("❌ Clear Selection", id="clear_selection_btn", variant="default"), - Button("⬅️ Previous Page", id="prev_page_btn", variant="default"), - Button("➡️ Next Page", id="next_page_btn", variant="default"), - classes="button_bar", - ), - Label("", id="selection_status"), - Static("", id="page_info", classes="status-text"), - LoadingIndicator(id="loading"), - classes="main_container", - ) - yield Footer() - - async def on_mount(self) -> None: - """Initialize the screen.""" - self.query_one("#loading").display = False - - # Setup documents table with enhanced columns - table = self.query_one("#documents_table", EnhancedDataTable) - table.add_columns( - "✓", "Title", "Source URL", "Description", "Type", "Words", "Timestamp", "ID" - ) - - # Set up message handling for table events - table.can_focus = True - - await self.load_documents() - - async def load_documents(self) -> None: - """Load documents from the collection.""" - loading = self.query_one("#loading") - loading.display = True - - try: - if self.storage: - # Try to load documents using the storage backend - try: - raw_docs = await self.storage.list_documents( - limit=self.page_size, - offset=self.current_offset, - collection_name=self.collection["name"], - ) - # Cast to proper type with type checking - self.documents = [ - DocumentInfo( - id=str(doc.get("id", f"doc_{i}")), - title=str(doc.get("title", "Untitled Document")), - source_url=str(doc.get("source_url", "")), - description=str(doc.get("description", "")), - content_type=str(doc.get("content_type", "text/plain")), - content_preview=str(doc.get("content_preview", "")), - word_count=( - lambda wc_val: int(wc_val) if isinstance(wc_val, (int, str)) and str(wc_val).isdigit() else 0 - )(doc.get("word_count", 0)), - timestamp=str(doc.get("timestamp", "")), - ) - for i, doc in enumerate(raw_docs) - ] - except NotImplementedError: - # For storage backends that don't support document listing, show a message - self.notify( - f"Document listing not supported for {self.storage.__class__.__name__}", - severity="information" - ) - self.documents = [] - - await self.update_table() - self.update_selection_status() - self.update_page_info() - - except Exception as e: - self.notify(f"Error loading documents: {e}", severity="error", markup=False) - finally: - loading.display = False - - async def update_table(self) -> None: - """Update the documents table with enhanced metadata display.""" - table = self.query_one("#documents_table", EnhancedDataTable) - table.clear(columns=True) - - # Add enhanced columns with more metadata - table.add_columns( - "✓", "Title", "Source URL", "Description", "Type", "Words", "Timestamp", "ID" - ) - - # Add rows with enhanced metadata - for doc in self.documents: - selected = "✓" if doc["id"] in self.selected_docs else "" - - # Get additional metadata from the raw docs - description = str(doc.get("description") or "").strip()[:40] - if not description: - description = "[dim]No description[/dim]" - elif len(str(doc.get("description") or "")) > 40: - description += "..." - - # Format content type with appropriate icon - content_type = doc.get("content_type", "text/plain") - if "markdown" in content_type.lower(): - type_display = "📝 md" - elif "html" in content_type.lower(): - type_display = "🌐 html" - elif "text" in content_type.lower(): - type_display = "📄 txt" - else: - type_display = f"📄 {content_type.split('/')[-1][:5]}" - - # Format timestamp to be more readable - timestamp = doc.get("timestamp", "") - if timestamp: - try: - # Parse ISO format timestamp - dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00")) - timestamp = dt.strftime("%m/%d %H:%M") - except Exception: - timestamp = str(timestamp)[:16] # Fallback - table.add_row( - selected, - doc.get("title", "Untitled")[:40], - doc.get("source_url", "")[:35], - description, - type_display, - str(doc.get("word_count", 0)), - timestamp, - doc["id"][:8] + "...", # Show truncated ID - ) - - def update_selection_status(self) -> None: - """Update the selection status label.""" - status_label = self.query_one("#selection_status", Label) - total_selected = len(self.selected_docs) - status_label.update(f"Selected: {total_selected} documents") - - def update_page_info(self) -> None: - """Update the page information.""" - page_info = self.query_one("#page_info", Static) - total_docs = self.collection["count"] - start = self.current_offset + 1 - end = min(self.current_offset + len(self.documents), total_docs) - page_num = (self.current_offset // self.page_size) + 1 - total_pages = (total_docs + self.page_size - 1) // self.page_size - - page_info.update( - f"Showing {start:,}-{end:,} of {total_docs:,} documents (Page {page_num} of {total_pages})" - ) - - def get_current_document(self) -> DocumentInfo | None: - """Get the currently selected document.""" - table = self.query_one("#documents_table", EnhancedDataTable) - try: - if 0 <= table.cursor_coordinate.row < len(self.documents): - return self.documents[table.cursor_coordinate.row] - except (AttributeError, IndexError): - pass - return None - - # Action methods - def action_refresh(self) -> None: - """Refresh the document list.""" - self.run_worker(self.load_documents()) - - def action_toggle_selection(self) -> None: - """Toggle selection of current row.""" - if doc := self.get_current_document(): - doc_id = doc["id"] - if doc_id in self.selected_docs: - self.selected_docs.remove(doc_id) - else: - self.selected_docs.add(doc_id) - - self.run_worker(self.update_table()) - self.update_selection_status() - - def action_select_all(self) -> None: - """Select all documents on current page.""" - for doc in self.documents: - self.selected_docs.add(doc["id"]) - self.run_worker(self.update_table()) - self.update_selection_status() - - def action_select_none(self) -> None: - """Clear all selections.""" - self.selected_docs.clear() - self.run_worker(self.update_table()) - self.update_selection_status() - - def action_delete_selected(self) -> None: - """Delete selected documents.""" - if self.selected_docs: - from .dialogs import ConfirmDocumentDeleteScreen - - self.app.push_screen( - ConfirmDocumentDeleteScreen(list(self.selected_docs), self.collection, self) - ) - else: - self.notify("No documents selected", severity="warning") - - def action_next_page(self) -> None: - """Go to next page.""" - if self.current_offset + self.page_size < self.collection["count"]: - self.current_offset += self.page_size - self.run_worker(self.load_documents()) - - def action_prev_page(self) -> None: - """Go to previous page.""" - if self.current_offset >= self.page_size: - self.current_offset -= self.page_size - self.run_worker(self.load_documents()) - - def action_first_page(self) -> None: - """Go to first page.""" - if self.current_offset > 0: - self.current_offset = 0 - self.run_worker(self.load_documents()) - - def action_last_page(self) -> None: - """Go to last page.""" - total_docs = self.collection["count"] - last_offset = ((total_docs - 1) // self.page_size) * self.page_size - if self.current_offset != last_offset: - self.current_offset = last_offset - self.run_worker(self.load_documents()) - - def on_button_pressed(self, event: Button.Pressed) -> None: - """Handle button presses.""" - if event.button.id == "refresh_docs_btn": - self.action_refresh() - elif event.button.id == "delete_selected_btn": - self.action_delete_selected() - elif event.button.id == "select_all_btn": - self.action_select_all() - elif event.button.id == "clear_selection_btn": - self.action_select_none() - elif event.button.id == "next_page_btn": - self.action_next_page() - elif event.button.id == "prev_page_btn": - self.action_prev_page() - - def on_enhanced_data_table_row_toggled(self, event: EnhancedDataTable.RowToggled) -> None: - """Handle row toggle from enhanced table.""" - if 0 <= event.row_index < len(self.documents): - doc = self.documents[event.row_index] - doc_id = doc["id"] - - if doc_id in self.selected_docs: - self.selected_docs.remove(doc_id) - else: - self.selected_docs.add(doc_id) - - self.run_worker(self.update_table()) - self.update_selection_status() - - def on_enhanced_data_table_select_all(self, event: EnhancedDataTable.SelectAll) -> None: - """Handle select all from enhanced table.""" - self.action_select_all() - - def on_enhanced_data_table_clear_selection( - self, event: EnhancedDataTable.ClearSelection - ) -> None: - """Handle clear selection from enhanced table.""" - self.action_select_none() - - """Enhanced ingestion screen with multi-storage support.""" from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from textual import work from textual.app import ComposeResult @@ -2608,12 +1931,23 @@ class IngestionScreen(ModalScreen[None]): Label("📋 Source Type (Press 1/2/3):", classes="input-label"), Horizontal( Button("🌐 Web (1)", id="web_btn", variant="primary", classes="type-button"), - Button("📦 Repository (2)", id="repo_btn", variant="default", classes="type-button"), - Button("📖 Documentation (3)", id="docs_btn", variant="default", classes="type-button"), + Button( + "📦 Repository (2)", id="repo_btn", variant="default", classes="type-button" + ), + Button( + "📖 Documentation (3)", + id="docs_btn", + variant="default", + classes="type-button", + ), classes="type_buttons", ), Rule(line_style="dashed"), - Label(f"🗄️ Target Storages ({len(self.available_backends)} available):", classes="input-label", id="backend_label"), + Label( + f"🗄️ Target Storages ({len(self.available_backends)} available):", + classes="input-label", + id="backend_label", + ), Container( *self._create_backend_checkbox_widgets(), classes="backend-selection", @@ -2642,7 +1976,6 @@ class IngestionScreen(ModalScreen[None]): yield LoadingIndicator(id="loading", classes="pulse") - def _create_backend_checkbox_widgets(self) -> list[Checkbox]: """Create checkbox widgets for each available backend.""" checkboxes: list[Checkbox] = [ @@ -2722,19 +2055,26 @@ class IngestionScreen(ModalScreen[None]): collection_name = collection_input.value.strip() if not source_url: - self.notify("🔍 Please enter a source URL", severity="error") + cast("CollectionManagementApp", self.app).safe_notify( + "🔍 Please enter a source URL", severity="error" + ) url_input.focus() return # Validate URL format if not self._validate_url(source_url): - self.notify("❌ Invalid URL format. Please enter a valid HTTP/HTTPS URL or file:// path", severity="error") + cast("CollectionManagementApp", self.app).safe_notify( + "❌ Invalid URL format. Please enter a valid HTTP/HTTPS URL or file:// path", + severity="error", + ) url_input.focus() return resolved_backends = self._resolve_selected_backends() if not resolved_backends: - self.notify("⚠️ Select at least one storage backend", severity="warning") + cast("CollectionManagementApp", self.app).safe_notify( + "⚠️ Select at least one storage backend", severity="warning" + ) return self.selected_backends = resolved_backends @@ -2749,18 +2089,16 @@ class IngestionScreen(ModalScreen[None]): url_lower = url.lower() # Allow HTTP/HTTPS URLs - if url_lower.startswith(('http://', 'https://')): + if url_lower.startswith(("http://", "https://")): # Additional validation could be added here return True # Allow file:// URLs for repository paths - if url_lower.startswith('file://'): + if url_lower.startswith("file://"): return True # Allow local file paths that look like repositories - return '/' in url and not url_lower.startswith( - ('javascript:', 'data:', 'vbscript:') - ) + return "/" in url and not url_lower.startswith(("javascript:", "data:", "vbscript:")) def _resolve_selected_backends(self) -> list[StorageBackend]: selected: list[StorageBackend] = [] @@ -2791,13 +2129,14 @@ class IngestionScreen(ModalScreen[None]): if not self.selected_backends: status_widget.update("📋 Selected: None") elif len(self.selected_backends) == 1: - backend_name = BACKEND_LABELS.get(self.selected_backends[0], self.selected_backends[0].value) + backend_name = BACKEND_LABELS.get( + self.selected_backends[0], self.selected_backends[0].value + ) status_widget.update(f"📋 Selected: {backend_name}") else: # Multiple backends selected backend_names = [ - BACKEND_LABELS.get(backend, backend.value) - for backend in self.selected_backends + BACKEND_LABELS.get(backend, backend.value) for backend in self.selected_backends ] if len(backend_names) <= 3: # Show all names if 3 or fewer @@ -2822,7 +2161,10 @@ class IngestionScreen(ModalScreen[None]): for backend in BACKEND_ORDER: if backend not in self.available_backends: continue - if backend.value.lower() == backend_name_lower or backend.name.lower() == backend_name_lower: + if ( + backend.value.lower() == backend_name_lower + or backend.name.lower() == backend_name_lower + ): matched_backends.append(backend) break return matched_backends or [self.available_backends[0]] @@ -2854,6 +2196,7 @@ class IngestionScreen(ModalScreen[None]): loading.display = False except Exception: pass + cast("CollectionManagementApp", self.app).call_from_thread(_update) def progress_reporter(percent: int, message: str) -> None: @@ -2865,6 +2208,7 @@ class IngestionScreen(ModalScreen[None]): progress_text.update(message) except Exception: pass + cast("CollectionManagementApp", self.app).call_from_thread(_update_progress) try: @@ -2880,13 +2224,18 @@ class IngestionScreen(ModalScreen[None]): for i, backend in enumerate(backends): progress_percent = 20 + (60 * i) // len(backends) - progress_reporter(progress_percent, f"🔗 Processing {backend.value} backend ({i+1}/{len(backends)})...") + progress_reporter( + progress_percent, + f"🔗 Processing {backend.value} backend ({i + 1}/{len(backends)})...", + ) try: # Run the Prefect flow for this backend using asyncio.run with timeout import asyncio - async def run_flow_with_timeout(current_backend: StorageBackend = backend) -> IngestionResult: + async def run_flow_with_timeout( + current_backend: StorageBackend = backend, + ) -> IngestionResult: return await asyncio.wait_for( create_ingestion_flow( source_url=source_url, @@ -2895,7 +2244,7 @@ class IngestionScreen(ModalScreen[None]): collection_name=final_collection_name, progress_callback=progress_reporter, ), - timeout=600.0 # 10 minute timeout + timeout=600.0, # 10 minute timeout ) result = asyncio.run(run_flow_with_timeout()) @@ -2904,25 +2253,33 @@ class IngestionScreen(ModalScreen[None]): total_failed += result.documents_failed if result.error_messages: - flow_errors.extend([f"{backend.value}: {err}" for err in result.error_messages]) + flow_errors.extend( + [f"{backend.value}: {err}" for err in result.error_messages] + ) except TimeoutError: error_msg = f"{backend.value}: Timeout after 10 minutes" flow_errors.append(error_msg) progress_reporter(0, f"❌ {backend.value} timed out") - def notify_timeout(msg: str = f"⏰ {backend.value} flow timed out after 10 minutes") -> None: + + def notify_timeout( + msg: str = f"⏰ {backend.value} flow timed out after 10 minutes", + ) -> None: try: self.notify(msg, severity="error", markup=False) except Exception: pass + cast("CollectionManagementApp", self.app).call_from_thread(notify_timeout) except Exception as exc: flow_errors.append(f"{backend.value}: {exc}") + def notify_error(msg: str = f"❌ {backend.value} flow failed: {exc}") -> None: try: self.notify(msg, severity="error", markup=False) except Exception: pass + cast("CollectionManagementApp", self.app).call_from_thread(notify_error) successful = total_successful @@ -2948,20 +2305,38 @@ class IngestionScreen(ModalScreen[None]): cast("CollectionManagementApp", self.app).call_from_thread(notify_results) - import time - time.sleep(2) - cast("CollectionManagementApp", self.app).pop_screen() + def _pop() -> None: + try: + self.app.pop_screen() + except Exception: + pass + + # Schedule screen pop via timer instead of blocking + cast("CollectionManagementApp", self.app).call_from_thread( + lambda: self.app.set_timer(2.0, _pop) + ) except Exception as exc: # pragma: no cover - defensive progress_reporter(0, f"❌ Prefect flows error: {exc}") + def notify_error(msg: str = f"❌ Prefect flows failed: {exc}") -> None: try: self.notify(msg, severity="error") except Exception: pass + cast("CollectionManagementApp", self.app).call_from_thread(notify_error) - import time - time.sleep(2) + + def _pop_on_error() -> None: + try: + self.app.pop_screen() + except Exception: + pass + + # Schedule screen pop via timer for error case too + cast("CollectionManagementApp", self.app).call_from_thread( + lambda: self.app.set_timer(2.0, _pop_on_error) + ) finally: update_ui("hide_loading") @@ -3012,12 +2387,24 @@ class SearchScreen(Screen[None]): @override def compose(self) -> ComposeResult: yield Header() + # Check if search is supported for this backend + backends = self.collection["backend"] + if isinstance(backends, str): + backends = [backends] + search_supported = "weaviate" in backends + search_indicator = "✅ Search supported" if search_supported else "❌ Search not supported" + yield Container( Static( - f"🔍 Search in: {self.collection['name']} ({self.collection['backend']})", + f"🔍 Search in: {self.collection['name']} ({', '.join(backends)}) - {search_indicator}", classes="title", ), - Static("Press / or Ctrl+F to focus search, Enter to search", classes="subtitle"), + Static( + "Press / or Ctrl+F to focus search, Enter to search" + if search_supported + else "Search functionality not available for this backend", + classes="subtitle", + ), Input(placeholder="Enter search query... (press Enter to search)", id="search_input"), Button("🔍 Search", id="search_btn", variant="primary"), Button("🗑️ Clear Results", id="clear_btn", variant="default"), @@ -3096,7 +2483,9 @@ class SearchScreen(Screen[None]): finally: loading.display = False - def _setup_search_ui(self, loading: LoadingIndicator, table: EnhancedDataTable, status: Static, query: str) -> None: + def _setup_search_ui( + self, loading: LoadingIndicator, table: EnhancedDataTable, status: Static, query: str + ) -> None: """Setup the search UI elements.""" loading.display = True status.update(f"🔍 Searching for '{query}'...") @@ -3108,10 +2497,18 @@ class SearchScreen(Screen[None]): if self.collection["type"] == "weaviate" and self.weaviate: return await self.search_weaviate(query) elif self.collection["type"] == "openwebui" and self.openwebui: - return await self.search_openwebui(query) + # OpenWebUI search is not yet implemented + self.notify("Search not supported for OpenWebUI collections", severity="warning") + return [] + elif self.collection["type"] == "r2r": + # R2R search would go here when implemented + self.notify("Search not supported for R2R collections", severity="warning") + return [] return [] - def _populate_results_table(self, table: EnhancedDataTable, results: list[dict[str, str | float]]) -> None: + def _populate_results_table( + self, table: EnhancedDataTable, results: list[dict[str, str | float]] + ) -> None: """Populate the results table with search results.""" for result in results: row_data = self._format_result_row(result) @@ -3162,7 +2559,11 @@ class SearchScreen(Screen[None]): return str(score) def _update_search_status( - self, status: Static, query: str, results: list[dict[str, str | float]], table: EnhancedDataTable + self, + status: Static, + query: str, + results: list[dict[str, str | float]], + table: EnhancedDataTable, ) -> None: """Update search status and notifications based on results.""" if not results: @@ -3232,153 +2633,542 @@ class SearchScreen(Screen[None]): return [] - -"""TUI runner functions and initialization.""" + +"""Storage management utilities for TUI applications.""" from __future__ import annotations import asyncio -import logging -from logging import Logger -from logging.handlers import QueueHandler, RotatingFileHandler -from pathlib import Path -from queue import Queue -from typing import NamedTuple +from collections.abc import AsyncGenerator, Coroutine, Sequence +from typing import TYPE_CHECKING, Protocol -from ....config import configure_prefect, get_settings -from ....core.models import StorageBackend -from .storage_manager import StorageManager +from pydantic import SecretStr + +from ....core.exceptions import StorageError +from ....core.models import Document, StorageBackend, StorageConfig +from ....storage.base import BaseStorage +from ....storage.openwebui import OpenWebUIStorage +from ....storage.r2r.storage import R2RStorage +from ....storage.weaviate import WeaviateStorage +from ..models import CollectionInfo, StorageCapabilities + +if TYPE_CHECKING: + from ....config.settings import Settings -class _TuiLoggingContext(NamedTuple): - """Container describing configured logging outputs for the TUI.""" +class StorageBackendProtocol(Protocol): + """Protocol defining storage backend interface.""" - queue: Queue[logging.LogRecord] - formatter: logging.Formatter - log_file: Path | None + async def initialize(self) -> None: ... + async def count(self, *, collection_name: str | None = None) -> int: ... + async def list_collections(self) -> list[str]: ... + async def search( + self, + query: str, + limit: int = 10, + threshold: float = 0.7, + *, + collection_name: str | None = None, + ) -> AsyncGenerator[Document, None]: ... + async def close(self) -> None: ... -_logging_context: _TuiLoggingContext | None = None +class MultiStorageAdapter(BaseStorage): + """Mirror writes to multiple storage backends.""" + def __init__(self, storages: Sequence[BaseStorage]) -> None: + if not storages: + raise ValueError("MultiStorageAdapter requires at least one storage backend") -def _configure_tui_logging(*, log_level: str) -> _TuiLoggingContext: - """Configure logging so that messages do not break the TUI output.""" + unique: list[BaseStorage] = [] + seen_ids: set[int] = set() + for storage in storages: + storage_id = id(storage) + if storage_id in seen_ids: + continue + seen_ids.add(storage_id) + unique.append(storage) - global _logging_context - if _logging_context is not None: - return _logging_context + self._storages: list[BaseStorage] = unique + self._primary: BaseStorage = unique[0] + super().__init__(self._primary.config) - resolved_level = getattr(logging, log_level.upper(), logging.INFO) - log_queue: Queue[logging.LogRecord] = Queue() - formatter = logging.Formatter( - fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) + async def initialize(self) -> None: + for storage in self._storages: + await storage.initialize() - root_logger = logging.getLogger() - root_logger.setLevel(resolved_level) + async def store(self, document: Document, *, collection_name: str | None = None) -> str: + # Store in primary backend first + primary_id: str = await self._primary.store(document, collection_name=collection_name) - # Remove existing stream handlers to prevent console flicker inside the TUI - for handler in list(root_logger.handlers): - root_logger.removeHandler(handler) + # Replicate to secondary backends concurrently + if len(self._storages) > 1: - queue_handler = QueueHandler(log_queue) - queue_handler.setLevel(resolved_level) - root_logger.addHandler(queue_handler) + async def replicate_to_backend( + storage: BaseStorage, + ) -> tuple[BaseStorage, bool, Exception | None]: + try: + await storage.store(document, collection_name=collection_name) + return storage, True, None + except Exception as exc: + return storage, False, exc - log_file: Path | None = None - try: - log_dir = Path.cwd() / "logs" - log_dir.mkdir(parents=True, exist_ok=True) - log_file = log_dir / "tui.log" - file_handler = RotatingFileHandler( - log_file, - maxBytes=2_000_000, - backupCount=5, - encoding="utf-8", + tasks = [replicate_to_backend(storage) for storage in self._storages[1:]] + results = await asyncio.gather(*tasks, return_exceptions=True) + + failures: list[str] = [] + errors: list[Exception] = [] + + for result in results: + if isinstance(result, tuple): + storage, success, error = result + if not success and error is not None: + failures.append(self._format_backend_label(storage)) + errors.append(error) + elif isinstance(result, Exception): + failures.append("unknown") + errors.append(result) + + if failures: + backends = ", ".join(failures) + primary_error = errors[0] if errors else Exception("Unknown replication error") + raise StorageError( + f"Document stored in primary backend but replication failed for: {backends}" + ) from primary_error + + return primary_id + + async def store_batch( + self, documents: list[Document], *, collection_name: str | None = None + ) -> list[str]: + # Store in primary backend first + primary_ids: list[str] = await self._primary.store_batch( + documents, collection_name=collection_name ) - file_handler.setLevel(resolved_level) - file_handler.setFormatter(formatter) - root_logger.addHandler(file_handler) - except OSError as exc: # pragma: no cover - filesystem specific - fallback = logging.getLogger(__name__) - fallback.warning("Failed to configure file logging for TUI: %s", exc) - _logging_context = _TuiLoggingContext(log_queue, formatter, log_file) - return _logging_context + # Replicate to secondary backends concurrently + if len(self._storages) > 1: + + async def replicate_batch_to_backend( + storage: BaseStorage, + ) -> tuple[BaseStorage, bool, Exception | None]: + try: + await storage.store_batch(documents, collection_name=collection_name) + return storage, True, None + except Exception as exc: + return storage, False, exc + + tasks = [replicate_batch_to_backend(storage) for storage in self._storages[1:]] + results = await asyncio.gather(*tasks, return_exceptions=True) + + failures: list[str] = [] + errors: list[Exception] = [] + + for result in results: + if isinstance(result, tuple): + storage, success, error = result + if not success and error is not None: + failures.append(self._format_backend_label(storage)) + errors.append(error) + elif isinstance(result, Exception): + failures.append("unknown") + errors.append(result) + + if failures: + backends = ", ".join(failures) + primary_error = ( + errors[0] if errors else Exception("Unknown batch replication error") + ) + raise StorageError( + f"Batch stored in primary backend but replication failed for: {backends}" + ) from primary_error + + return primary_ids + + async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool: + # Delete from primary backend first + primary_deleted: bool = await self._primary.delete( + document_id, collection_name=collection_name + ) + + # Delete from secondary backends concurrently + if len(self._storages) > 1: + + async def delete_from_backend( + storage: BaseStorage, + ) -> tuple[BaseStorage, bool, Exception | None]: + try: + await storage.delete(document_id, collection_name=collection_name) + return storage, True, None + except Exception as exc: + return storage, False, exc + + tasks = [delete_from_backend(storage) for storage in self._storages[1:]] + results = await asyncio.gather(*tasks, return_exceptions=True) + + failures: list[str] = [] + errors: list[Exception] = [] + + for result in results: + if isinstance(result, tuple): + storage, success, error = result + if not success and error is not None: + failures.append(self._format_backend_label(storage)) + errors.append(error) + elif isinstance(result, Exception): + failures.append("unknown") + errors.append(result) + + if failures: + backends = ", ".join(failures) + primary_error = errors[0] if errors else Exception("Unknown deletion error") + raise StorageError( + f"Document deleted from primary backend but failed for: {backends}" + ) from primary_error + + return primary_deleted + + async def count(self, *, collection_name: str | None = None) -> int: + count_result: int = await self._primary.count(collection_name=collection_name) + return count_result + + async def list_collections(self) -> list[str]: + list_fn = getattr(self._primary, "list_collections", None) + if list_fn is None: + return [] + collections_result: list[str] = await list_fn() + return collections_result + + async def search( + self, + query: str, + limit: int = 10, + threshold: float = 0.7, + *, + collection_name: str | None = None, + ) -> AsyncGenerator[Document, None]: + async for item in self._primary.search( + query, + limit=limit, + threshold=threshold, + collection_name=collection_name, + ): + yield item + + async def close(self) -> None: + for storage in self._storages: + close_fn = getattr(storage, "close", None) + if close_fn is not None: + await close_fn() + + def _format_backend_label(self, storage: BaseStorage) -> str: + backend = getattr(storage.config, "backend", None) + if isinstance(backend, StorageBackend): + backend_value: str = backend.value + return backend_value + class_name: str = storage.__class__.__name__ + return class_name -LOGGER: Logger = logging.getLogger(__name__) +class StorageManager: + """Centralized manager for all storage backend operations.""" + def __init__(self, settings: Settings) -> None: + """Initialize storage manager with application settings.""" + self.settings: Settings = settings + self.backends: dict[StorageBackend, BaseStorage] = {} + self.capabilities: dict[StorageBackend, StorageCapabilities] = {} + self._initialized: bool = False -async def run_textual_tui() -> None: - """Run the enhanced modern TUI with better error handling and initialization.""" - settings = get_settings() - configure_prefect(settings) + async def initialize_all_backends(self) -> dict[StorageBackend, bool]: + """Initialize all available storage backends with timeout protection.""" + results: dict[StorageBackend, bool] = {} - logging_context = _configure_tui_logging(log_level=settings.log_level) + async def init_backend( + backend_type: StorageBackend, config: StorageConfig, storage_class: type[BaseStorage] + ) -> bool: + """Initialize a single backend with timeout.""" + try: + storage = storage_class(config) + await asyncio.wait_for(storage.initialize(), timeout=30.0) + self.backends[backend_type] = storage + if backend_type == StorageBackend.WEAVIATE: + self.capabilities[backend_type] = StorageCapabilities.VECTOR_SEARCH + elif backend_type == StorageBackend.OPEN_WEBUI: + self.capabilities[backend_type] = StorageCapabilities.KNOWLEDGE_BASE + elif backend_type == StorageBackend.R2R: + self.capabilities[backend_type] = StorageCapabilities.FULL_FEATURED + return True + except (TimeoutError, Exception): + return False - LOGGER.info("Initializing collection management TUI") - LOGGER.info("Scanning available storage backends") + # Initialize backends concurrently with timeout protection + tasks: list[tuple[StorageBackend, Coroutine[None, None, bool]]] = [] - # Initialize storage manager - storage_manager = StorageManager(settings) - backend_status = await storage_manager.initialize_all_backends() - - # Report initialization results - for backend, success in backend_status.items(): - if success: - LOGGER.info("%s connected successfully", backend.value) + # Try Weaviate + if self.settings.weaviate_endpoint: + config = StorageConfig( + backend=StorageBackend.WEAVIATE, + endpoint=self.settings.weaviate_endpoint, + api_key=SecretStr(self.settings.weaviate_api_key) + if self.settings.weaviate_api_key + else None, + collection_name="default", + ) + tasks.append( + ( + StorageBackend.WEAVIATE, + init_backend(StorageBackend.WEAVIATE, config, WeaviateStorage), + ) + ) else: - LOGGER.warning("%s connection failed", backend.value) + results[StorageBackend.WEAVIATE] = False - available_backends = storage_manager.get_available_backends() - if not available_backends: - LOGGER.error("Could not connect to any storage backend") - LOGGER.info("Please check your configuration and try again") - LOGGER.info("Supported backends: Weaviate, OpenWebUI, R2R") - return + # Try OpenWebUI + if self.settings.openwebui_endpoint and self.settings.openwebui_api_key: + config = StorageConfig( + backend=StorageBackend.OPEN_WEBUI, + endpoint=self.settings.openwebui_endpoint, + api_key=SecretStr(self.settings.openwebui_api_key) + if self.settings.openwebui_api_key + else None, + collection_name="default", + ) + tasks.append( + ( + StorageBackend.OPEN_WEBUI, + init_backend(StorageBackend.OPEN_WEBUI, config, OpenWebUIStorage), + ) + ) + else: + results[StorageBackend.OPEN_WEBUI] = False - LOGGER.info( - "Launching TUI with %d backend(s): %s", - len(available_backends), - ", ".join(backend.value for backend in available_backends), - ) + # Try R2R + if self.settings.r2r_endpoint: + config = StorageConfig( + backend=StorageBackend.R2R, + endpoint=self.settings.r2r_endpoint, + api_key=SecretStr(self.settings.r2r_api_key) if self.settings.r2r_api_key else None, + collection_name="default", + ) + tasks.append((StorageBackend.R2R, init_backend(StorageBackend.R2R, config, R2RStorage))) + else: + results[StorageBackend.R2R] = False - # Get individual storage instances for backward compatibility - from ....storage.openwebui import OpenWebUIStorage - from ....storage.weaviate import WeaviateStorage + # Execute initialization tasks concurrently + if tasks: + backend_types, task_coroutines = zip(*tasks, strict=False) + task_results: Sequence[bool | BaseException] = await asyncio.gather( + *task_coroutines, return_exceptions=True + ) - weaviate_backend = storage_manager.get_backend(StorageBackend.WEAVIATE) - openwebui_backend = storage_manager.get_backend(StorageBackend.OPEN_WEBUI) - r2r_backend = storage_manager.get_backend(StorageBackend.R2R) + for backend_type, task_result in zip(backend_types, task_results, strict=False): + results[backend_type] = task_result if isinstance(task_result, bool) else False + self._initialized = True + return results - # Type-safe casting to specific storage types - weaviate = weaviate_backend if isinstance(weaviate_backend, WeaviateStorage) else None - openwebui = openwebui_backend if isinstance(openwebui_backend, OpenWebUIStorage) else None + def get_backend(self, backend_type: StorageBackend) -> BaseStorage | None: + """Get storage backend by type.""" + return self.backends.get(backend_type) - # Import here to avoid circular import - from ..app import CollectionManagementApp - app = CollectionManagementApp( - storage_manager, - weaviate, - openwebui, - r2r_backend, - log_queue=logging_context.queue, - log_formatter=logging_context.formatter, - log_file=logging_context.log_file, - ) - try: - await app.run_async() - finally: - LOGGER.info("Shutting down storage connections") - await storage_manager.close_all() - LOGGER.info("All storage connections closed gracefully") + def build_multi_storage_adapter( + self, backends: Sequence[StorageBackend] + ) -> MultiStorageAdapter: + storages: list[BaseStorage] = [] + seen: set[StorageBackend] = set() + for backend in backends: + backend_enum = ( + backend if isinstance(backend, StorageBackend) else StorageBackend(backend) + ) + if backend_enum in seen: + continue + seen.add(backend_enum) + storage = self.backends.get(backend_enum) + if storage is None: + raise ValueError(f"Storage backend {backend_enum.value} is not initialized") + storages.append(storage) + return MultiStorageAdapter(storages) + def get_available_backends(self) -> list[StorageBackend]: + """Get list of successfully initialized backends.""" + return list(self.backends.keys()) -def dashboard() -> None: - """Launch the modern collection dashboard.""" - asyncio.run(run_textual_tui()) + def has_capability(self, backend: StorageBackend, capability: StorageCapabilities) -> bool: + """Check if backend has specific capability.""" + backend_caps = self.capabilities.get(backend, StorageCapabilities.BASIC) + return capability.value <= backend_caps.value + + async def get_all_collections(self) -> list[CollectionInfo]: + """Get collections from all available backends, merging collections with same name.""" + collection_map: dict[str, CollectionInfo] = {} + + for backend_type, storage in self.backends.items(): + try: + backend_collections = await storage.list_collections() + for collection_name in backend_collections: + # Validate collection name + if not collection_name or not isinstance(collection_name, str): + continue + + try: + count = await storage.count(collection_name=collection_name) + # Validate count is non-negative + count = max(count, 0) + except StorageError as e: + # Storage-specific errors - log and use 0 count + import logging + + logging.warning( + f"Failed to get count for {collection_name} on {backend_type.value}: {e}" + ) + count = 0 + except Exception as e: + # Unexpected errors - log and skip this collection from this backend + import logging + + logging.warning( + f"Unexpected error counting {collection_name} on {backend_type.value}: {e}" + ) + continue + + size_mb = count * 0.01 # Rough estimate: 10KB per document + + if collection_name in collection_map: + # Merge with existing collection + existing = collection_map[collection_name] + existing_backends = existing["backend"] + backend_value = backend_type.value + + if isinstance(existing_backends, str): + existing["backend"] = [existing_backends, backend_value] + elif isinstance(existing_backends, list): + # Prevent duplicates + if backend_value not in existing_backends: + existing_backends.append(backend_value) + + # Aggregate counts and sizes + existing["count"] += count + existing["size_mb"] += size_mb + else: + # Create new collection entry + collection_info: CollectionInfo = { + "name": collection_name, + "type": self._get_collection_type(collection_name, backend_type), + "count": count, + "backend": backend_type.value, + "status": "active", + "last_updated": "2024-01-01T00:00:00Z", + "size_mb": size_mb, + } + collection_map[collection_name] = collection_info + except Exception: + continue + + return list(collection_map.values()) + + def _get_collection_type(self, collection_name: str, backend: StorageBackend) -> str: + """Determine collection type based on name and backend.""" + # Prioritize definitive backend type first + if backend == StorageBackend.R2R: + return "r2r" + elif backend == StorageBackend.WEAVIATE: + return "weaviate" + elif backend == StorageBackend.OPEN_WEBUI: + return "openwebui" + + # Fallback to name-based guessing if backend is not specific + name_lower = collection_name.lower() + if "web" in name_lower or "doc" in name_lower: + return "documentation" + elif "repo" in name_lower or "code" in name_lower: + return "repository" + else: + return "general" + + async def search_across_backends( + self, + query: str, + limit: int = 10, + backends: list[StorageBackend] | None = None, + ) -> dict[StorageBackend, list[Document]]: + """Search across multiple backends and return grouped results.""" + if backends is None: + backends = self.get_available_backends() + + results: dict[StorageBackend, list[Document]] = {} + + async def search_backend(backend_type: StorageBackend) -> None: + storage = self.backends.get(backend_type) + if storage: + try: + documents: list[Document] = [] + async for doc in storage.search(query, limit=limit): + documents.append(doc) + results[backend_type] = documents + except Exception: + results[backend_type] = [] + + # Run searches in parallel + tasks = [search_backend(backend) for backend in backends] + await asyncio.gather(*tasks, return_exceptions=True) + + return results + + def get_r2r_storage(self) -> R2RStorage | None: + """Get R2R storage instance if available.""" + storage = self.backends.get(StorageBackend.R2R) + return storage if isinstance(storage, R2RStorage) else None + + async def get_backend_status( + self, + ) -> dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]]: + """Get comprehensive status for all backends.""" + status: dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]] = {} + + for backend_type, storage in self.backends.items(): + try: + collections = await storage.list_collections() + total_docs = 0 + for collection in collections: + total_docs += await storage.count(collection_name=collection) + + backend_status: dict[str, str | int | bool | StorageCapabilities] = { + "available": True, + "collections": len(collections), + "total_documents": total_docs, + "capabilities": self.capabilities.get(backend_type, StorageCapabilities.BASIC), + "endpoint": getattr(storage.config, "endpoint", "unknown"), + } + status[backend_type] = backend_status + except Exception as e: + status[backend_type] = { + "available": False, + "error": str(e), + "capabilities": StorageCapabilities.NONE, + } + + return status + + async def close_all(self) -> None: + """Close all storage connections.""" + for storage in self.backends.values(): + try: + await storage.close() + except Exception: + pass + + self.backends.clear() + self.capabilities.clear() + self._initialized = False + + @property + def is_initialized(self) -> bool: + """Check if storage manager is initialized.""" + return self._initialized + + def supports_advanced_features(self, backend: StorageBackend) -> bool: + """Check if backend supports advanced features like chunks and entities.""" + return self.has_capability(backend, StorageCapabilities.FULL_FEATURED) @@ -3542,9 +3332,7 @@ class ScrapeOptionsForm(Widget): formats.append("screenshot") options: dict[str, object] = { "formats": formats, - "only_main_content": self.query_one( - "#only_main_content", Switch - ).value, + "only_main_content": self.query_one("#only_main_content", Switch).value, } include_tags_input = self.query_one("#include_tags", Input).value if include_tags_input.strip(): @@ -3579,11 +3367,15 @@ class ScrapeOptionsForm(Widget): if include_tags := options.get("include_tags", []): include_list = include_tags if isinstance(include_tags, list) else [] - self.query_one("#include_tags", Input).value = ", ".join(str(tag) for tag in include_list) + self.query_one("#include_tags", Input).value = ", ".join( + str(tag) for tag in include_list + ) if exclude_tags := options.get("exclude_tags", []): exclude_list = exclude_tags if isinstance(exclude_tags, list) else [] - self.query_one("#exclude_tags", Input).value = ", ".join(str(tag) for tag in exclude_list) + self.query_one("#exclude_tags", Input).value = ", ".join( + str(tag) for tag in exclude_list + ) # Set performance wait_for = options.get("wait_for") @@ -3887,7 +3679,9 @@ class ExtractOptionsForm(Widget): "content": "string", "tags": ["string"] }""" - prompt_widget.text = "Extract article title, author, publication date, main content, and associated tags" + prompt_widget.text = ( + "Extract article title, author, publication date, main content, and associated tags" + ) elif event.button.id == "preset_product": schema_widget.text = """{ @@ -3897,7 +3691,9 @@ class ExtractOptionsForm(Widget): "category": "string", "availability": "string" }""" - prompt_widget.text = "Extract product name, price, description, category, and availability status" + prompt_widget.text = ( + "Extract product name, price, description, category, and availability status" + ) elif event.button.id == "preset_contact": schema_widget.text = """{ @@ -3907,7 +3703,9 @@ class ExtractOptionsForm(Widget): "company": "string", "position": "string" }""" - prompt_widget.text = "Extract contact information including name, email, phone, company, and position" + prompt_widget.text = ( + "Extract contact information including name, email, phone, company, and position" + ) elif event.button.id == "preset_data": schema_widget.text = """{ @@ -4160,10 +3958,17 @@ class ChunkViewer(Widget): "id": str(chunk_data.get("id", "")), "document_id": self.document_id, "content": str(chunk_data.get("text", "")), - "start_index": (lambda si: int(si) if isinstance(si, (int, str)) else 0)(chunk_data.get("start_index", 0)), - "end_index": (lambda ei: int(ei) if isinstance(ei, (int, str)) else 0)(chunk_data.get("end_index", 0)), + "start_index": (lambda si: int(si) if isinstance(si, (int, str)) else 0)( + chunk_data.get("start_index", 0) + ), + "end_index": (lambda ei: int(ei) if isinstance(ei, (int, str)) else 0)( + chunk_data.get("end_index", 0) + ), "metadata": ( - dict(metadata_val) if (metadata_val := chunk_data.get("metadata")) and isinstance(metadata_val, dict) else {} + dict(metadata_val) + if (metadata_val := chunk_data.get("metadata")) + and isinstance(metadata_val, dict) + else {} ), } self.chunks.append(chunk_info) @@ -4336,10 +4141,10 @@ class EntityGraph(Widget): """Show detailed information about an entity.""" details_widget = self.query_one("#entity_details", Static) - details_text = f"""**Entity:** {entity['name']} -**Type:** {entity['type']} -**Confidence:** {entity['confidence']:.2%} -**ID:** {entity['id']} + details_text = f"""**Entity:** {entity["name"]} +**Type:** {entity["type"]} +**Confidence:** {entity["confidence"]:.2%} +**ID:** {entity["id"]} **Metadata:** """ @@ -4610,7 +4415,6 @@ else: # pragma: no cover - optional dependency fallback R2RStorage = BaseStorage - class CollectionManagementApp(App[None]): """Enhanced modern Textual application with comprehensive keyboard navigation.""" @@ -4953,7 +4757,9 @@ class ResponsiveGrid(Container): markup: bool = True, ) -> None: """Initialize responsive grid.""" - super().__init__(*children, name=name, id=id, classes=classes, disabled=disabled, markup=markup) + super().__init__( + *children, name=name, id=id, classes=classes, disabled=disabled, markup=markup + ) self._columns: int = columns self._auto_fit: bool = auto_fit self._compact: bool = compact @@ -5223,7 +5029,9 @@ class CardLayout(ResponsiveGrid): ) -> None: """Initialize card layout with default settings for cards.""" # Default to auto-fit cards with minimum width - super().__init__(auto_fit=True, name=name, id=id, classes=classes, disabled=disabled, markup=markup) + super().__init__( + auto_fit=True, name=name, id=id, classes=classes, disabled=disabled, markup=markup + ) class SplitPane(Container): @@ -5292,7 +5100,9 @@ class SplitPane(Container): if self._vertical: cast(Widget, self).add_class("vertical") - pane_classes = ("top-pane", "bottom-pane") if self._vertical else ("left-pane", "right-pane") + pane_classes = ( + ("top-pane", "bottom-pane") if self._vertical else ("left-pane", "right-pane") + ) yield Container(self._left_content, classes=pane_classes[0]) yield Static("", classes="splitter") @@ -5309,12 +5119,1457 @@ class SplitPane(Container): widget.query_one(".right-pane").styles.width = f"{(1 - self._split_ratio) * 100}%" + +"""CLI interface for ingestion pipeline.""" + +import asyncio +from typing import Annotated + +import typer +from pydantic import SecretStr +from rich.console import Console +from rich.panel import Panel +from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn +from rich.table import Table + +from ..config import configure_prefect, get_settings +from ..core.models import ( + IngestionResult, + IngestionSource, + StorageBackend, + StorageConfig, +) +from ..flows.ingestion import create_ingestion_flow +from ..flows.scheduler import create_scheduled_deployment, serve_deployments + +app = typer.Typer( + name="ingest", + help="🚀 Modern Document Ingestion Pipeline - Advanced web and repository processing", + rich_markup_mode="rich", + add_completion=False, +) +console = Console() + + +@app.callback() +def main( + version: Annotated[ + bool, typer.Option("--version", "-v", help="Show version information") + ] = False, +) -> None: + """ + 🚀 Modern Document Ingestion Pipeline + + [bold cyan]Advanced document processing and management platform[/bold cyan] + + Features: + • 🌐 Web scraping and crawling with Firecrawl + • 📦 Repository ingestion with Repomix + • 🗄️ Multiple storage backends (Weaviate, OpenWebUI, R2R) + • 📊 Modern TUI for collection management + • ⚡ Async processing with Prefect orchestration + • 🎨 Rich CLI with enhanced visuals + """ + settings = get_settings() + configure_prefect(settings) + + if version: + console.print( + Panel( + ( + "[bold magenta]Ingest Pipeline v0.1.0[/bold magenta]\n" + "[dim]Modern Document Ingestion & Management System[/dim]" + ), + title="🚀 Version Info", + border_style="magenta", + ) + ) + raise typer.Exit() + + +@app.command() +def ingest( + source_url: Annotated[str, typer.Argument(help="URL or path to ingest from")], + source_type: Annotated[ + IngestionSource, typer.Option("--type", "-t", help="Type of source") + ] = IngestionSource.WEB, + storage: Annotated[ + StorageBackend, typer.Option("--storage", "-s", help="Storage backend") + ] = StorageBackend.WEAVIATE, + collection: Annotated[ + str | None, + typer.Option( + "--collection", "-c", help="Target collection name (auto-generated if not specified)" + ), + ] = None, + validate: Annotated[ + bool, typer.Option("--validate/--no-validate", help="Validate source before ingesting") + ] = True, +) -> None: + """ + 🚀 Run a one-time ingestion job with enhanced progress tracking. + + This command processes documents from various sources and stores them in + your chosen backend with full progress visualization. + """ + # Enhanced startup message + console.print( + Panel( + ( + f"[bold cyan]🚀 Starting Modern Ingestion[/bold cyan]\n\n" + f"[yellow]Source:[/yellow] {source_url}\n" + f"[yellow]Type:[/yellow] {source_type.value.title()}\n" + f"[yellow]Storage:[/yellow] {storage.value.replace('_', ' ').title()}\n" + f"[yellow]Collection:[/yellow] {collection or '[dim]Auto-generated[/dim]'}" + ), + title="🔥 Ingestion Configuration", + border_style="cyan", + ) + ) + + async def run_with_progress() -> IngestionResult: + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + console=console, + ) as progress: + task = progress.add_task("🔄 Processing documents...", total=100) + + # Simulate progress updates during ingestion + progress.update(task, advance=20, description="🔗 Connecting to services...") + await asyncio.sleep(0.5) + + progress.update(task, advance=30, description="📄 Fetching documents...") + result = await run_ingestion( + url=source_url, + source_type=source_type, + storage_backend=storage, + collection_name=collection, + validate_first=validate, + ) + + progress.update(task, advance=50, description="✅ Ingestion complete!") + return result + + # Use asyncio.run() with proper event loop handling + try: + result = asyncio.run(run_with_progress()) + except RuntimeError as e: + if "asyncio.run() cannot be called from a running event loop" in str(e): + # If we're already in an event loop (e.g., in Jupyter), use nest_asyncio + try: + import nest_asyncio + + nest_asyncio.apply() + result = asyncio.run(run_with_progress()) + except ImportError: + # Fallback: get the current loop and run the coroutine + loop = asyncio.get_event_loop() + result = loop.run_until_complete(run_with_progress()) + else: + raise + + # Enhanced results display + status_color = "green" if result.status.value == "completed" else "red" + + # Create results table with enhanced styling + table = Table( + title="📊 Ingestion Results", + title_style="bold magenta", + border_style="cyan", + header_style="bold blue", + ) + table.add_column("📋 Metric", style="cyan", no_wrap=True) + table.add_column("📈 Value", style=status_color, justify="right") + + # Add enhanced status icon + status_icon = "✅" if result.status.value == "completed" else "❌" + table.add_row("Status", f"{status_icon} {result.status.value.title()}") + + table.add_row("Documents Processed", f"📄 {result.documents_processed:,}") + table.add_row("Documents Failed", f"⚠️ {result.documents_failed:,}") + table.add_row("Duration", f"⏱️ {result.duration_seconds:.2f}s") + + if result.error_messages: + error_text = "\n".join(f"❌ {error}" for error in result.error_messages[:3]) + if len(result.error_messages) > 3: + error_text += f"\n... and {len(result.error_messages) - 3} more errors" + table.add_row("Errors", error_text) + + console.print(table) + + # Success celebration or error guidance + if result.status.value == "completed" and result.documents_processed > 0: + console.print( + Panel( + ( + f"🎉 [bold green]Success![/bold green] {result.documents_processed} documents ingested\n\n" + f"💡 [dim]Try '[bold cyan]ingest modern[/bold cyan]' to explore your collections![/dim]" + ), + title="✨ Ingestion Complete", + border_style="green", + ) + ) + elif result.error_messages: + console.print( + Panel( + ( + "❌ [bold red]Ingestion encountered errors[/bold red]\n\n" + "💡 [dim]Check your configuration and try again[/dim]" + ), + title="⚠️ Issues Detected", + border_style="red", + ) + ) + + +@app.command() +def schedule( + name: Annotated[str, typer.Argument(help="Deployment name")], + source_url: Annotated[str, typer.Argument(help="URL or path to ingest from")], + source_type: Annotated[ + IngestionSource, typer.Option("--type", "-t", help="Type of source") + ] = IngestionSource.WEB, + storage: Annotated[ + StorageBackend, typer.Option("--storage", "-s", help="Storage backend") + ] = StorageBackend.WEAVIATE, + cron: Annotated[ + str | None, typer.Option("--cron", "-c", help="Cron expression for scheduling") + ] = None, + interval: Annotated[int, typer.Option("--interval", "-i", help="Interval in minutes")] = 60, + serve_now: Annotated[ + bool, typer.Option("--serve/--no-serve", help="Start serving immediately") + ] = False, +) -> None: + """ + Create a scheduled deployment for recurring ingestion. + """ + console.print(f"[bold blue]Creating deployment: {name}[/bold blue]") + + deployment = create_scheduled_deployment( + name=name, + source_url=source_url, + source_type=source_type, + storage_backend=storage, + schedule_type="cron" if cron else "interval", + cron_expression=cron, + interval_minutes=interval, + ) + + console.print(f"[green]✓ Deployment '{name}' created[/green]") + + if serve_now: + console.print("[yellow]Starting deployment server...[/yellow]") + serve_deployments([deployment]) + + +@app.command() +def serve( + config_file: Annotated[ + str | None, typer.Option("--config", "-c", help="Path to deployments config file") + ] = None, + ui: Annotated[ + str | None, typer.Option("--ui", help="Launch user interface (options: tui, web)") + ] = None, +) -> None: + """ + 🚀 Serve configured deployments with optional UI interface. + + Launch the deployment server to run scheduled ingestion jobs, + optionally with a modern Terminal User Interface (TUI) or web interface. + """ + # Handle UI mode first + if ui == "tui": + console.print( + Panel( + ( + "[bold cyan]🚀 Launching Enhanced TUI[/bold cyan]\n\n" + "[yellow]Features:[/yellow]\n" + "• 📊 Interactive collection management\n" + "• ⌨️ Enhanced keyboard navigation\n" + "• 🎨 Modern design with focus indicators\n" + "• 📄 Document browsing and search\n" + "• 🔄 Real-time status updates" + ), + title="🎉 TUI Mode", + border_style="cyan", + ) + ) + from .tui import dashboard + + dashboard() + return + elif ui == "web": + console.print("[red]Web UI not yet implemented. Use --ui tui for Terminal UI.[/red]") + return + elif ui: + console.print(f"[red]Unknown UI option: {ui}[/red]") + console.print("[yellow]Available options: tui, web[/yellow]") + return + + # Normal deployment server mode + if config_file: + # Load deployments from config + console.print(f"[yellow]Loading deployments from {config_file}[/yellow]") + # Implementation would load YAML/JSON config + else: + # Create example deployments + deployments = [ + create_scheduled_deployment( + name="docs-daily", + source_url="https://docs.example.com", + source_type="documentation", + storage_backend="weaviate", + schedule_type="cron", + cron_expression="0 2 * * *", # Daily at 2 AM + ), + create_scheduled_deployment( + name="repo-hourly", + source_url="https://github.com/example/repo", + source_type="repository", + storage_backend="open_webui", + schedule_type="interval", + interval_minutes=60, + ), + ] + + console.print( + "[bold green]Starting deployment server with example deployments[/bold green]" + ) + serve_deployments(deployments) + + +@app.command() +def tui() -> None: + """ + 🚀 Launch the enhanced Terminal User Interface. + + Quick shortcut for 'serve --ui tui' with modern keyboard navigation, + interactive collection management, and real-time status updates. + """ + console.print( + Panel( + ( + "[bold cyan]🚀 Launching Enhanced TUI[/bold cyan]\n\n" + "[yellow]Features:[/yellow]\n" + "• 📊 Interactive collection management\n" + "• ⌨️ Enhanced keyboard navigation\n" + "• 🎨 Modern design with focus indicators\n" + "• 📄 Document browsing and search\n" + "• 🔄 Real-time status updates" + ), + title="🎉 TUI Mode", + border_style="cyan", + ) + ) + from .tui import dashboard + + dashboard() + + +@app.command() +def config() -> None: + """ + 📋 Display current configuration with enhanced formatting. + + Shows all configured endpoints, models, and settings in a beautiful + table format with status indicators. + """ + settings = get_settings() + + console.print( + Panel( + ( + "[bold cyan]⚙️ System Configuration[/bold cyan]\n" + "[dim]Current pipeline settings and endpoints[/dim]" + ), + title="🔧 Configuration", + border_style="cyan", + ) + ) + + # Enhanced configuration table + table = Table( + title="📊 Configuration Details", + title_style="bold magenta", + border_style="blue", + header_style="bold cyan", + show_lines=True, + ) + table.add_column("🏷️ Setting", style="cyan", no_wrap=True, width=25) + table.add_column("🎯 Value", style="yellow", overflow="fold") + table.add_column("📊 Status", style="green", width=12, justify="center") + + # Add configuration rows with status indicators + def get_status_indicator(value: str | None) -> str: + return "✅ Set" if value else "❌ Missing" + + table.add_row("🤖 LLM Endpoint", str(settings.llm_endpoint), "✅ Active") + table.add_row("🔥 Firecrawl Endpoint", str(settings.firecrawl_endpoint), "✅ Active") + table.add_row( + "🗄️ Weaviate Endpoint", + str(settings.weaviate_endpoint), + get_status_indicator(str(settings.weaviate_api_key) if settings.weaviate_api_key else None), + ) + table.add_row( + "🌐 OpenWebUI Endpoint", + str(settings.openwebui_endpoint), + get_status_indicator(settings.openwebui_api_key), + ) + table.add_row("🧠 Embedding Model", settings.embedding_model, "✅ Set") + table.add_row("💾 Default Storage", settings.default_storage_backend.title(), "✅ Set") + table.add_row("📦 Default Batch Size", f"{settings.default_batch_size:,}", "✅ Set") + table.add_row("⚡ Max Concurrent Tasks", f"{settings.max_concurrent_tasks}", "✅ Set") + + console.print(table) + + # Additional helpful information + console.print( + Panel( + ( + "💡 [bold cyan]Quick Tips[/bold cyan]\n\n" + "• Use '[bold]ingest list-collections[/bold]' to view all collections\n" + "• Use '[bold]ingest search[/bold]' to search content\n" + "• Configure API keys in your [yellow].env[/yellow] file\n" + "• Default collection names are auto-generated from URLs" + ), + title="🚀 Usage Tips", + border_style="green", + ) + ) + + +@app.command() +def list_collections() -> None: + """ + 📋 List all collections across storage backends. + """ + console.print("[bold cyan]📚 Collection Overview[/bold cyan]") + asyncio.run(run_list_collections()) + + +@app.command() +def search( + query: Annotated[str, typer.Argument(help="Search query")], + collection: Annotated[ + str | None, typer.Option("--collection", "-c", help="Target collection") + ] = None, + backend: Annotated[ + StorageBackend, typer.Option("--backend", "-b", help="Storage backend") + ] = StorageBackend.WEAVIATE, + limit: Annotated[int, typer.Option("--limit", "-l", help="Result limit")] = 10, +) -> None: + """ + 🔍 Search across collections. + """ + console.print(f"[bold cyan]🔍 Searching for: {query}[/bold cyan]") + asyncio.run(run_search(query, collection, backend.value, limit)) + + +@app.command(name="blocks") +def blocks_command() -> None: + """🧩 List and manage Prefect Blocks.""" + console.print("[bold cyan]📦 Prefect Blocks Management[/bold cyan]") + console.print( + "Use 'prefect block register --module ingest_pipeline.core.models' to register custom blocks" + ) + console.print("Use 'prefect block ls' to list available blocks") + + +@app.command(name="variables") +def variables_command() -> None: + """📊 Manage Prefect Variables.""" + console.print("[bold cyan]📊 Prefect Variables Management[/bold cyan]") + console.print("Use 'prefect variable set VARIABLE_NAME value' to set variables") + console.print("Use 'prefect variable ls' to list variables") + + +async def run_ingestion( + url: str, + source_type: IngestionSource, + storage_backend: StorageBackend, + collection_name: str | None = None, + validate_first: bool = True, +) -> IngestionResult: + """ + Run ingestion with support for targeted collections. + """ + # Auto-generate collection name if not provided + if not collection_name: + from urllib.parse import urlparse + + parsed = urlparse(url) + domain = parsed.netloc.replace(".", "_").replace("-", "_") + collection_name = f"{domain}_{source_type.value}" + + return await create_ingestion_flow( + source_url=url, + source_type=source_type, + storage_backend=storage_backend, + collection_name=collection_name, + validate_first=validate_first, + ) + + +async def run_list_collections() -> None: + """ + List collections across storage backends. + """ + from ..config import get_settings + from ..core.models import StorageBackend + from ..storage.openwebui import OpenWebUIStorage + from ..storage.weaviate import WeaviateStorage + + settings = get_settings() + + console.print("🔍 [bold cyan]Scanning storage backends...[/bold cyan]") + + # Try to connect to Weaviate + weaviate_collections: list[tuple[str, int]] = [] + try: + weaviate_config = StorageConfig( + backend=StorageBackend.WEAVIATE, + endpoint=settings.weaviate_endpoint, + api_key=SecretStr(settings.weaviate_api_key) + if settings.weaviate_api_key is not None + else None, + collection_name="default", + ) + weaviate = WeaviateStorage(weaviate_config) + await weaviate.initialize() + + overview = await weaviate.describe_collections() + for item in overview: + name = str(item.get("name", "Unknown")) + count_val = item.get("count", 0) + count = int(count_val) if isinstance(count_val, (int, str)) else 0 + weaviate_collections.append((name, count)) + except Exception as e: + console.print(f"❌ [red]Weaviate connection failed: {e}[/red]") + + # Try to connect to OpenWebUI + openwebui_collections: list[tuple[str, int]] = [] + try: + openwebui_config = StorageConfig( + backend=StorageBackend.OPEN_WEBUI, + endpoint=settings.openwebui_endpoint, + api_key=SecretStr(settings.openwebui_api_key) + if settings.openwebui_api_key is not None + else None, + collection_name="default", + ) + openwebui = OpenWebUIStorage(openwebui_config) + await openwebui.initialize() + + overview = await openwebui.describe_collections() + for item in overview: + name = str(item.get("name", "Unknown")) + count_val = item.get("count", 0) + count = int(count_val) if isinstance(count_val, (int, str)) else 0 + openwebui_collections.append((name, count)) + except Exception as e: + console.print(f"❌ [red]OpenWebUI connection failed: {e}[/red]") + + # Display results + if weaviate_collections or openwebui_collections: + # Create results table + from rich.table import Table + + table = Table( + title="📚 Collection Overview", + title_style="bold magenta", + border_style="cyan", + header_style="bold blue", + ) + table.add_column("🏷️ Collection", style="cyan", no_wrap=True) + table.add_column("📊 Backend", style="yellow") + table.add_column("📄 Documents", style="green", justify="right") + + # Add Weaviate collections + for name, count in weaviate_collections: + table.add_row(name, "🗄️ Weaviate", f"{count:,}") + + # Add OpenWebUI collections + for name, count in openwebui_collections: + table.add_row(name, "🌐 OpenWebUI", f"{count:,}") + + console.print(table) + else: + console.print("❌ [yellow]No collections found in any backend[/yellow]") + + +async def run_search(query: str, collection: str | None, backend: str, limit: int) -> None: + """ + Search across collections. + """ + from ..config import get_settings + from ..core.models import StorageBackend + from ..storage.weaviate import WeaviateStorage + + settings = get_settings() + + console.print(f"🔍 Searching for: '[bold cyan]{query}[/bold cyan]'") + if collection: + console.print(f"📚 Target collection: [yellow]{collection}[/yellow]") + console.print(f"💾 Backend: [blue]{backend}[/blue]") + + results = [] + + try: + if backend == "weaviate": + weaviate_config = StorageConfig( + backend=StorageBackend.WEAVIATE, + endpoint=settings.weaviate_endpoint, + api_key=SecretStr(settings.weaviate_api_key) + if settings.weaviate_api_key is not None + else None, + collection_name=collection or "default", + ) + weaviate = WeaviateStorage(weaviate_config) + await weaviate.initialize() + + results_generator = weaviate.search(query, limit=limit) + async for doc in results_generator: + results.append( + { + "title": getattr(doc, "title", "Untitled"), + "content": getattr(doc, "content", ""), + "score": getattr(doc, "score", 0.0), + "backend": "🗄️ Weaviate", + } + ) + + elif backend == "open_webui": + console.print("❌ [red]OpenWebUI search not yet implemented[/red]") + return + + except Exception as e: + console.print(f"❌ [red]Search failed: {e}[/red]") + return + + # Display results + if results: + from rich.table import Table + + table = Table( + title=f"🔍 Search Results for '{query}'", + title_style="bold magenta", + border_style="green", + header_style="bold blue", + ) + table.add_column("📄 Title", style="cyan", max_width=40) + table.add_column("📝 Preview", style="white", max_width=60) + table.add_column("📊 Score", style="yellow", justify="right") + + for result in results[:limit]: + title = str(result["title"]) + title_display = title[:40] + "..." if len(title) > 40 else title + + content = str(result["content"]) + content_display = content[:60] + "..." if len(content) > 60 else content + + score = f"{result['score']:.3f}" + + table.add_row(title_display, content_display, score) + + console.print(table) + console.print(f"\n✅ [green]Found {len(results)} results[/green]") + else: + console.print("❌ [yellow]No results found[/yellow]") + + +if __name__ == "__main__": + app() + + + +"""Configuration management utilities.""" + +from __future__ import annotations + +from contextlib import ExitStack + +from prefect.settings import Setting, temporary_settings + + +# Import Prefect settings with version compatibility - avoid static analysis issues +def _setup_prefect_settings() -> tuple[object, object, object]: + """Setup Prefect settings with proper fallbacks.""" + try: + import prefect.settings as ps + + # Try to get the settings directly + api_key = getattr(ps, "PREFECT_API_KEY", None) + api_url = getattr(ps, "PREFECT_API_URL", None) + work_pool = getattr(ps, "PREFECT_DEFAULT_WORK_POOL_NAME", None) + + if api_key is not None: + return api_key, api_url, work_pool + + # Fallback to registry-based approach + registry = getattr(ps, "PREFECT_SETTING_REGISTRY", None) + if registry is not None: + Setting = getattr(ps, "Setting", None) + if Setting is not None: + api_key = registry.get("PREFECT_API_KEY") or Setting( + "PREFECT_API_KEY", type_=str, default=None + ) + api_url = registry.get("PREFECT_API_URL") or Setting( + "PREFECT_API_URL", type_=str, default=None + ) + work_pool = registry.get("PREFECT_DEFAULT_WORK_POOL_NAME") or Setting( + "PREFECT_DEFAULT_WORK_POOL_NAME", type_=str, default=None + ) + return api_key, api_url, work_pool + + except ImportError: + pass + + # Ultimate fallback + return None, None, None + + +PREFECT_API_KEY, PREFECT_API_URL, PREFECT_DEFAULT_WORK_POOL_NAME = _setup_prefect_settings() + +# Import after Prefect settings setup to avoid circular dependencies +from .settings import Settings, get_settings # noqa: E402 + +__all__ = ["Settings", "get_settings", "configure_prefect"] + +_prefect_settings_stack: ExitStack | None = None + + +def configure_prefect(settings: Settings) -> None: + """Apply Prefect settings from the application configuration.""" + global _prefect_settings_stack + + overrides: dict[Setting, str] = {} + + if ( + settings.prefect_api_url is not None + and PREFECT_API_URL is not None + and isinstance(PREFECT_API_URL, Setting) + ): + overrides[PREFECT_API_URL] = str(settings.prefect_api_url) + if ( + settings.prefect_api_key + and PREFECT_API_KEY is not None + and isinstance(PREFECT_API_KEY, Setting) + ): + overrides[PREFECT_API_KEY] = settings.prefect_api_key + if ( + settings.prefect_work_pool + and PREFECT_DEFAULT_WORK_POOL_NAME is not None + and isinstance(PREFECT_DEFAULT_WORK_POOL_NAME, Setting) + ): + overrides[PREFECT_DEFAULT_WORK_POOL_NAME] = settings.prefect_work_pool + + if not overrides: + return + + filtered_overrides = { + setting: value for setting, value in overrides.items() if setting.value() != value + } + + if not filtered_overrides: + return + + new_stack = ExitStack() + new_stack.enter_context(temporary_settings(updates=filtered_overrides)) + + if _prefect_settings_stack is not None: + _prefect_settings_stack.close() + + _prefect_settings_stack = new_stack + + + +"""Base ingestor interface.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING + +from ..core.models import Document, IngestionJob + +if TYPE_CHECKING: + from ..storage.base import BaseStorage + + +class BaseIngestor(ABC): + """Abstract base class for all ingestors.""" + + @abstractmethod + def ingest(self, job: IngestionJob) -> AsyncGenerator[Document, None]: + """ + Ingest data from a source. + + Args: + job: The ingestion job configuration + + Yields: + Documents from the source + """ + ... # pragma: no cover + + @abstractmethod + async def validate_source(self, source_url: str) -> bool: + """ + Validate if the source is accessible. + + Args: + source_url: URL or path to the source + + Returns: + True if source is valid and accessible + """ + pass # pragma: no cover + + @abstractmethod + async def estimate_size(self, source_url: str) -> int: + """ + Estimate the number of documents in the source. + + Args: + source_url: URL or path to the source + + Returns: + Estimated number of documents + """ + pass # pragma: no cover + + async def ingest_with_dedup( + self, + job: IngestionJob, + storage_client: BaseStorage, + *, + collection_name: str | None = None, + stale_after_days: int = 30, + ) -> AsyncGenerator[Document, None]: + """ + Ingest documents with duplicate detection (optional optimization). + + Default implementation falls back to regular ingestion. + Subclasses can override to provide optimized deduplication. + + Args: + job: The ingestion job configuration + storage_client: Storage client to check for existing documents + collection_name: Collection to check for duplicates + stale_after_days: Consider documents stale after this many days + + Yields: + Documents from the source (with deduplication if implemented) + """ + # Default implementation: fall back to regular ingestion + async for document in self.ingest(job): + yield document + + + +"""Document management screen with enhanced navigation.""" + +from datetime import datetime + +from textual.app import ComposeResult +from textual.binding import Binding +from textual.containers import Container, Horizontal, ScrollableContainer +from textual.screen import ModalScreen, Screen +from textual.widgets import Button, Footer, Header, Label, LoadingIndicator, Markdown, Static +from typing_extensions import override + +from ....storage.base import BaseStorage +from ..models import CollectionInfo, DocumentInfo +from ..widgets import EnhancedDataTable + + +class DocumentManagementScreen(Screen[None]): + """Screen for managing documents within a collection with enhanced keyboard navigation.""" + + collection: CollectionInfo + storage: BaseStorage | None + documents: list[DocumentInfo] + selected_docs: set[str] + current_offset: int + page_size: int + + BINDINGS = [ + Binding("escape", "app.pop_screen", "Back"), + Binding("r", "refresh", "Refresh"), + Binding("v", "view_document", "View"), + Binding("delete", "delete_selected", "Delete Selected"), + Binding("a", "select_all", "Select All"), + Binding("ctrl+a", "select_all", "Select All"), + Binding("n", "select_none", "Clear Selection"), + Binding("ctrl+shift+a", "select_none", "Clear Selection"), + Binding("space", "toggle_selection", "Toggle Selection"), + Binding("ctrl+d", "delete_selected", "Delete Selected"), + Binding("pageup", "prev_page", "Previous Page"), + Binding("pagedown", "next_page", "Next Page"), + Binding("home", "first_page", "First Page"), + Binding("end", "last_page", "Last Page"), + ] + + def __init__(self, collection: CollectionInfo, storage: BaseStorage | None): + super().__init__() + self.collection = collection + self.storage = storage + self.documents: list[DocumentInfo] = [] + self.selected_docs: set[str] = set() + self.current_offset = 0 + self.page_size = 50 + + @override + def compose(self) -> ComposeResult: + yield Header() + yield Container( + Static(f"📄 Document Management: {self.collection['name']}", classes="title"), + Static( + f"Total Documents: {self.collection['count']:,} | Use Space to select, Delete to remove", + classes="subtitle", + ), + Label(f"Page size: {self.page_size} documents"), + EnhancedDataTable(id="documents_table", classes="enhanced-table"), + Horizontal( + Button("🔄 Refresh", id="refresh_docs_btn", variant="primary"), + Button("🗑️ Delete Selected", id="delete_selected_btn", variant="error"), + Button("✅ Select All", id="select_all_btn", variant="default"), + Button("❌ Clear Selection", id="clear_selection_btn", variant="default"), + Button("⬅️ Previous Page", id="prev_page_btn", variant="default"), + Button("➡️ Next Page", id="next_page_btn", variant="default"), + classes="button_bar", + ), + Label("", id="selection_status"), + Static("", id="page_info", classes="status-text"), + LoadingIndicator(id="loading"), + classes="main_container", + ) + yield Footer() + + async def on_mount(self) -> None: + """Initialize the screen.""" + self.query_one("#loading").display = False + + # Setup documents table with enhanced columns + table = self.query_one("#documents_table", EnhancedDataTable) + table.add_columns( + "✓", "Title", "Source URL", "Description", "Type", "Words", "Timestamp", "ID" + ) + + # Set up message handling for table events + table.can_focus = True + + await self.load_documents() + + async def load_documents(self) -> None: + """Load documents from the collection.""" + loading = self.query_one("#loading") + loading.display = True + + try: + if self.storage: + # Try to load documents using the storage backend + try: + raw_docs = await self.storage.list_documents( + limit=self.page_size, + offset=self.current_offset, + collection_name=self.collection["name"], + ) + # Cast to proper type with type checking + self.documents = [ + DocumentInfo( + id=str(doc.get("id", f"doc_{i}")), + title=str(doc.get("title", "Untitled Document")), + source_url=str(doc.get("source_url", "")), + description=str(doc.get("description", "")), + content_type=str(doc.get("content_type", "text/plain")), + content_preview=str(doc.get("content_preview", "")), + word_count=( + lambda wc_val: int(wc_val) + if isinstance(wc_val, (int, str)) and str(wc_val).isdigit() + else 0 + )(doc.get("word_count", 0)), + timestamp=str(doc.get("timestamp", "")), + ) + for i, doc in enumerate(raw_docs) + ] + except NotImplementedError: + # For storage backends that don't support document listing, show a message + self.notify( + f"Document listing not supported for {self.storage.__class__.__name__}", + severity="information", + ) + self.documents = [] + + await self.update_table() + self.update_selection_status() + self.update_page_info() + + except Exception as e: + self.notify(f"Error loading documents: {e}", severity="error", markup=False) + finally: + loading.display = False + + async def update_table(self) -> None: + """Update the documents table with enhanced metadata display.""" + table = self.query_one("#documents_table", EnhancedDataTable) + table.clear(columns=True) + + # Add enhanced columns with more metadata + table.add_columns( + "✓", "Title", "Source URL", "Description", "Type", "Words", "Timestamp", "ID" + ) + + # Add rows with enhanced metadata + for doc in self.documents: + selected = "✓" if doc["id"] in self.selected_docs else "" + + # Get additional metadata from the raw docs + description = str(doc.get("description") or "").strip()[:40] + if not description: + description = "[dim]No description[/dim]" + elif len(str(doc.get("description") or "")) > 40: + description += "..." + + # Format content type with appropriate icon + content_type = doc.get("content_type", "text/plain") + if "markdown" in content_type.lower(): + type_display = "📝 md" + elif "html" in content_type.lower(): + type_display = "🌐 html" + elif "text" in content_type.lower(): + type_display = "📄 txt" + else: + type_display = f"📄 {content_type.split('/')[-1][:5]}" + + # Format timestamp to be more readable + timestamp = doc.get("timestamp", "") + if timestamp: + try: + # Parse ISO format timestamp + dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00")) + timestamp = dt.strftime("%m/%d %H:%M") + except Exception: + timestamp = str(timestamp)[:16] # Fallback + table.add_row( + selected, + doc.get("title", "Untitled")[:40], + doc.get("source_url", "")[:35], + description, + type_display, + str(doc.get("word_count", 0)), + timestamp, + doc["id"][:8] + "...", # Show truncated ID + ) + + def update_selection_status(self) -> None: + """Update the selection status label.""" + status_label = self.query_one("#selection_status", Label) + total_selected = len(self.selected_docs) + status_label.update(f"Selected: {total_selected} documents") + + def update_page_info(self) -> None: + """Update the page information.""" + page_info = self.query_one("#page_info", Static) + total_docs = self.collection["count"] + start = self.current_offset + 1 + end = min(self.current_offset + len(self.documents), total_docs) + page_num = (self.current_offset // self.page_size) + 1 + total_pages = (total_docs + self.page_size - 1) // self.page_size + + page_info.update( + f"Showing {start:,}-{end:,} of {total_docs:,} documents (Page {page_num} of {total_pages})" + ) + + def get_current_document(self) -> DocumentInfo | None: + """Get the currently selected document.""" + table = self.query_one("#documents_table", EnhancedDataTable) + try: + if 0 <= table.cursor_coordinate.row < len(self.documents): + return self.documents[table.cursor_coordinate.row] + except (AttributeError, IndexError): + pass + return None + + # Action methods + def action_refresh(self) -> None: + """Refresh the document list.""" + self.run_worker(self.load_documents()) + + def action_toggle_selection(self) -> None: + """Toggle selection of current row.""" + if doc := self.get_current_document(): + doc_id = doc["id"] + if doc_id in self.selected_docs: + self.selected_docs.remove(doc_id) + else: + self.selected_docs.add(doc_id) + + self.run_worker(self.update_table()) + self.update_selection_status() + + def action_select_all(self) -> None: + """Select all documents on current page.""" + for doc in self.documents: + self.selected_docs.add(doc["id"]) + self.run_worker(self.update_table()) + self.update_selection_status() + + def action_select_none(self) -> None: + """Clear all selections.""" + self.selected_docs.clear() + self.run_worker(self.update_table()) + self.update_selection_status() + + def action_delete_selected(self) -> None: + """Delete selected documents.""" + if self.selected_docs: + from .dialogs import ConfirmDocumentDeleteScreen + + self.app.push_screen( + ConfirmDocumentDeleteScreen(list(self.selected_docs), self.collection, self) + ) + else: + self.notify("No documents selected", severity="warning") + + def action_next_page(self) -> None: + """Go to next page.""" + if self.current_offset + self.page_size < self.collection["count"]: + self.current_offset += self.page_size + self.run_worker(self.load_documents()) + + def action_prev_page(self) -> None: + """Go to previous page.""" + if self.current_offset >= self.page_size: + self.current_offset -= self.page_size + self.run_worker(self.load_documents()) + + def action_first_page(self) -> None: + """Go to first page.""" + if self.current_offset > 0: + self.current_offset = 0 + self.run_worker(self.load_documents()) + + def action_last_page(self) -> None: + """Go to last page.""" + total_docs = self.collection["count"] + last_offset = ((total_docs - 1) // self.page_size) * self.page_size + if self.current_offset != last_offset: + self.current_offset = last_offset + self.run_worker(self.load_documents()) + + def on_button_pressed(self, event: Button.Pressed) -> None: + """Handle button presses.""" + if event.button.id == "refresh_docs_btn": + self.action_refresh() + elif event.button.id == "delete_selected_btn": + self.action_delete_selected() + elif event.button.id == "select_all_btn": + self.action_select_all() + elif event.button.id == "clear_selection_btn": + self.action_select_none() + elif event.button.id == "next_page_btn": + self.action_next_page() + elif event.button.id == "prev_page_btn": + self.action_prev_page() + + def on_enhanced_data_table_row_toggled(self, event: EnhancedDataTable.RowToggled) -> None: + """Handle row toggle from enhanced table.""" + if 0 <= event.row_index < len(self.documents): + doc = self.documents[event.row_index] + doc_id = doc["id"] + + if doc_id in self.selected_docs: + self.selected_docs.remove(doc_id) + else: + self.selected_docs.add(doc_id) + + self.run_worker(self.update_table()) + self.update_selection_status() + + def on_enhanced_data_table_select_all(self, event: EnhancedDataTable.SelectAll) -> None: + """Handle select all from enhanced table.""" + self.action_select_all() + + def on_enhanced_data_table_clear_selection( + self, event: EnhancedDataTable.ClearSelection + ) -> None: + """Handle clear selection from enhanced table.""" + self.action_select_none() + + def action_view_document(self) -> None: + """View the content of the currently selected document.""" + if doc := self.get_current_document(): + if self.storage: + self.app.push_screen( + DocumentContentModal(doc, self.storage, self.collection["name"]) + ) + else: + self.notify("No storage backend available", severity="error") + else: + self.notify("No document selected", severity="warning") + + +class DocumentContentModal(ModalScreen[None]): + """Modal screen for viewing document content.""" + + DEFAULT_CSS = """ + DocumentContentModal { + align: center middle; + } + + DocumentContentModal > Container { + width: 90%; + height: 85%; + background: $surface; + border: thick $primary; + } + + DocumentContentModal .modal-header { + background: $primary; + color: $text; + padding: 1; + dock: top; + height: 3; + } + + DocumentContentModal .modal-content { + padding: 1; + height: 1fr; + } + """ + + BINDINGS = [ + Binding("escape", "app.pop_screen", "Close"), + Binding("q", "app.pop_screen", "Close"), + ] + + def __init__(self, document: DocumentInfo, storage: BaseStorage, collection_name: str): + super().__init__() + self.document = document + self.storage = storage + self.collection_name = collection_name + + def compose(self) -> ComposeResult: + yield Container( + Static( + f"📄 Document: {self.document['title'][:60]}{'...' if len(self.document['title']) > 60 else ''}", + classes="modal-header", + ), + ScrollableContainer( + Markdown("Loading document content...", id="document_content"), + LoadingIndicator(id="content_loading"), + classes="modal-content", + ), + ) + + async def on_mount(self) -> None: + """Load and display the document content.""" + content_widget = self.query_one("#document_content", Markdown) + loading = self.query_one("#content_loading") + + try: + # Get full document content + doc_content = await self.storage.retrieve( + self.document["id"], collection_name=self.collection_name + ) + + # Format content for display + if isinstance(doc_content, str): + formatted_content = f"""# {self.document["title"]} + +**Source:** {self.document.get("source_url", "N/A")} +**Type:** {self.document.get("content_type", "text/plain")} +**Words:** {self.document.get("word_count", 0):,} +**Timestamp:** {self.document.get("timestamp", "N/A")} + +--- + +{doc_content} +""" + else: + formatted_content = f"""# {self.document["title"]} + +**Source:** {self.document.get("source_url", "N/A")} +**Type:** {self.document.get("content_type", "text/plain")} +**Words:** {self.document.get("word_count", 0):,} +**Timestamp:** {self.document.get("timestamp", "N/A")} + +--- + +*Content format not supported for display* +""" + + content_widget.update(formatted_content) + + except Exception as e: + content_widget.update( + f"# Error Loading Document\n\nFailed to load document content: {e}" + ) + finally: + loading.display = False + + + +"""TUI runner functions and initialization.""" + +from __future__ import annotations + +import asyncio +import logging +from logging import Logger +from logging.handlers import QueueHandler, RotatingFileHandler +from pathlib import Path +from queue import Queue +from typing import NamedTuple + +import platformdirs + +from ....config import configure_prefect, get_settings +from .storage_manager import StorageManager + + +class _TuiLoggingContext(NamedTuple): + """Container describing configured logging outputs for the TUI.""" + + queue: Queue[logging.LogRecord] + formatter: logging.Formatter + log_file: Path | None + + +_logging_context: _TuiLoggingContext | None = None + + +def _configure_tui_logging(*, log_level: str) -> _TuiLoggingContext: + """Configure logging so that messages do not break the TUI output.""" + + global _logging_context + if _logging_context is not None: + return _logging_context + + resolved_level = getattr(logging, log_level.upper(), logging.INFO) + log_queue: Queue[logging.LogRecord] = Queue() + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + root_logger = logging.getLogger() + root_logger.setLevel(resolved_level) + + # Remove existing stream handlers to prevent console flicker inside the TUI + for handler in list(root_logger.handlers): + root_logger.removeHandler(handler) + + queue_handler = QueueHandler(log_queue) + queue_handler.setLevel(resolved_level) + root_logger.addHandler(queue_handler) + + log_file: Path | None = None + try: + # Try current directory first for development + log_dir = Path.cwd() / "logs" + log_dir.mkdir(parents=True, exist_ok=True) + log_file = log_dir / "tui.log" + except OSError: + # Fall back to user log directory + try: + log_dir = Path(platformdirs.user_log_dir("ingest-pipeline", "ingest-pipeline")) + log_dir.mkdir(parents=True, exist_ok=True) + log_file = log_dir / "tui.log" + except OSError as exc: + fallback = logging.getLogger(__name__) + fallback.warning("Failed to create log directory, file logging disabled: %s", exc) + log_file = None + + if log_file: + try: + file_handler = RotatingFileHandler( + log_file, + maxBytes=2_000_000, + backupCount=5, + encoding="utf-8", + ) + file_handler.setLevel(resolved_level) + file_handler.setFormatter(formatter) + root_logger.addHandler(file_handler) + except OSError as exc: + fallback = logging.getLogger(__name__) + fallback.warning("Failed to configure file logging for TUI: %s", exc) + log_file = None + + _logging_context = _TuiLoggingContext(log_queue, formatter, log_file) + return _logging_context + + +LOGGER: Logger = logging.getLogger(__name__) + + +async def run_textual_tui() -> None: + """Run the enhanced modern TUI with better error handling and initialization.""" + settings = get_settings() + configure_prefect(settings) + + logging_context = _configure_tui_logging(log_level=settings.log_level) + + LOGGER.info("Initializing collection management TUI") + LOGGER.info("Scanning available storage backends") + + # Create storage manager without initialization - let TUI handle it asynchronously + storage_manager = StorageManager(settings) + + LOGGER.info("Launching TUI - storage backends will initialize in background") + + # Import here to avoid circular import + from ..app import CollectionManagementApp + + app = CollectionManagementApp( + storage_manager, + None, # weaviate - will be available after initialization + None, # openwebui - will be available after initialization + None, # r2r_backend - will be available after initialization + log_queue=logging_context.queue, + log_formatter=logging_context.formatter, + log_file=logging_context.log_file, + ) + try: + await app.run_async() + finally: + LOGGER.info("Shutting down storage connections") + await storage_manager.close_all() + LOGGER.info("All storage connections closed gracefully") + + +def dashboard() -> None: + """Launch the modern collection dashboard.""" + asyncio.run(run_textual_tui()) + + """Comprehensive theming system for TUI applications with WCAG AA accessibility compliance.""" from dataclasses import dataclass from enum import Enum -from typing import Any +from typing import Protocol + +from textual.app import App + +# Type alias for Textual apps with unknown return type +TextualApp = App[object] + + +class AppProtocol(Protocol): + """Protocol for apps that support CSS and refresh.""" + + CSS: str + + def refresh(self) -> None: + """Refresh the app.""" + ... class ThemeType(Enum): @@ -5493,8 +6748,8 @@ class ThemeManager: """Manages theme selection and CSS generation.""" def __init__(self, default_theme: ThemeType = ThemeType.DARK): - self.current_theme = default_theme - self._themes = { + self.current_theme: ThemeType = default_theme + self._themes: dict[ThemeType, ColorPalette] = { ThemeType.DARK: ThemeRegistry.get_enhanced_dark(), ThemeType.LIGHT: ThemeRegistry.get_light(), ThemeType.HIGH_CONTRAST: ThemeRegistry.get_high_contrast(), @@ -6418,30 +7673,28 @@ def get_css_for_theme(theme_type: ThemeType) -> str: return css -def apply_theme_to_app(app: object, theme_type: ThemeType) -> None: +def apply_theme_to_app(app: TextualApp | AppProtocol, theme_type: ThemeType) -> None: """Apply a theme to a Textual app instance.""" try: - css = set_theme(theme_type) - if hasattr(app, "stylesheet"): - app.stylesheet.clear() - app.stylesheet.parse(css) - elif hasattr(app, "CSS"): - setattr(app, "CSS", css) - elif hasattr(app, "refresh"): - # Fallback: try to refresh the app with new CSS + # Note: CSS class variable cannot be changed at runtime + # This function would need to be called during app initialization + # or implement a different approach for dynamic theming + _ = set_theme(theme_type) # Keep for future implementation + if hasattr(app, "refresh"): app.refresh() except Exception as e: # Graceful fallback - log but don't crash the UI import logging + logging.debug(f"Failed to apply theme to app: {e}") class ThemeSwitcher: """Helper class for managing theme switching in TUI applications.""" - def __init__(self, app: object | None = None) -> None: - self.app = app - self.theme_history = [ThemeType.DARK] + def __init__(self, app: TextualApp | AppProtocol | None = None) -> None: + self.app: TextualApp | AppProtocol | None = app + self.theme_history: list[ThemeType] = [ThemeType.DARK] def switch_theme(self, theme_type: ThemeType) -> str: """Switch to a new theme and apply it to the app if available.""" @@ -6469,7 +7722,7 @@ class ThemeSwitcher: next_theme = themes[(current_index + 1) % len(themes)] return self.switch_theme(next_theme) - def get_theme_info(self) -> dict[str, Any]: + def get_theme_info(self) -> dict[str, str | list[str] | dict[str, str]]: """Get information about the current theme.""" palette = get_theme_palette() return { @@ -6486,11 +7739,11 @@ class ThemeSwitcher: # Responsive breakpoints for dynamic layout adaptation RESPONSIVE_BREAKPOINTS = { - "xs": 40, # Extra small terminals - "sm": 60, # Small terminals - "md": 100, # Medium terminals - "lg": 140, # Large terminals - "xl": 180, # Extra large terminals + "xs": 40, # Extra small terminals + "sm": 60, # Small terminals + "md": 100, # Medium terminals + "lg": 140, # Large terminals + "xl": 180, # Extra large terminals } @@ -7027,744 +8280,772 @@ def apply_responsive_theme() -> str: return f"{custom_properties}\n{base_css}\n{responsive_css}" - -"""CLI interface for ingestion pipeline.""" - -import asyncio -from typing import Annotated - -import typer -from pydantic import SecretStr -from rich.console import Console -from rich.panel import Panel -from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn -from rich.table import Table - -from ..config import configure_prefect, get_settings -from ..core.models import ( - IngestionResult, - IngestionSource, - StorageBackend, - StorageConfig, -) -from ..flows.ingestion import create_ingestion_flow -from ..flows.scheduler import create_scheduled_deployment, serve_deployments - -app = typer.Typer( - name="ingest", - help="🚀 Modern Document Ingestion Pipeline - Advanced web and repository processing", - rich_markup_mode="rich", - add_completion=False, -) -console = Console() - - -@app.callback() -def main( - version: Annotated[ - bool, typer.Option("--version", "-v", help="Show version information") - ] = False, -) -> None: - """ - 🚀 Modern Document Ingestion Pipeline - - [bold cyan]Advanced document processing and management platform[/bold cyan] - - Features: - • 🌐 Web scraping and crawling with Firecrawl - • 📦 Repository ingestion with Repomix - • 🗄️ Multiple storage backends (Weaviate, OpenWebUI, R2R) - • 📊 Modern TUI for collection management - • ⚡ Async processing with Prefect orchestration - • 🎨 Rich CLI with enhanced visuals - """ - settings = get_settings() - configure_prefect(settings) - - if version: - console.print( - Panel( - ( - "[bold magenta]Ingest Pipeline v0.1.0[/bold magenta]\n" - "[dim]Modern Document Ingestion & Management System[/dim]" - ), - title="🚀 Version Info", - border_style="magenta", - ) - ) - raise typer.Exit() - - -@app.command() -def ingest( - source_url: Annotated[str, typer.Argument(help="URL or path to ingest from")], - source_type: Annotated[ - IngestionSource, typer.Option("--type", "-t", help="Type of source") - ] = IngestionSource.WEB, - storage: Annotated[ - StorageBackend, typer.Option("--storage", "-s", help="Storage backend") - ] = StorageBackend.WEAVIATE, - collection: Annotated[ - str | None, - typer.Option( - "--collection", "-c", help="Target collection name (auto-generated if not specified)" - ), - ] = None, - validate: Annotated[ - bool, typer.Option("--validate/--no-validate", help="Validate source before ingesting") - ] = True, -) -> None: - """ - 🚀 Run a one-time ingestion job with enhanced progress tracking. - - This command processes documents from various sources and stores them in - your chosen backend with full progress visualization. - """ - # Enhanced startup message - console.print( - Panel( - ( - f"[bold cyan]🚀 Starting Modern Ingestion[/bold cyan]\n\n" - f"[yellow]Source:[/yellow] {source_url}\n" - f"[yellow]Type:[/yellow] {source_type.value.title()}\n" - f"[yellow]Storage:[/yellow] {storage.value.replace('_', ' ').title()}\n" - f"[yellow]Collection:[/yellow] {collection or '[dim]Auto-generated[/dim]'}" - ), - title="🔥 Ingestion Configuration", - border_style="cyan", - ) - ) - - async def run_with_progress() -> IngestionResult: - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - console=console, - ) as progress: - task = progress.add_task("🔄 Processing documents...", total=100) - - # Simulate progress updates during ingestion - progress.update(task, advance=20, description="🔗 Connecting to services...") - await asyncio.sleep(0.5) - - progress.update(task, advance=30, description="📄 Fetching documents...") - result = await run_ingestion( - url=source_url, - source_type=source_type, - storage_backend=storage, - collection_name=collection, - validate_first=validate, - ) - - progress.update(task, advance=50, description="✅ Ingestion complete!") - return result - - # Use asyncio.run() with proper event loop handling - try: - result = asyncio.run(run_with_progress()) - except RuntimeError as e: - if "asyncio.run() cannot be called from a running event loop" in str(e): - # If we're already in an event loop (e.g., in Jupyter), use nest_asyncio - try: - import nest_asyncio - nest_asyncio.apply() - result = asyncio.run(run_with_progress()) - except ImportError: - # Fallback: get the current loop and run the coroutine - loop = asyncio.get_event_loop() - result = loop.run_until_complete(run_with_progress()) - else: - raise - - # Enhanced results display - status_color = "green" if result.status.value == "completed" else "red" - - # Create results table with enhanced styling - table = Table( - title="📊 Ingestion Results", - title_style="bold magenta", - border_style="cyan", - header_style="bold blue", - ) - table.add_column("📋 Metric", style="cyan", no_wrap=True) - table.add_column("📈 Value", style=status_color, justify="right") - - # Add enhanced status icon - status_icon = "✅" if result.status.value == "completed" else "❌" - table.add_row("Status", f"{status_icon} {result.status.value.title()}") - - table.add_row("Documents Processed", f"📄 {result.documents_processed:,}") - table.add_row("Documents Failed", f"⚠️ {result.documents_failed:,}") - table.add_row("Duration", f"⏱️ {result.duration_seconds:.2f}s") - - if result.error_messages: - error_text = "\n".join(f"❌ {error}" for error in result.error_messages[:3]) - if len(result.error_messages) > 3: - error_text += f"\n... and {len(result.error_messages) - 3} more errors" - table.add_row("Errors", error_text) - - console.print(table) - - # Success celebration or error guidance - if result.status.value == "completed" and result.documents_processed > 0: - console.print( - Panel( - ( - f"🎉 [bold green]Success![/bold green] {result.documents_processed} documents ingested\n\n" - f"💡 [dim]Try '[bold cyan]ingest modern[/bold cyan]' to explore your collections![/dim]" - ), - title="✨ Ingestion Complete", - border_style="green", - ) - ) - elif result.error_messages: - console.print( - Panel( - ( - "❌ [bold red]Ingestion encountered errors[/bold red]\n\n" - "💡 [dim]Check your configuration and try again[/dim]" - ), - title="⚠️ Issues Detected", - border_style="red", - ) - ) - - -@app.command() -def schedule( - name: Annotated[str, typer.Argument(help="Deployment name")], - source_url: Annotated[str, typer.Argument(help="URL or path to ingest from")], - source_type: Annotated[ - IngestionSource, typer.Option("--type", "-t", help="Type of source") - ] = IngestionSource.WEB, - storage: Annotated[ - StorageBackend, typer.Option("--storage", "-s", help="Storage backend") - ] = StorageBackend.WEAVIATE, - cron: Annotated[ - str | None, typer.Option("--cron", "-c", help="Cron expression for scheduling") - ] = None, - interval: Annotated[int, typer.Option("--interval", "-i", help="Interval in minutes")] = 60, - serve_now: Annotated[ - bool, typer.Option("--serve/--no-serve", help="Start serving immediately") - ] = False, -) -> None: - """ - Create a scheduled deployment for recurring ingestion. - """ - console.print(f"[bold blue]Creating deployment: {name}[/bold blue]") - - deployment = create_scheduled_deployment( - name=name, - source_url=source_url, - source_type=source_type, - storage_backend=storage, - schedule_type="cron" if cron else "interval", - cron_expression=cron, - interval_minutes=interval, - ) - - console.print(f"[green]✓ Deployment '{name}' created[/green]") - - if serve_now: - console.print("[yellow]Starting deployment server...[/yellow]") - serve_deployments([deployment]) - - -@app.command() -def serve( - config_file: Annotated[ - str | None, typer.Option("--config", "-c", help="Path to deployments config file") - ] = None, - ui: Annotated[ - str | None, typer.Option("--ui", help="Launch user interface (options: tui, web)") - ] = None, -) -> None: - """ - 🚀 Serve configured deployments with optional UI interface. - - Launch the deployment server to run scheduled ingestion jobs, - optionally with a modern Terminal User Interface (TUI) or web interface. - """ - # Handle UI mode first - if ui == "tui": - console.print( - Panel( - ( - "[bold cyan]🚀 Launching Enhanced TUI[/bold cyan]\n\n" - "[yellow]Features:[/yellow]\n" - "• 📊 Interactive collection management\n" - "• ⌨️ Enhanced keyboard navigation\n" - "• 🎨 Modern design with focus indicators\n" - "• 📄 Document browsing and search\n" - "• 🔄 Real-time status updates" - ), - title="🎉 TUI Mode", - border_style="cyan", - ) - ) - from .tui import dashboard - - dashboard() - return - elif ui == "web": - console.print("[red]Web UI not yet implemented. Use --ui tui for Terminal UI.[/red]") - return - elif ui: - console.print(f"[red]Unknown UI option: {ui}[/red]") - console.print("[yellow]Available options: tui, web[/yellow]") - return - - # Normal deployment server mode - if config_file: - # Load deployments from config - console.print(f"[yellow]Loading deployments from {config_file}[/yellow]") - # Implementation would load YAML/JSON config - else: - # Create example deployments - deployments = [ - create_scheduled_deployment( - name="docs-daily", - source_url="https://docs.example.com", - source_type="documentation", - storage_backend="weaviate", - schedule_type="cron", - cron_expression="0 2 * * *", # Daily at 2 AM - ), - create_scheduled_deployment( - name="repo-hourly", - source_url="https://github.com/example/repo", - source_type="repository", - storage_backend="open_webui", - schedule_type="interval", - interval_minutes=60, - ), - ] - - console.print( - "[bold green]Starting deployment server with example deployments[/bold green]" - ) - serve_deployments(deployments) - - -@app.command() -def tui() -> None: - """ - 🚀 Launch the enhanced Terminal User Interface. - - Quick shortcut for 'serve --ui tui' with modern keyboard navigation, - interactive collection management, and real-time status updates. - """ - console.print( - Panel( - ( - "[bold cyan]🚀 Launching Enhanced TUI[/bold cyan]\n\n" - "[yellow]Features:[/yellow]\n" - "• 📊 Interactive collection management\n" - "• ⌨️ Enhanced keyboard navigation\n" - "• 🎨 Modern design with focus indicators\n" - "• 📄 Document browsing and search\n" - "• 🔄 Real-time status updates" - ), - title="🎉 TUI Mode", - border_style="cyan", - ) - ) - from .tui import dashboard - - dashboard() - - -@app.command() -def config() -> None: - """ - 📋 Display current configuration with enhanced formatting. - - Shows all configured endpoints, models, and settings in a beautiful - table format with status indicators. - """ - settings = get_settings() - - console.print( - Panel( - ( - "[bold cyan]⚙️ System Configuration[/bold cyan]\n" - "[dim]Current pipeline settings and endpoints[/dim]" - ), - title="🔧 Configuration", - border_style="cyan", - ) - ) - - # Enhanced configuration table - table = Table( - title="📊 Configuration Details", - title_style="bold magenta", - border_style="blue", - header_style="bold cyan", - show_lines=True, - ) - table.add_column("🏷️ Setting", style="cyan", no_wrap=True, width=25) - table.add_column("🎯 Value", style="yellow", overflow="fold") - table.add_column("📊 Status", style="green", width=12, justify="center") - - # Add configuration rows with status indicators - def get_status_indicator(value: str | None) -> str: - return "✅ Set" if value else "❌ Missing" - - table.add_row("🤖 LLM Endpoint", str(settings.llm_endpoint), "✅ Active") - table.add_row("🔥 Firecrawl Endpoint", str(settings.firecrawl_endpoint), "✅ Active") - table.add_row( - "🗄️ Weaviate Endpoint", - str(settings.weaviate_endpoint), - get_status_indicator(str(settings.weaviate_api_key) if settings.weaviate_api_key else None), - ) - table.add_row( - "🌐 OpenWebUI Endpoint", - str(settings.openwebui_endpoint), - get_status_indicator(settings.openwebui_api_key), - ) - table.add_row("🧠 Embedding Model", settings.embedding_model, "✅ Set") - table.add_row("💾 Default Storage", settings.default_storage_backend.title(), "✅ Set") - table.add_row("📦 Default Batch Size", f"{settings.default_batch_size:,}", "✅ Set") - table.add_row("⚡ Max Concurrent Tasks", f"{settings.max_concurrent_tasks}", "✅ Set") - - console.print(table) - - # Additional helpful information - console.print( - Panel( - ( - "💡 [bold cyan]Quick Tips[/bold cyan]\n\n" - "• Use '[bold]ingest list-collections[/bold]' to view all collections\n" - "• Use '[bold]ingest search[/bold]' to search content\n" - "• Configure API keys in your [yellow].env[/yellow] file\n" - "• Default collection names are auto-generated from URLs" - ), - title="🚀 Usage Tips", - border_style="green", - ) - ) - - -@app.command() -def list_collections() -> None: - """ - 📋 List all collections across storage backends. - """ - console.print("[bold cyan]📚 Collection Overview[/bold cyan]") - asyncio.run(run_list_collections()) - - -@app.command() -def search( - query: Annotated[str, typer.Argument(help="Search query")], - collection: Annotated[ - str | None, typer.Option("--collection", "-c", help="Target collection") - ] = None, - backend: Annotated[ - StorageBackend, typer.Option("--backend", "-b", help="Storage backend") - ] = StorageBackend.WEAVIATE, - limit: Annotated[int, typer.Option("--limit", "-l", help="Result limit")] = 10, -) -> None: - """ - 🔍 Search across collections. - """ - console.print(f"[bold cyan]🔍 Searching for: {query}[/bold cyan]") - asyncio.run(run_search(query, collection, backend.value, limit)) - - -@app.command(name="blocks") -def blocks_command() -> None: - """🧩 List and manage Prefect Blocks.""" - console.print("[bold cyan]📦 Prefect Blocks Management[/bold cyan]") - console.print("Use 'prefect block register --module ingest_pipeline.core.models' to register custom blocks") - console.print("Use 'prefect block ls' to list available blocks") - - -@app.command(name="variables") -def variables_command() -> None: - """📊 Manage Prefect Variables.""" - console.print("[bold cyan]📊 Prefect Variables Management[/bold cyan]") - console.print("Use 'prefect variable set VARIABLE_NAME value' to set variables") - console.print("Use 'prefect variable ls' to list variables") - - -async def run_ingestion( - url: str, - source_type: IngestionSource, - storage_backend: StorageBackend, - collection_name: str | None = None, - validate_first: bool = True, -) -> IngestionResult: - """ - Run ingestion with support for targeted collections. - """ - # Auto-generate collection name if not provided - if not collection_name: - from urllib.parse import urlparse - - parsed = urlparse(url) - domain = parsed.netloc.replace(".", "_").replace("-", "_") - collection_name = f"{domain}_{source_type.value}" - - return await create_ingestion_flow( - source_url=url, - source_type=source_type, - storage_backend=storage_backend, - collection_name=collection_name, - validate_first=validate_first, - ) - - -async def run_list_collections() -> None: - """ - List collections across storage backends. - """ - from ..config import get_settings - from ..core.models import StorageBackend - from ..storage.openwebui import OpenWebUIStorage - from ..storage.weaviate import WeaviateStorage - - settings = get_settings() - - console.print("🔍 [bold cyan]Scanning storage backends...[/bold cyan]") - - # Try to connect to Weaviate - weaviate_collections: list[tuple[str, int]] = [] - try: - weaviate_config = StorageConfig( - backend=StorageBackend.WEAVIATE, - endpoint=settings.weaviate_endpoint, - api_key=SecretStr(settings.weaviate_api_key) if settings.weaviate_api_key is not None else None, - collection_name="default", - ) - weaviate = WeaviateStorage(weaviate_config) - await weaviate.initialize() - - overview = await weaviate.describe_collections() - for item in overview: - name = str(item.get("name", "Unknown")) - count_val = item.get("count", 0) - count = int(count_val) if isinstance(count_val, (int, str)) else 0 - weaviate_collections.append((name, count)) - except Exception as e: - console.print(f"❌ [red]Weaviate connection failed: {e}[/red]") - - # Try to connect to OpenWebUI - openwebui_collections: list[tuple[str, int]] = [] - try: - openwebui_config = StorageConfig( - backend=StorageBackend.OPEN_WEBUI, - endpoint=settings.openwebui_endpoint, - api_key=SecretStr(settings.openwebui_api_key) if settings.openwebui_api_key is not None else None, - collection_name="default", - ) - openwebui = OpenWebUIStorage(openwebui_config) - await openwebui.initialize() - - overview = await openwebui.describe_collections() - for item in overview: - name = str(item.get("name", "Unknown")) - count_val = item.get("count", 0) - count = int(count_val) if isinstance(count_val, (int, str)) else 0 - openwebui_collections.append((name, count)) - except Exception as e: - console.print(f"❌ [red]OpenWebUI connection failed: {e}[/red]") - - # Display results - if weaviate_collections or openwebui_collections: - # Create results table - from rich.table import Table - - table = Table( - title="📚 Collection Overview", - title_style="bold magenta", - border_style="cyan", - header_style="bold blue", - ) - table.add_column("🏷️ Collection", style="cyan", no_wrap=True) - table.add_column("📊 Backend", style="yellow") - table.add_column("📄 Documents", style="green", justify="right") - - # Add Weaviate collections - for name, count in weaviate_collections: - table.add_row(name, "🗄️ Weaviate", f"{count:,}") - - # Add OpenWebUI collections - for name, count in openwebui_collections: - table.add_row(name, "🌐 OpenWebUI", f"{count:,}") - - console.print(table) - else: - console.print("❌ [yellow]No collections found in any backend[/yellow]") - - -async def run_search(query: str, collection: str | None, backend: str, limit: int) -> None: - """ - Search across collections. - """ - from ..config import get_settings - from ..core.models import StorageBackend - from ..storage.weaviate import WeaviateStorage - - settings = get_settings() - - console.print(f"🔍 Searching for: '[bold cyan]{query}[/bold cyan]'") - if collection: - console.print(f"📚 Target collection: [yellow]{collection}[/yellow]") - console.print(f"💾 Backend: [blue]{backend}[/blue]") - - results = [] - - try: - if backend == "weaviate": - weaviate_config = StorageConfig( - backend=StorageBackend.WEAVIATE, - endpoint=settings.weaviate_endpoint, - api_key=SecretStr(settings.weaviate_api_key) if settings.weaviate_api_key is not None else None, - collection_name=collection or "default", - ) - weaviate = WeaviateStorage(weaviate_config) - await weaviate.initialize() - - results_generator = weaviate.search(query, limit=limit) - async for doc in results_generator: - results.append( - { - "title": getattr(doc, "title", "Untitled"), - "content": getattr(doc, "content", ""), - "score": getattr(doc, "score", 0.0), - "backend": "🗄️ Weaviate", - } - ) - - elif backend == "open_webui": - console.print("❌ [red]OpenWebUI search not yet implemented[/red]") - return - - except Exception as e: - console.print(f"❌ [red]Search failed: {e}[/red]") - return - - # Display results - if results: - from rich.table import Table - - table = Table( - title=f"🔍 Search Results for '{query}'", - title_style="bold magenta", - border_style="green", - header_style="bold blue", - ) - table.add_column("📄 Title", style="cyan", max_width=40) - table.add_column("📝 Preview", style="white", max_width=60) - table.add_column("📊 Score", style="yellow", justify="right") - - for result in results[:limit]: - title = str(result["title"]) - title_display = title[:40] + "..." if len(title) > 40 else title - - content = str(result["content"]) - content_display = content[:60] + "..." if len(content) > 60 else content - - score = f"{result['score']:.3f}" - - table.add_row(title_display, content_display, score) - - console.print(table) - console.print(f"\n✅ [green]Found {len(results)} results[/green]") - else: - console.print("❌ [yellow]No results found[/yellow]") - - -if __name__ == "__main__": - app() - - - -"""Configuration management utilities.""" + +"""Prefect flow for ingestion pipeline.""" from __future__ import annotations -from contextlib import ExitStack +from collections.abc import Callable +from datetime import UTC, datetime +from typing import TYPE_CHECKING, Literal, TypeAlias, assert_never, cast -from prefect.settings import Setting, temporary_settings +from prefect import flow, get_run_logger, task +from prefect.blocks.core import Block +from prefect.variables import Variable +from pydantic import SecretStr + +from ..config.settings import Settings +from ..core.exceptions import IngestionError +from ..core.models import ( + Document, + FirecrawlConfig, + IngestionJob, + IngestionResult, + IngestionSource, + IngestionStatus, + RepomixConfig, + StorageBackend, + StorageConfig, +) +from ..ingestors import BaseIngestor, FirecrawlIngestor, FirecrawlPage, RepomixIngestor +from ..storage import OpenWebUIStorage, WeaviateStorage +from ..storage import R2RStorage as RuntimeR2RStorage +from ..storage.base import BaseStorage +from ..utils.metadata_tagger import MetadataTagger + +SourceTypeLiteral = Literal["web", "repository", "documentation"] +StorageBackendLiteral = Literal["weaviate", "open_webui", "r2r"] +SourceTypeLike: TypeAlias = IngestionSource | SourceTypeLiteral +StorageBackendLike: TypeAlias = StorageBackend | StorageBackendLiteral -# Import Prefect settings with version compatibility - avoid static analysis issues -def _setup_prefect_settings() -> tuple[object, object, object]: - """Setup Prefect settings with proper fallbacks.""" +def _safe_cache_key(prefix: str, params: dict[str, object], key: str) -> str: + """Create a type-safe cache key from task parameters.""" + value = params.get(key, "") + return f"{prefix}_{hash(str(value))}" + + +if TYPE_CHECKING: + from ..storage.r2r.storage import R2RStorage as R2RStorageType +else: + R2RStorageType = BaseStorage + + +@task(name="validate_source", retries=2, retry_delay_seconds=10, tags=["validation"]) +async def validate_source_task(source_url: str, source_type: IngestionSource) -> bool: + """ + Validate that a source is accessible. + + Args: + source_url: URL or path to source + source_type: Type of source + + Returns: + True if valid + """ + if source_type == IngestionSource.WEB: + ingestor = FirecrawlIngestor() + elif source_type == IngestionSource.REPOSITORY: + ingestor = RepomixIngestor() + else: + raise ValueError(f"Unsupported source type: {source_type}") + + result = await ingestor.validate_source(source_url) + return bool(result) + + +@task(name="initialize_storage", retries=3, retry_delay_seconds=5, tags=["storage"]) +async def initialize_storage_task(config: StorageConfig | str) -> BaseStorage: + """ + Initialize storage backend. + + Args: + config: Storage configuration block or block name + + Returns: + Initialized storage adapter + """ + # Load block if string provided + if isinstance(config, str): + # Use Block.aload with type slug for better type inference + loaded_block = await Block.aload(f"storage-config/{config}") + config = cast(StorageConfig, loaded_block) + + if config.backend == StorageBackend.WEAVIATE: + storage = WeaviateStorage(config) + elif config.backend == StorageBackend.OPEN_WEBUI: + storage = OpenWebUIStorage(config) + elif config.backend == StorageBackend.R2R: + if RuntimeR2RStorage is None: + raise ValueError("R2R storage not available. Check dependencies.") + storage = RuntimeR2RStorage(config) + else: + raise ValueError(f"Unsupported backend: {config.backend}") + + await storage.initialize() + return storage + + +@task( + name="map_firecrawl_site", + retries=2, + retry_delay_seconds=15, + tags=["firecrawl", "map"], + cache_key_fn=lambda ctx, p: _safe_cache_key("firecrawl_map", p, "source_url"), +) +async def map_firecrawl_site_task(source_url: str, config: FirecrawlConfig | str) -> list[str]: + """Map a site using Firecrawl and return discovered URLs.""" + # Load block if string provided + if isinstance(config, str): + # Use Block.aload with type slug for better type inference + loaded_block = await Block.aload(f"firecrawl-config/{config}") + config = cast(FirecrawlConfig, loaded_block) + + ingestor = FirecrawlIngestor(config) + mapped = await ingestor.map_site(source_url) + return mapped or [source_url] + + +@task( + name="filter_existing_documents", + retries=1, + retry_delay_seconds=5, + tags=["dedup"], + cache_key_fn=lambda ctx, p: _safe_cache_key("filter_docs", p, "urls"), +) # Cache based on URL list +async def filter_existing_documents_task( + urls: list[str], + storage_client: BaseStorage, + stale_after_days: int = 30, + *, + collection_name: str | None = None, +) -> list[str]: + """Filter URLs to only those that need scraping (missing or stale in storage).""" + import asyncio + + logger = get_run_logger() + + # Use semaphore to limit concurrent existence checks + semaphore = asyncio.Semaphore(20) + + async def check_url_exists(url: str) -> tuple[str, bool]: + async with semaphore: + try: + document_id = str(FirecrawlIngestor.compute_document_id(url)) + exists = await storage_client.check_exists( + document_id, collection_name=collection_name, stale_after_days=stale_after_days + ) + return url, exists + except Exception as e: + logger.warning("Error checking existence for URL %s: %s", url, e) + # Assume doesn't exist on error to ensure we scrape it + return url, False + + # Check all URLs in parallel - use return_exceptions=True for partial failure handling + results = await asyncio.gather(*[check_url_exists(url) for url in urls], return_exceptions=True) + + # Collect URLs that need scraping, handling any exceptions + eligible = [] + for result in results: + if isinstance(result, Exception): + logger.error("Unexpected error in parallel existence check: %s", result) + continue + # Type narrowing: result is now known to be tuple[str, bool] + if isinstance(result, tuple) and len(result) == 2: + url, exists = result + if not exists: + eligible.append(url) + + skipped = len(urls) - len(eligible) + if skipped > 0: + logger.info("Skipping %s up-to-date documents in %s", skipped, storage_client.display_name) + + return eligible + + +@task( + name="scrape_firecrawl_batch", retries=2, retry_delay_seconds=20, tags=["firecrawl", "scrape"] +) +async def scrape_firecrawl_batch_task( + batch_urls: list[str], config: FirecrawlConfig +) -> list[FirecrawlPage]: + """Scrape a batch of URLs via Firecrawl.""" + ingestor = FirecrawlIngestor(config) + result: list[FirecrawlPage] = await ingestor.scrape_pages(batch_urls) + return result + + +@task(name="annotate_firecrawl_metadata", retries=1, retry_delay_seconds=10, tags=["metadata"]) +async def annotate_firecrawl_metadata_task( + pages: list[FirecrawlPage], job: IngestionJob +) -> list[Document]: + """Annotate scraped pages with standardized metadata.""" + if not pages: + return [] + + ingestor = FirecrawlIngestor() + documents = [ingestor.create_document(page, job) for page in pages] + try: - import prefect.settings as ps + from ..config import get_settings - # Try to get the settings directly - api_key = getattr(ps, "PREFECT_API_KEY", None) - api_url = getattr(ps, "PREFECT_API_URL", None) - work_pool = getattr(ps, "PREFECT_DEFAULT_WORK_POOL_NAME", None) - - if api_key is not None: - return api_key, api_url, work_pool - - # Fallback to registry-based approach - registry = getattr(ps, "PREFECT_SETTING_REGISTRY", None) - if registry is not None: - Setting = getattr(ps, "Setting", None) - if Setting is not None: - api_key = registry.get("PREFECT_API_KEY") or Setting("PREFECT_API_KEY", type_=str, default=None) - api_url = registry.get("PREFECT_API_URL") or Setting("PREFECT_API_URL", type_=str, default=None) - work_pool = registry.get("PREFECT_DEFAULT_WORK_POOL_NAME") or Setting("PREFECT_DEFAULT_WORK_POOL_NAME", type_=str, default=None) - return api_key, api_url, work_pool - - except ImportError: - pass - - # Ultimate fallback - return None, None, None - -PREFECT_API_KEY, PREFECT_API_URL, PREFECT_DEFAULT_WORK_POOL_NAME = _setup_prefect_settings() - -# Import after Prefect settings setup to avoid circular dependencies -from .settings import Settings, get_settings # noqa: E402 - -__all__ = ["Settings", "get_settings", "configure_prefect"] - -_prefect_settings_stack: ExitStack | None = None + settings = get_settings() + async with MetadataTagger(llm_endpoint=str(settings.llm_endpoint)) as tagger: + tagged_documents: list[Document] = await tagger.tag_batch(documents) + return tagged_documents + except IngestionError as exc: # pragma: no cover - logging side effect + logger = get_run_logger() + logger.warning("Metadata tagging failed: %s", exc) + return documents + except Exception as exc: # pragma: no cover - defensive + logger = get_run_logger() + logger.warning("Metadata tagging unavailable, using base metadata: %s", exc) + return documents -def configure_prefect(settings: Settings) -> None: - """Apply Prefect settings from the application configuration.""" - global _prefect_settings_stack +@task(name="upsert_r2r_documents", retries=2, retry_delay_seconds=20, tags=["storage", "r2r"]) +async def upsert_r2r_documents_task( + storage_client: R2RStorageType, + documents: list[Document], + collection_name: str | None, +) -> tuple[int, int]: + """Upsert documents into R2R storage.""" + if not documents: + return 0, 0 - overrides: dict[Setting, str] = {} + stored_ids: list[str] = await storage_client.store_batch( + documents, collection_name=collection_name + ) + processed = len(stored_ids) + failed = len(documents) - processed - if settings.prefect_api_url is not None and PREFECT_API_URL is not None and isinstance(PREFECT_API_URL, Setting): - overrides[PREFECT_API_URL] = str(settings.prefect_api_url) - if settings.prefect_api_key and PREFECT_API_KEY is not None and isinstance(PREFECT_API_KEY, Setting): - overrides[PREFECT_API_KEY] = settings.prefect_api_key - if settings.prefect_work_pool and PREFECT_DEFAULT_WORK_POOL_NAME is not None and isinstance(PREFECT_DEFAULT_WORK_POOL_NAME, Setting): - overrides[PREFECT_DEFAULT_WORK_POOL_NAME] = settings.prefect_work_pool + if failed: + logger = get_run_logger() + logger.warning("Failed to upsert %s documents to R2R", failed) - if not overrides: - return + return processed, failed - filtered_overrides = { - setting: value - for setting, value in overrides.items() - if setting.value() != value + +@task(name="ingest_documents", retries=2, retry_delay_seconds=30, tags=["ingestion"]) +async def ingest_documents_task( + job: IngestionJob, + collection_name: str | None = None, + batch_size: int | None = None, + storage_client: BaseStorage | None = None, + storage_block_name: str | None = None, + ingestor_config_block_name: str | None = None, + progress_callback: Callable[[int, str], None] | None = None, +) -> tuple[int, int]: + """ + Ingest documents from source with optional pre-initialized storage client. + + Args: + job: Ingestion job configuration + collection_name: Target collection name + batch_size: Number of documents per batch (uses Variable if None) + storage_client: Optional pre-initialized storage client + storage_block_name: Optional storage block name to load + ingestor_config_block_name: Optional ingestor config block name to load + progress_callback: Optional callback for progress updates + + Returns: + Tuple of (processed_count, failed_count) + """ + if progress_callback: + progress_callback(35, "Creating ingestor and storage clients...") + + # Use Variable for batch size if not provided + if batch_size is None: + try: + batch_size_var = await Variable.aget("default_batch_size", default="50") + # Convert Variable result to int, handling various types + if isinstance(batch_size_var, int): + batch_size = batch_size_var + elif isinstance(batch_size_var, (str, float)): + batch_size = int(float(str(batch_size_var))) + else: + batch_size = 50 + except Exception: + batch_size = 50 + + ingestor = await _create_ingestor(job, ingestor_config_block_name) + storage = storage_client or await _create_storage(job, collection_name, storage_block_name) + + if progress_callback: + progress_callback(40, "Starting document processing...") + + return await _process_documents( + ingestor, storage, job, batch_size, collection_name, progress_callback + ) + + +async def _create_ingestor(job: IngestionJob, config_block_name: str | None = None) -> BaseIngestor: + """Create appropriate ingestor based on job source type.""" + if job.source_type == IngestionSource.WEB: + if config_block_name: + # Use Block.aload with type slug for better type inference + loaded_block = await Block.aload(f"firecrawl-config/{config_block_name}") + config = cast(FirecrawlConfig, loaded_block) + else: + # Fallback to default configuration + config = FirecrawlConfig() + return FirecrawlIngestor(config) + elif job.source_type == IngestionSource.REPOSITORY: + if config_block_name: + # Use Block.aload with type slug for better type inference + loaded_block = await Block.aload(f"repomix-config/{config_block_name}") + config = cast(RepomixConfig, loaded_block) + else: + # Fallback to default configuration + config = RepomixConfig() + return RepomixIngestor(config) + else: + raise ValueError(f"Unsupported source: {job.source_type}") + + +async def _create_storage( + job: IngestionJob, collection_name: str | None, storage_block_name: str | None = None +) -> BaseStorage: + """Create and initialize storage client.""" + if collection_name is None: + # Use variable for default collection prefix + prefix = await Variable.aget("default_collection_prefix", default="docs") + collection_name = f"{prefix}_{job.source_type.value}" + + if storage_block_name: + # Load storage config from block + loaded_block = await Block.aload(f"storage-config/{storage_block_name}") + storage_config = cast(StorageConfig, loaded_block) + # Override collection name if provided + storage_config.collection_name = collection_name + else: + # Fallback to building config from settings + from ..config import get_settings + + settings = get_settings() + storage_config = _build_storage_config(job, settings, collection_name) + + storage = _instantiate_storage(job.storage_backend, storage_config) + await storage.initialize() + return storage + + +def _build_storage_config( + job: IngestionJob, settings: Settings, collection_name: str +) -> StorageConfig: + """Build storage configuration from job and settings.""" + storage_endpoints = { + StorageBackend.WEAVIATE: settings.weaviate_endpoint, + StorageBackend.OPEN_WEBUI: settings.openwebui_endpoint, + StorageBackend.R2R: settings.get_storage_endpoint("r2r"), + } + storage_api_keys: dict[StorageBackend, str | None] = { + StorageBackend.WEAVIATE: settings.get_api_key("weaviate"), + StorageBackend.OPEN_WEBUI: settings.get_api_key("openwebui"), + StorageBackend.R2R: None, # R2R is self-hosted, no API key needed } - if not filtered_overrides: - return + api_key_raw: str | None = storage_api_keys[job.storage_backend] + api_key: SecretStr | None = SecretStr(api_key_raw) if api_key_raw is not None else None - new_stack = ExitStack() - new_stack.enter_context(temporary_settings(updates=filtered_overrides)) + return StorageConfig( + backend=job.storage_backend, + endpoint=storage_endpoints[job.storage_backend], + api_key=api_key, + collection_name=collection_name, + ) - if _prefect_settings_stack is not None: - _prefect_settings_stack.close() - _prefect_settings_stack = new_stack +def _instantiate_storage(backend: StorageBackend, config: StorageConfig) -> BaseStorage: + """Instantiate storage based on backend type.""" + if backend == StorageBackend.WEAVIATE: + return WeaviateStorage(config) + elif backend == StorageBackend.OPEN_WEBUI: + return OpenWebUIStorage(config) + elif backend == StorageBackend.R2R: + if RuntimeR2RStorage is None: + raise ValueError("R2R storage not available. Check dependencies.") + return RuntimeR2RStorage(config) + + assert_never(backend) + + +def _chunk_urls(urls: list[str], chunk_size: int) -> list[list[str]]: + """Group URLs into fixed-size chunks for batch processing.""" + + if chunk_size <= 0: + raise ValueError("chunk_size must be greater than zero") + + return [urls[i : i + chunk_size] for i in range(0, len(urls), chunk_size)] + + +def _deduplicate_urls(urls: list[str]) -> list[str]: + """Return the URLs with order preserved and duplicates removed.""" + + seen: set[str] = set() + unique: list[str] = [] + for url in urls: + if url not in seen: + seen.add(url) + unique.append(url) + return unique + + +async def _process_documents( + ingestor: BaseIngestor, + storage: BaseStorage, + job: IngestionJob, + batch_size: int, + collection_name: str | None, + progress_callback: Callable[[int, str], None] | None = None, +) -> tuple[int, int]: + """Process documents in batches.""" + processed = 0 + failed = 0 + batch: list[Document] = [] + total_documents = 0 + batch_count = 0 + + if progress_callback: + progress_callback(45, "Ingesting documents from source...") + + # Use smart ingestion with deduplication if storage supports it + if hasattr(storage, "check_exists"): + try: + # Try to use the smart ingestion method + document_generator = ingestor.ingest_with_dedup( + job, storage, collection_name=collection_name + ) + except Exception: + # Fall back to regular ingestion if smart method fails + document_generator = ingestor.ingest(job) + else: + document_generator = ingestor.ingest(job) + + async for document in document_generator: + batch.append(document) + total_documents += 1 + + if len(batch) >= batch_size: + batch_count += 1 + if progress_callback: + progress_callback( + 45 + min(35, (batch_count * 10)), + f"Processing batch {batch_count} ({total_documents} documents so far)...", + ) + + batch_processed, batch_failed = await _store_batch(storage, batch, collection_name) + processed += batch_processed + failed += batch_failed + batch = [] + + # Process remaining batch + if batch: + batch_count += 1 + if progress_callback: + progress_callback(80, f"Processing final batch ({total_documents} total documents)...") + + batch_processed, batch_failed = await _store_batch(storage, batch, collection_name) + processed += batch_processed + failed += batch_failed + + if progress_callback: + progress_callback(85, f"Completed processing {total_documents} documents") + + return processed, failed + + +async def _store_batch( + storage: BaseStorage, + batch: list[Document], + collection_name: str | None, +) -> tuple[int, int]: + """Store a batch of documents and return processed/failed counts.""" + try: + # Apply metadata tagging for backends that benefit from it + processed_batch = batch + if hasattr(storage, "config") and storage.config.backend in ( + StorageBackend.R2R, + StorageBackend.WEAVIATE, + ): + try: + from ..config import get_settings + + settings = get_settings() + async with MetadataTagger(llm_endpoint=str(settings.llm_endpoint)) as tagger: + processed_batch = await tagger.tag_batch(batch) + except Exception as exc: + print(f"Metadata tagging failed, using original documents: {exc}") + processed_batch = batch + + stored_ids = await storage.store_batch(processed_batch, collection_name=collection_name) + processed_count = len(stored_ids) + failed_count = len(processed_batch) - processed_count + + batch_type = ( + "final" if len(processed_batch) < 50 else "" + ) # Assume standard batch size is 50 + print(f"Successfully stored {processed_count} documents in {batch_type} batch".strip()) + + return processed_count, failed_count + except Exception as e: + batch_type = "Final" if len(batch) < 50 else "Batch" + print(f"{batch_type} storage failed: {e}") + return 0, len(batch) + + +@flow( + name="firecrawl_to_r2r", + description="Ingest Firecrawl pages into R2R with metadata annotation", + persist_result=False, + log_prints=True, +) +async def firecrawl_to_r2r_flow( + job: IngestionJob, + collection_name: str | None = None, + progress_callback: Callable[[int, str], None] | None = None, +) -> tuple[int, int]: + """Specialized flow for Firecrawl ingestion into R2R.""" + logger = get_run_logger() + from ..config import get_settings + + if progress_callback: + progress_callback(35, "Initializing Firecrawl and R2R storage...") + + settings = get_settings() + firecrawl_config = FirecrawlConfig() + resolved_collection = collection_name or f"docs_{job.source_type.value}" + + storage_config = _build_storage_config(job, settings, resolved_collection) + storage_client = await initialize_storage_task(storage_config) + + if RuntimeR2RStorage is None or not isinstance(storage_client, RuntimeR2RStorage): + raise IngestionError("Firecrawl to R2R flow requires an R2R storage backend") + + r2r_storage = cast("R2RStorageType", storage_client) + + if progress_callback: + progress_callback(45, "Checking for existing content before mapping...") + + # Smart mapping: try single URL first to avoid expensive map operation + base_url = str(job.source_url) + single_url_id = str(FirecrawlIngestor.compute_document_id(base_url)) + base_exists = await r2r_storage.check_exists( + single_url_id, collection_name=resolved_collection, stale_after_days=30 + ) + + if base_exists: + # Check if this is a recent single-page update + logger.info("Base URL %s exists and is fresh, skipping expensive mapping", base_url) + if progress_callback: + progress_callback(100, "Content is up to date, no processing needed") + return 0, 0 + + if progress_callback: + progress_callback(50, "Discovering pages with Firecrawl...") + + discovered_urls = await map_firecrawl_site_task(base_url, firecrawl_config) + unique_urls = _deduplicate_urls(discovered_urls) + logger.info("Discovered %s unique URLs from Firecrawl map", len(unique_urls)) + + if progress_callback: + progress_callback(60, f"Found {len(unique_urls)} pages, filtering existing content...") + + eligible_urls = await filter_existing_documents_task( + unique_urls, r2r_storage, collection_name=resolved_collection + ) + + if not eligible_urls: + logger.info("All Firecrawl pages are up to date for %s", job.source_url) + if progress_callback: + progress_callback(100, "All pages are up to date, no processing needed") + return 0, 0 + + if progress_callback: + progress_callback(70, f"Scraping {len(eligible_urls)} new/updated pages...") + + batch_size = min(settings.default_batch_size, firecrawl_config.limit) + url_batches = _chunk_urls(eligible_urls, batch_size) + logger.info("Scraping %s batches of Firecrawl pages", len(url_batches)) + + # Use asyncio.gather for concurrent scraping + import asyncio + + scrape_tasks = [scrape_firecrawl_batch_task(batch, firecrawl_config) for batch in url_batches] + batch_results = await asyncio.gather(*scrape_tasks) + + scraped_pages: list[FirecrawlPage] = [] + for batch_pages in batch_results: + scraped_pages.extend(batch_pages) + + if progress_callback: + progress_callback(80, f"Processing {len(scraped_pages)} scraped pages...") + + documents = await annotate_firecrawl_metadata_task(scraped_pages, job) + + if not documents: + logger.warning("No documents produced after scraping for %s", job.source_url) + return 0, len(eligible_urls) + + if progress_callback: + progress_callback(90, f"Storing {len(documents)} documents in R2R...") + + processed, failed = await upsert_r2r_documents_task(r2r_storage, documents, resolved_collection) + + logger.info("Upserted %s documents into R2R (%s failed)", processed, failed) + + return processed, failed + + +@task(name="update_job_status", tags=["tracking"]) +async def update_job_status_task( + job: IngestionJob, + status: IngestionStatus, + processed: int = 0, + _failed: int = 0, + error: str | None = None, +) -> IngestionJob: + """ + Update job status. + + Args: + job: Ingestion job + status: New status + processed: Documents processed + _failed: Documents failed (currently unused) + error: Error message if any + + Returns: + Updated job + """ + job.status = status + job.updated_at = datetime.now(UTC) + job.document_count = processed + + if status == IngestionStatus.COMPLETED: + job.completed_at = datetime.now(UTC) + + if error: + job.error_message = error + + return job + + +@flow( + name="ingestion_pipeline", + description="Main ingestion pipeline for documents", + retries=1, + retry_delay_seconds=60, + persist_result=True, + log_prints=True, +) +async def create_ingestion_flow( + source_url: str, + source_type: SourceTypeLike, + storage_backend: StorageBackendLike = StorageBackend.WEAVIATE, + collection_name: str | None = None, + validate_first: bool = True, + progress_callback: Callable[[int, str], None] | None = None, +) -> IngestionResult: + """ + Main ingestion flow. + + Args: + source_url: URL or path to source + source_type: Type of source + storage_backend: Storage backend to use + validate_first: Whether to validate source first + progress_callback: Optional callback for progress updates + + Returns: + Ingestion result + """ + print(f"Starting ingestion from {source_url}") + + source_enum = IngestionSource(source_type) + backend_enum = StorageBackend(storage_backend) + + # Create job + job = IngestionJob( + source_url=source_url, + source_type=source_enum, + storage_backend=backend_enum, + status=IngestionStatus.PENDING, + ) + + start_time = datetime.now(UTC) + error_messages: list[str] = [] + processed = 0 + failed = 0 + + try: + # Validate source if requested + if validate_first: + if progress_callback: + progress_callback(10, "Validating source...") + print("Validating source...") + is_valid = await validate_source_task(source_url, job.source_type) + + if not is_valid: + raise IngestionError(f"Source validation failed: {source_url}") + + # Update status to in progress + if progress_callback: + progress_callback(20, "Initializing storage...") + job = await update_job_status_task(job, IngestionStatus.IN_PROGRESS) + + # Run ingestion + if progress_callback: + progress_callback(30, "Starting document ingestion...") + print("Ingesting documents...") + if job.source_type == IngestionSource.WEB and job.storage_backend == StorageBackend.R2R: + processed, failed = await firecrawl_to_r2r_flow( + job, collection_name, progress_callback=progress_callback + ) + else: + processed, failed = await ingest_documents_task( + job, collection_name, progress_callback=progress_callback + ) + + if progress_callback: + progress_callback(90, "Finalizing ingestion...") + + # Update final status + if failed > 0: + error_messages.append(f"{failed} documents failed to process") + + # Set status based on results + if processed == 0 and failed > 0: + final_status = IngestionStatus.FAILED + elif failed > 0: + final_status = IngestionStatus.PARTIAL + else: + final_status = IngestionStatus.COMPLETED + + job = await update_job_status_task(job, final_status, processed=processed, _failed=failed) + + print(f"Ingestion completed: {processed} processed, {failed} failed") + + except Exception as e: + print(f"Ingestion failed: {e}") + error_messages.append(str(e)) + + # Don't reset counts - keep whatever was processed before the error + job = await update_job_status_task( + job, IngestionStatus.FAILED, processed=processed, _failed=failed, error=str(e) + ) + + # Calculate duration + duration = (datetime.now(UTC) - start_time).total_seconds() + + return IngestionResult( + job_id=job.id, + status=job.status, + documents_processed=processed, + documents_failed=failed, + duration_seconds=duration, + error_messages=error_messages, + ) @@ -7773,8 +9054,8 @@ def configure_prefect(settings: Settings) -> None: from datetime import timedelta from typing import Literal, Protocol, cast -from prefect import serve from prefect.deployments.runner import RunnerDeployment +from prefect.flows import serve as prefect_serve from prefect.schedules import Cron, Interval from prefect.variables import Variable @@ -7852,7 +9133,7 @@ def create_scheduled_deployment( tags = [source_enum.value, backend_enum.value] # Create deployment parameters with block support - parameters = { + parameters: dict[str, str | bool] = { "source_url": source_url, "source_type": source_enum.value, "storage_backend": backend_enum.value, @@ -7867,8 +9148,8 @@ def create_scheduled_deployment( # Create deployment # The flow decorator adds the to_deployment method at runtime - to_deployment = create_ingestion_flow.to_deployment - deployment = to_deployment( + flow_with_deployment = cast(FlowWithDeployment, create_ingestion_flow) + return flow_with_deployment.to_deployment( name=name, schedule=schedule, parameters=parameters, @@ -7876,8 +9157,6 @@ def create_scheduled_deployment( description=f"Scheduled ingestion from {source_url}", ) - return cast("RunnerDeployment", deployment) - def serve_deployments(deployments: list[RunnerDeployment]) -> None: """ @@ -7886,92 +9165,7 @@ def serve_deployments(deployments: list[RunnerDeployment]) -> None: Args: deployments: List of deployment configurations """ - serve(*deployments, limit=10) - - - -"""Base ingestor interface.""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator -from typing import TYPE_CHECKING - -from ..core.models import Document, IngestionJob - -if TYPE_CHECKING: - from ..storage.base import BaseStorage - - -class BaseIngestor(ABC): - """Abstract base class for all ingestors.""" - - @abstractmethod - def ingest(self, job: IngestionJob) -> AsyncGenerator[Document, None]: - """ - Ingest data from a source. - - Args: - job: The ingestion job configuration - - Yields: - Documents from the source - """ - ... # pragma: no cover - - @abstractmethod - async def validate_source(self, source_url: str) -> bool: - """ - Validate if the source is accessible. - - Args: - source_url: URL or path to the source - - Returns: - True if source is valid and accessible - """ - pass # pragma: no cover - - @abstractmethod - async def estimate_size(self, source_url: str) -> int: - """ - Estimate the number of documents in the source. - - Args: - source_url: URL or path to the source - - Returns: - Estimated number of documents - """ - pass # pragma: no cover - - async def ingest_with_dedup( - self, - job: IngestionJob, - storage_client: BaseStorage, - *, - collection_name: str | None = None, - stale_after_days: int = 30, - ) -> AsyncGenerator[Document, None]: - """ - Ingest documents with duplicate detection (optional optimization). - - Default implementation falls back to regular ingestion. - Subclasses can override to provide optimized deduplication. - - Args: - job: The ingestion job configuration - storage_client: Storage client to check for existing documents - collection_name: Collection to check for duplicates - stale_after_days: Consider documents stale after this many days - - Yields: - Documents from the source (with deduplication if implemented) - """ - # Default implementation: fall back to regular ingestion - async for document in self.ingest(job): - yield document + prefect_serve(*deployments, limit=10) @@ -7983,7 +9177,7 @@ import re from collections.abc import AsyncGenerator, Awaitable, Callable from dataclasses import dataclass from datetime import UTC, datetime -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Protocol, cast from urllib.parse import urlparse from uuid import NAMESPACE_URL, UUID, uuid5 @@ -8005,9 +9199,70 @@ if TYPE_CHECKING: from ..storage.base import BaseStorage +class FirecrawlMetadata(Protocol): + """Protocol for Firecrawl metadata objects.""" + + title: str | None + description: str | None + author: str | None + language: str | None + sitemap_last_modified: str | None + sourceURL: str | None + keywords: str | list[str] | None + robots: str | None + ogTitle: str | None + ogDescription: str | None + ogUrl: str | None + ogImage: str | None + twitterCard: str | None + twitterSite: str | None + twitterCreator: str | None + favicon: str | None + statusCode: int | None + + +class FirecrawlResult(Protocol): + """Protocol for Firecrawl scrape result objects.""" + + metadata: FirecrawlMetadata | None + markdown: str | None + + +class FirecrawlMapLink(Protocol): + """Protocol for Firecrawl map link objects.""" + + url: str + + +class FirecrawlMapResult(Protocol): + """Protocol for Firecrawl map result objects.""" + + links: list[FirecrawlMapLink] | None + + +class AsyncFirecrawlSession(Protocol): + """Protocol for AsyncFirecrawl session objects.""" + + async def close(self) -> None: ... + + +class AsyncFirecrawlClient(Protocol): + """Protocol for AsyncFirecrawl client objects.""" + + _session: AsyncFirecrawlSession | None + + async def close(self) -> None: ... + + async def scrape(self, url: str, formats: list[str]) -> FirecrawlResult: ... + + async def map(self, url: str, limit: int | None = None) -> "FirecrawlMapResult": ... + + class FirecrawlError(IngestionError): """Base exception for Firecrawl-related errors.""" + status_code: int | None + def __init__(self, message: str, status_code: int | None = None) -> None: super().__init__(message) self.status_code = status_code @@ -8041,7 +9296,7 @@ async def retry_with_backoff( except Exception as e: if attempt == max_retries - 1: raise e - delay = 1.0 * (2**attempt) + delay: float = 1.0 * (2**attempt) logging.warning( f"Firecrawl operation failed (attempt {attempt + 1}/{max_retries}): {e}. Retrying in {delay:.1f}s..." ) @@ -8081,7 +9336,7 @@ class FirecrawlIngestor(BaseIngestor): """Ingestor for web and documentation sites using Firecrawl.""" config: FirecrawlConfig - client: AsyncFirecrawl + client: AsyncFirecrawlClient def __init__(self, config: FirecrawlConfig | None = None): """ @@ -8107,15 +9362,16 @@ class FirecrawlIngestor(BaseIngestor): "http://localhost" ): # Self-hosted instance - try with api_url if supported - self.client = AsyncFirecrawl( - api_key=api_key, api_url=str(settings.firecrawl_endpoint) + self.client = cast( + AsyncFirecrawlClient, + AsyncFirecrawl(api_key=api_key, api_url=str(settings.firecrawl_endpoint)), ) else: # Cloud instance - use standard initialization - self.client = AsyncFirecrawl(api_key=api_key) + self.client = cast(AsyncFirecrawlClient, AsyncFirecrawl(api_key=api_key)) except Exception: # Fallback to standard initialization - self.client = AsyncFirecrawl(api_key=api_key) + self.client = cast(AsyncFirecrawlClient, AsyncFirecrawl(api_key=api_key)) @override async def ingest(self, job: IngestionJob) -> AsyncGenerator[Document, None]: @@ -8172,9 +9428,7 @@ class FirecrawlIngestor(BaseIngestor): for check_url in site_map: document_id = str(self.compute_document_id(check_url)) exists = await storage_client.check_exists( - document_id, - collection_name=collection_name, - stale_after_days=stale_after_days + document_id, collection_name=collection_name, stale_after_days=stale_after_days ) if not exists: eligible_urls.append(check_url) @@ -8254,11 +9508,11 @@ class FirecrawlIngestor(BaseIngestor): """ try: # Use SDK v2 map endpoint following official pattern - result = await self.client.map(url=url, limit=self.config.limit) + result: FirecrawlMapResult = await self.client.map(url=url, limit=self.config.limit) - if result and getattr(result, "links", None): + if result and result.links: # Extract URLs from the result following official pattern - return [getattr(link, "url", str(link)) for link in result.links] + return [link.url for link in result.links] return [] except Exception as e: # If map fails (might not be available in all versions), fall back to single URL @@ -8301,43 +9555,57 @@ class FirecrawlIngestor(BaseIngestor): try: # Use SDK v2 scrape endpoint following official pattern with retry async def scrape_operation() -> FirecrawlPage | None: - result = await self.client.scrape(url, formats=self.config.formats) + result: FirecrawlResult = await self.client.scrape(url, formats=self.config.formats) # Extract data from the result following official response handling if result: # The SDK returns a ScrapeData object with typed metadata - metadata = getattr(result, "metadata", None) + metadata: FirecrawlMetadata | None = getattr(result, "metadata", None) # Extract basic metadata - title = getattr(metadata, "title", None) if metadata else None - description = getattr(metadata, "description", None) if metadata else None + title: str | None = getattr(metadata, "title", None) if metadata else None + description: str | None = ( + getattr(metadata, "description", None) if metadata else None + ) # Extract enhanced metadata if available - author = getattr(metadata, "author", None) if metadata else None - language = getattr(metadata, "language", None) if metadata else None - sitemap_last_modified = ( + author: str | None = getattr(metadata, "author", None) if metadata else None + language: str | None = getattr(metadata, "language", None) if metadata else None + sitemap_last_modified: str | None = ( getattr(metadata, "sitemap_last_modified", None) if metadata else None ) - source_url = getattr(metadata, "sourceURL", None) if metadata else None - keywords = getattr(metadata, "keywords", None) if metadata else None - robots = getattr(metadata, "robots", None) if metadata else None + source_url: str | None = ( + getattr(metadata, "sourceURL", None) if metadata else None + ) + keywords: str | list[str] | None = ( + getattr(metadata, "keywords", None) if metadata else None + ) + robots: str | None = getattr(metadata, "robots", None) if metadata else None # Open Graph metadata - og_title = getattr(metadata, "ogTitle", None) if metadata else None - og_description = getattr(metadata, "ogDescription", None) if metadata else None - og_url = getattr(metadata, "ogUrl", None) if metadata else None - og_image = getattr(metadata, "ogImage", None) if metadata else None + og_title: str | None = getattr(metadata, "ogTitle", None) if metadata else None + og_description: str | None = ( + getattr(metadata, "ogDescription", None) if metadata else None + ) + og_url: str | None = getattr(metadata, "ogUrl", None) if metadata else None + og_image: str | None = getattr(metadata, "ogImage", None) if metadata else None # Twitter metadata - twitter_card = getattr(metadata, "twitterCard", None) if metadata else None - twitter_site = getattr(metadata, "twitterSite", None) if metadata else None - twitter_creator = ( + twitter_card: str | None = ( + getattr(metadata, "twitterCard", None) if metadata else None + ) + twitter_site: str | None = ( + getattr(metadata, "twitterSite", None) if metadata else None + ) + twitter_creator: str | None = ( getattr(metadata, "twitterCreator", None) if metadata else None ) # Additional metadata - favicon = getattr(metadata, "favicon", None) if metadata else None - status_code = getattr(metadata, "statusCode", None) if metadata else None + favicon: str | None = getattr(metadata, "favicon", None) if metadata else None + status_code: int | None = ( + getattr(metadata, "statusCode", None) if metadata else None + ) return FirecrawlPage( url=url, @@ -8350,7 +9618,7 @@ class FirecrawlIngestor(BaseIngestor): source_url=source_url, keywords=keywords.split(",") if keywords and isinstance(keywords, str) - else keywords, + else (keywords if isinstance(keywords, list) else None), robots=robots, og_title=og_title, og_description=og_description, @@ -8376,11 +9644,11 @@ class FirecrawlIngestor(BaseIngestor): return uuid5(NAMESPACE_URL, source_url) @staticmethod - def _analyze_content_structure(content: str) -> dict[str, object]: + def _analyze_content_structure(content: str) -> dict[str, str | int | bool | list[str]]: """Analyze markdown content to extract structural information.""" # Extract heading hierarchy heading_pattern = r"^(#{1,6})\s+(.+)$" - headings = [] + headings: list[str] = [] for match in re.finditer(heading_pattern, content, re.MULTILINE): level = len(match.group(1)) text = match.group(2).strip() @@ -8395,7 +9663,8 @@ class FirecrawlIngestor(BaseIngestor): max_depth = 0 if headings: for heading in headings: - depth = (len(heading) - len(heading.lstrip())) // 2 + 1 + heading_str: str = str(heading) + depth = (len(heading_str) - len(heading_str.lstrip())) // 2 + 1 max_depth = max(max_depth, depth) return { @@ -8509,10 +9778,16 @@ class FirecrawlIngestor(BaseIngestor): "site_name": domain_info["site_name"], # Document structure "heading_hierarchy": ( - list(hierarchy_val) if (hierarchy_val := structure_info.get("heading_hierarchy")) and isinstance(hierarchy_val, (list, tuple)) else [] + list(hierarchy_val) + if (hierarchy_val := structure_info.get("heading_hierarchy")) + and isinstance(hierarchy_val, (list, tuple)) + else [] ), "section_depth": ( - int(depth_val) if (depth_val := structure_info.get("section_depth")) and isinstance(depth_val, (int, str)) else 0 + int(depth_val) + if (depth_val := structure_info.get("section_depth")) + and isinstance(depth_val, (int, str)) + else 0 ), "has_code_blocks": bool(structure_info.get("has_code_blocks", False)), "has_images": bool(structure_info.get("has_images", False)), @@ -8547,7 +9822,11 @@ class FirecrawlIngestor(BaseIngestor): await self.client.close() except Exception as e: logging.debug(f"Error closing Firecrawl client: {e}") - elif hasattr(self.client, "_session") and hasattr(self.client._session, "close"): + elif ( + hasattr(self.client, "_session") + and self.client._session + and hasattr(self.client._session, "close") + ): try: await self.client._session.close() except Exception as e: @@ -8572,13 +9851,17 @@ class FirecrawlIngestor(BaseIngestor): import json from datetime import UTC, datetime -from typing import Protocol, TypedDict, cast +from typing import Final, Protocol, TypedDict, cast import httpx +from ..config import get_settings from ..core.exceptions import IngestionError from ..core.models import Document +JSON_CONTENT_TYPE: Final[str] = "application/json" +AUTHORIZATION_HEADER: Final[str] = "Authorization" + class HttpResponse(Protocol): """Protocol for HTTP response.""" @@ -8590,28 +9873,35 @@ class HttpResponse(Protocol): class AsyncHttpClient(Protocol): """Protocol for async HTTP client.""" - async def post( - self, - url: str, - *, - json: dict[str, object] | None = None - ) -> HttpResponse: ... + async def post(self, url: str, *, json: dict[str, object] | None = None) -> HttpResponse: ... async def aclose(self) -> None: ... + async def __aenter__(self) -> "AsyncHttpClient": ... + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: object | None, + ) -> None: ... + class LlmResponse(TypedDict): """Type for LLM API response structure.""" + choices: list[dict[str, object]] class LlmChoice(TypedDict): """Type for individual choice in LLM response.""" + message: dict[str, object] class LlmMessage(TypedDict): """Type for message in LLM choice.""" + content: str @@ -8636,8 +9926,11 @@ class MetadataTagger: def __init__( self, - llm_endpoint: str = "http://llm.lab", - model: str = "fireworks/glm-4p5-air", + llm_endpoint: str | None = None, + model: str | None = None, + api_key: str | None = None, + *, + timeout: float | None = None, ): """ Initialize metadata tagger. @@ -8645,30 +9938,26 @@ class MetadataTagger: Args: llm_endpoint: LLM API endpoint model: Model to use for tagging + api_key: Explicit API key override + timeout: Optional request timeout override in seconds """ - self.endpoint = llm_endpoint.rstrip('/') - self.model = model + settings = get_settings() + endpoint_value = llm_endpoint or str(settings.llm_endpoint) + self.endpoint = endpoint_value.rstrip("/") + self.model = model or settings.metadata_model - # Get API key from environment - import os - from pathlib import Path + resolved_timeout = timeout if timeout is not None else float(settings.request_timeout) + resolved_api_key = api_key or settings.get_llm_api_key() or "" - from dotenv import load_dotenv - - # Load .env from the project root - env_path = Path(__file__).parent.parent.parent / ".env" - _ = load_dotenv(env_path) - - api_key = os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY") or "" - - headers = {"Content-Type": "application/json"} - if api_key: - headers["Authorization"] = f"Bearer {api_key}" + headers: dict[str, str] = {"Content-Type": JSON_CONTENT_TYPE} + if resolved_api_key: + headers[AUTHORIZATION_HEADER] = f"Bearer {resolved_api_key}" # Create client with proper typing - httpx.AsyncClient implements AsyncHttpClient protocol - AsyncClientClass = getattr(httpx, "AsyncClient") - raw_client = AsyncClientClass(timeout=60.0, headers=headers) - self.client = cast(AsyncHttpClient, raw_client) + self.client = cast( + AsyncHttpClient, + httpx.AsyncClient(timeout=resolved_timeout, headers=headers), + ) async def tag_document( self, document: Document, custom_instructions: str | None = None @@ -8719,12 +10008,16 @@ class MetadataTagger: # Build a proper DocumentMetadata instance with only valid keys new_metadata: CoreDocumentMetadata = { "source_url": str(updated_metadata.get("source_url", "")), - "timestamp": ( - lambda ts: ts if isinstance(ts, datetime) else datetime.now(UTC) - )(updated_metadata.get("timestamp", datetime.now(UTC))), + "timestamp": (lambda ts: ts if isinstance(ts, datetime) else datetime.now(UTC))( + updated_metadata.get("timestamp", datetime.now(UTC)) + ), "content_type": str(updated_metadata.get("content_type", "text/plain")), - "word_count": (lambda wc: int(wc) if isinstance(wc, (int, str)) else 0)(updated_metadata.get("word_count", 0)), - "char_count": (lambda cc: int(cc) if isinstance(cc, (int, str)) else 0)(updated_metadata.get("char_count", 0)), + "word_count": (lambda wc: int(wc) if isinstance(wc, (int, str)) else 0)( + updated_metadata.get("word_count", 0) + ), + "char_count": (lambda cc: int(cc) if isinstance(cc, (int, str)) else 0)( + updated_metadata.get("char_count", 0) + ), } # Add optional fields if they exist @@ -8926,16 +10219,79 @@ Return a JSON object with the following structure: """Vectorizer utility for generating embeddings.""" +import asyncio from types import TracebackType -from typing import Self, cast +from typing import Final, NotRequired, Self, TypedDict import httpx -from typings import EmbeddingResponse - +from ..config import get_settings from ..core.exceptions import VectorizationError from ..core.models import StorageConfig, VectorConfig +JSON_CONTENT_TYPE: Final[str] = "application/json" +AUTHORIZATION_HEADER: Final[str] = "Authorization" + + +class EmbeddingData(TypedDict): + """Structure for embedding data from providers.""" + + embedding: list[float] + index: NotRequired[int] + object: NotRequired[str] + + +class EmbeddingResponse(TypedDict): + """Embedding response format for multiple providers.""" + + data: list[EmbeddingData] + model: NotRequired[str] + object: NotRequired[str] + usage: NotRequired[dict[str, int]] + # Alternative formats + embedding: NotRequired[list[float]] + vector: NotRequired[list[float]] + embeddings: NotRequired[list[list[float]]] + + +def _extract_embedding_from_response(response_data: dict[str, object]) -> list[float]: + """Extract embedding vector from provider response.""" + # OpenAI/Ollama format: {"data": [{"embedding": [...]}]} + if "data" in response_data: + data_list = response_data["data"] + if isinstance(data_list, list) and data_list: + first_item = data_list[0] + if isinstance(first_item, dict) and "embedding" in first_item: + embedding = first_item["embedding"] + if isinstance(embedding, list) and all( + isinstance(x, (int, float)) for x in embedding + ): + return [float(x) for x in embedding] + + # Direct embedding format: {"embedding": [...]} + if "embedding" in response_data: + embedding = response_data["embedding"] + if isinstance(embedding, list) and all(isinstance(x, (int, float)) for x in embedding): + return [float(x) for x in embedding] + + # Vector format: {"vector": [...]} + if "vector" in response_data: + vector = response_data["vector"] + if isinstance(vector, list) and all(isinstance(x, (int, float)) for x in vector): + return [float(x) for x in vector] + + # Embeddings array format: {"embeddings": [[...]]} + if "embeddings" in response_data: + embeddings = response_data["embeddings"] + if isinstance(embeddings, list) and embeddings: + first_embedding = embeddings[0] + if isinstance(first_embedding, list) and all( + isinstance(x, (int, float)) for x in first_embedding + ): + return [float(x) for x in first_embedding] + + raise VectorizationError("Unrecognized embedding response format") + class Vectorizer: """Handles text vectorization using LLM endpoints.""" @@ -8951,33 +10307,24 @@ class Vectorizer: Args: config: Configuration with embedding details """ + settings = get_settings() if isinstance(config, StorageConfig): - # Extract vector config from storage config - self.endpoint = "http://llm.lab" - self.model = "ollama/bge-m3" - self.dimension = 1024 + # Extract vector config from global settings when storage config is provided + self.endpoint = str(settings.llm_endpoint).rstrip("/") + self.model = settings.embedding_model + self.dimension = settings.embedding_dimension else: - self.endpoint = str(config.embedding_endpoint) + self.endpoint = str(config.embedding_endpoint).rstrip("/") self.model = config.model self.dimension = config.dimension - # Get API key from environment - import os - from pathlib import Path + resolved_api_key = settings.get_llm_api_key() or "" + headers: dict[str, str] = {"Content-Type": JSON_CONTENT_TYPE} + if resolved_api_key: + headers[AUTHORIZATION_HEADER] = f"Bearer {resolved_api_key}" - from dotenv import load_dotenv - - # Load .env from the project root - env_path = Path(__file__).parent.parent.parent / ".env" - _ = load_dotenv(env_path) - - api_key = os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY") or "" - - headers = {"Content-Type": "application/json"} - if api_key: - headers["Authorization"] = f"Bearer {api_key}" - - self.client: httpx.AsyncClient = httpx.AsyncClient(timeout=60.0, headers=headers) + timeout_seconds = float(settings.request_timeout) + self.client = httpx.AsyncClient(timeout=timeout_seconds, headers=headers) async def vectorize(self, text: str) -> list[float]: """ @@ -9003,21 +10350,34 @@ class Vectorizer: async def vectorize_batch(self, texts: list[str]) -> list[list[float]]: """ - Generate embeddings for multiple texts. + Generate embeddings for multiple texts in parallel. Args: texts: List of texts to vectorize Returns: List of embedding vectors + + Raises: + VectorizationError: If any vectorization fails """ - vectors: list[list[float]] = [] - for text in texts: - vector = await self.vectorize(text) - vectors.append(vector) + if not texts: + return [] - return vectors + # Use semaphore to limit concurrent requests and prevent overwhelming the endpoint + semaphore = asyncio.Semaphore(20) + + async def vectorize_with_semaphore(text: str) -> list[float]: + async with semaphore: + return await self.vectorize(text) + + try: + # Execute all vectorization requests concurrently + vectors = await asyncio.gather(*[vectorize_with_semaphore(text) for text in texts]) + return list(vectors) + except Exception as e: + raise VectorizationError(f"Batch vectorization failed: {e}") from e async def _ollama_embed(self, text: str) -> list[float]: """ @@ -9043,23 +10403,12 @@ class Vectorizer: _ = response.raise_for_status() response_json = response.json() - # Response is expected to be dict[str, object] from our type stub + if not isinstance(response_json, dict): + raise VectorizationError("Invalid JSON response format") - response_data = cast(EmbeddingResponse, cast(object, response_json)) + # Extract embedding using type-safe helper + embedding = _extract_embedding_from_response(response_json) - # Parse OpenAI-compatible response format - embeddings_list = response_data.get("data", []) - if not embeddings_list: - raise VectorizationError("No embeddings returned") - - first_embedding = embeddings_list[0] - embedding_raw = first_embedding.get("embedding") - if not embedding_raw: - raise VectorizationError("Invalid embedding format") - - # Convert to float list and validate - embedding: list[float] = [] - embedding.extend(float(item) for item in embedding_raw) # Ensure correct dimension if len(embedding) != self.dimension: raise VectorizationError( @@ -9088,22 +10437,12 @@ class Vectorizer: _ = response.raise_for_status() response_json = response.json() - # Response is expected to be dict[str, object] from our type stub + if not isinstance(response_json, dict): + raise VectorizationError("Invalid JSON response format") - response_data = cast(EmbeddingResponse, cast(object, response_json)) + # Extract embedding using type-safe helper + embedding = _extract_embedding_from_response(response_json) - embeddings_list = response_data.get("data", []) - if not embeddings_list: - raise VectorizationError("No embeddings returned") - - first_embedding = embeddings_list[0] - embedding_raw = first_embedding.get("embedding") - if not embedding_raw: - raise VectorizationError("Invalid embedding format") - - # Convert to float list and validate - embedding: list[float] = [] - embedding.extend(float(item) for item in embedding_raw) # Ensure correct dimension if len(embedding) != self.dimension: raise VectorizationError( @@ -9116,6 +10455,14 @@ class Vectorizer: """Async context manager entry.""" return self + async def close(self) -> None: + """Close the HTTP client connection.""" + try: + await self.client.aclose() + except Exception: + # Already closed or connection lost + pass + async def __aexit__( self, exc_type: type[BaseException] | None, @@ -9123,14 +10470,13 @@ class Vectorizer: exc_tb: TracebackType | None, ) -> None: """Async context manager exit.""" - await self.client.aclose() + await self.close() """Main dashboard screen with collections overview.""" import logging -from datetime import datetime from typing import TYPE_CHECKING, Final from textual import work @@ -9335,7 +10681,11 @@ class CollectionOverviewScreen(Screen[None]): """Calculate basic metrics from collections.""" self.total_collections = len(self.collections) self.total_documents = sum(col["count"] for col in self.collections) - self.active_backends = sum([bool(self.weaviate), bool(self.openwebui), bool(self.r2r)]) + # Calculate active backends from storage manager if individual storages are None + if self.weaviate is None and self.openwebui is None and self.r2r is None: + self.active_backends = len(self.storage_manager.get_available_backends()) + else: + self.active_backends = sum([bool(self.weaviate), bool(self.openwebui), bool(self.r2r)]) def _update_metrics_cards(self) -> None: """Update the metrics cards display.""" @@ -9482,76 +10832,6 @@ class CollectionOverviewScreen(Screen[None]): self.is_loading = False loading_indicator.display = False - async def list_weaviate_collections(self) -> list[CollectionInfo]: - """List Weaviate collections with enhanced metadata.""" - if not self.weaviate: - return [] - - try: - overview = await self.weaviate.describe_collections() - collections: list[CollectionInfo] = [] - - for item in overview: - count_raw = item.get("count", 0) - count_val = int(count_raw) if isinstance(count_raw, (int, str)) else 0 - size_mb_raw = item.get("size_mb", 0.0) - size_mb_val = float(size_mb_raw) if isinstance(size_mb_raw, (int, float, str)) else 0.0 - collections.append( - CollectionInfo( - name=str(item.get("name", "Unknown")), - type="weaviate", - count=count_val, - backend="🗄️ Weaviate", - status="✓ Active", - last_updated=datetime.now().strftime("%Y-%m-%d %H:%M"), - size_mb=size_mb_val, - ) - ) - - return collections - except Exception as e: - self.notify(f"Error listing Weaviate collections: {e}", severity="error", markup=False) - return [] - - async def list_openwebui_collections(self) -> list[CollectionInfo]: - """List OpenWebUI collections with enhanced metadata.""" - # Try to get OpenWebUI backend from storage manager if direct instance not available - openwebui_backend = self.openwebui - if not openwebui_backend: - backend = self.storage_manager.get_backend(StorageBackend.OPEN_WEBUI) - if not isinstance(backend, OpenWebUIStorage): - return [] - openwebui_backend = backend - if not openwebui_backend: - return [] - - try: - overview = await openwebui_backend.describe_collections() - collections: list[CollectionInfo] = [] - - for item in overview: - count_raw = item.get("count", 0) - count_val = int(count_raw) if isinstance(count_raw, (int, str)) else 0 - size_mb_raw = item.get("size_mb", 0.0) - size_mb_val = float(size_mb_raw) if isinstance(size_mb_raw, (int, float, str)) else 0.0 - collection_name = str(item.get("name", "Unknown")) - collections.append( - CollectionInfo( - name=collection_name, - type="openwebui", - count=count_val, - backend="🌐 OpenWebUI", - status="✓ Active", - last_updated=datetime.now().strftime("%Y-%m-%d %H:%M"), - size_mb=size_mb_val, - ) - ) - - return collections - except Exception as e: - self.notify(f"Error listing OpenWebUI collections: {e}", severity="error", markup=False) - return [] - async def update_collections_table(self) -> None: """Update the collections table with enhanced formatting.""" table = self.query_one("#collections_table", EnhancedDataTable) @@ -9825,7 +11105,7 @@ Enjoy the enhanced interface! 🎉 from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING, ClassVar +from typing import TYPE_CHECKING from textual.app import ComposeResult from textual.binding import Binding @@ -9837,6 +11117,7 @@ from typing_extensions import override from ..models import CollectionInfo if TYPE_CHECKING: + from ..app import CollectionManagementApp from .dashboard import CollectionOverviewScreen from .documents import DocumentManagementScreen @@ -9847,7 +11128,12 @@ class ConfirmDeleteScreen(Screen[None]): collection: CollectionInfo parent_screen: CollectionOverviewScreen - BINDINGS: list[Binding] = [ + @property + def app(self) -> CollectionManagementApp: # type: ignore[override] + """Return the typed app instance.""" + return super().app # type: ignore[return-value] + + BINDINGS = [ Binding("escape", "app.pop_screen", "Cancel"), Binding("y", "confirm_delete", "Yes"), Binding("n", "app.pop_screen", "No"), @@ -9898,7 +11184,10 @@ class ConfirmDeleteScreen(Screen[None]): try: if self.collection["type"] == "weaviate" and self.parent_screen.weaviate: # Delete Weaviate collection - if self.parent_screen.weaviate.client and self.parent_screen.weaviate.client.collections: + if ( + self.parent_screen.weaviate.client + and self.parent_screen.weaviate.client.collections + ): self.parent_screen.weaviate.client.collections.delete(self.collection["name"]) self.notify( f"Deleted Weaviate collection: {self.collection['name']}", @@ -9916,7 +11205,7 @@ class ConfirmDeleteScreen(Screen[None]): return # Check if the storage backend supports collection deletion - if not hasattr(storage_backend, 'delete_collection'): + if not hasattr(storage_backend, "delete_collection"): self.notify( f"❌ Collection deletion not supported for {self.collection['type']} backend", severity="error", @@ -9929,10 +11218,13 @@ class ConfirmDeleteScreen(Screen[None]): collection_name = str(self.collection["name"]) collection_type = str(self.collection["type"]) - self.notify(f"Deleting {collection_type} collection: {collection_name}...", severity="information") + self.notify( + f"Deleting {collection_type} collection: {collection_name}...", + severity="information", + ) # Use the standard delete_collection method for all backends - if hasattr(storage_backend, 'delete_collection'): + if hasattr(storage_backend, "delete_collection"): success = await storage_backend.delete_collection(collection_name) else: self.notify("❌ Backend does not support collection deletion", severity="error") @@ -9954,12 +11246,15 @@ class ConfirmDeleteScreen(Screen[None]): return # Refresh parent screen after a short delay to ensure deletion is processed - self.call_later(lambda _: self.parent_screen.refresh_collections(), 0.5) # 500ms delay + self.call_later(self._refresh_parent_collections, 0.5) # 500ms delay self.app.pop_screen() except Exception as e: self.notify(f"Failed to delete collection: {e}", severity="error", markup=False) + def _refresh_parent_collections(self) -> None: + """Helper method to refresh parent collections.""" + self.parent_screen.refresh_collections() class ConfirmDocumentDeleteScreen(Screen[None]): @@ -9967,9 +11262,14 @@ class ConfirmDocumentDeleteScreen(Screen[None]): doc_ids: list[str] collection: CollectionInfo - parent_screen: "DocumentManagementScreen" + parent_screen: DocumentManagementScreen - BINDINGS: list[Binding] = [ + @property + def app(self) -> CollectionManagementApp: # type: ignore[override] + """Return the typed app instance.""" + return super().app # type: ignore[return-value] + + BINDINGS = [ Binding("escape", "app.pop_screen", "Cancel"), Binding("y", "confirm_delete", "Yes"), Binding("n", "app.pop_screen", "No"), @@ -9980,7 +11280,7 @@ class ConfirmDocumentDeleteScreen(Screen[None]): self, doc_ids: list[str], collection: CollectionInfo, - parent_screen: "DocumentManagementScreen", + parent_screen: DocumentManagementScreen, ): super().__init__() self.doc_ids = doc_ids @@ -10030,11 +11330,11 @@ class ConfirmDocumentDeleteScreen(Screen[None]): try: results: dict[str, bool] = {} - if hasattr(self.parent_screen, 'storage') and self.parent_screen.storage: + if hasattr(self.parent_screen, "storage") and self.parent_screen.storage: # Delete documents via storage # The storage should have delete_documents method for weaviate storage = self.parent_screen.storage - if hasattr(storage, 'delete_documents'): + if hasattr(storage, "delete_documents"): results = await storage.delete_documents( self.doc_ids, collection_name=self.collection["name"], @@ -10066,7 +11366,12 @@ class LogViewerScreen(ModalScreen[None]): _log_widget: RichLog | None _log_file: Path | None - BINDINGS: list[Binding] = [ + @property + def app(self) -> CollectionManagementApp: # type: ignore[override] + """Return the typed app instance.""" + return super().app # type: ignore[return-value] + + BINDINGS = [ Binding("escape", "close", "Close"), Binding("ctrl+l", "close", "Close"), Binding("s", "show_path", "Log File"), @@ -10082,7 +11387,9 @@ class LogViewerScreen(ModalScreen[None]): yield Header(show_clock=True) yield Container( Static("📜 Live Application Logs", classes="title"), - Static("Logs update in real time. Press S to reveal the log file path.", classes="subtitle"), + Static( + "Logs update in real time. Press S to reveal the log file path.", classes="subtitle" + ), RichLog(id="log_stream", classes="log-stream", wrap=True, highlight=False), Static("", id="log_file_path", classes="subtitle"), classes="main_container log-viewer-container", @@ -10093,14 +11400,14 @@ class LogViewerScreen(ModalScreen[None]): """Attach this viewer to the parent application once mounted.""" self._log_widget = self.query_one(RichLog) - if hasattr(self.app, 'attach_log_viewer'): - self.app.attach_log_viewer(self) + if hasattr(self.app, "attach_log_viewer"): + self.app.attach_log_viewer(self) # type: ignore[arg-type] def on_unmount(self) -> None: """Detach from the parent application when closed.""" - if hasattr(self.app, 'detach_log_viewer'): - self.app.detach_log_viewer(self) + if hasattr(self.app, "detach_log_viewer"): + self.app.detach_log_viewer(self) # type: ignore[arg-type] def _get_log_widget(self) -> RichLog: if self._log_widget is None: @@ -10142,14 +11449,16 @@ class LogViewerScreen(ModalScreen[None]): if self._log_file is None: self.notify("File logging is disabled for this session.", severity="warning") else: - self.notify(f"Log file available at: {self._log_file}", severity="information", markup=False) + self.notify( + f"Log file available at: {self._log_file}", severity="information", markup=False + ) """Application settings and configuration.""" from functools import lru_cache -from typing import Annotated, ClassVar, Literal +from typing import Annotated, ClassVar, Final, Literal from prefect.variables import Variable from pydantic import Field, HttpUrl, model_validator @@ -10168,6 +11477,8 @@ class Settings(BaseSettings): # API Keys firecrawl_api_key: str | None = None + llm_api_key: str | None = None + openai_api_key: str | None = None openwebui_api_key: str | None = None weaviate_api_key: str | None = None r2r_api_key: str | None = None @@ -10181,6 +11492,7 @@ class Settings(BaseSettings): # Model Configuration embedding_model: str = "ollama/bge-m3:latest" + metadata_model: str = "fireworks/glm-4p5-air" embedding_dimension: int = 1024 # Ingestion Settings @@ -10248,14 +11560,20 @@ class Settings(BaseSettings): Returns: API key or None """ - service_map = { + service_map: Final[dict[str, str | None]] = { "firecrawl": self.firecrawl_api_key, "openwebui": self.openwebui_api_key, "weaviate": self.weaviate_api_key, "r2r": self.r2r_api_key, + "llm": self.get_llm_api_key(), + "openai": self.openai_api_key, } return service_map.get(service) + def get_llm_api_key(self) -> str | None: + """Get API key for LLM services with OpenAI fallback.""" + return self.llm_api_key or (self.openai_api_key or None) + @model_validator(mode="after") def validate_backend_configuration(self) -> "Settings": """Validate that required configuration is present for the default backend.""" @@ -10278,10 +11596,11 @@ class Settings(BaseSettings): key_name, key_value = required_keys[backend] if not key_value: import warnings + warnings.warn( f"{key_name} not set - authentication may fail for {backend} backend", UserWarning, - stacklevel=2 + stacklevel=2, ) return self @@ -10304,16 +11623,24 @@ class PrefectVariableConfig: def __init__(self) -> None: self._settings: Settings = get_settings() self._variable_names: list[str] = [ - "default_batch_size", "max_file_size", "max_crawl_depth", "max_crawl_pages", - "default_storage_backend", "default_collection_prefix", "max_concurrent_tasks", - "request_timeout", "default_schedule_interval" + "default_batch_size", + "max_file_size", + "max_crawl_depth", + "max_crawl_pages", + "default_storage_backend", + "default_collection_prefix", + "max_concurrent_tasks", + "request_timeout", + "default_schedule_interval", ] def _get_fallback_value(self, name: str, default_value: object = None) -> object: """Get fallback value from settings or default.""" return default_value or getattr(self._settings, name, default_value) - def get_with_fallback(self, name: str, default_value: str | int | float | None = None) -> str | int | float | None: + def get_with_fallback( + self, name: str, default_value: str | int | float | None = None + ) -> str | int | float | None: """Get variable value with fallback synchronously.""" fallback = self._get_fallback_value(name, default_value) # Ensure fallback is a type that Variable expects @@ -10334,7 +11661,9 @@ class PrefectVariableConfig: return fallback return str(fallback) if fallback is not None else None - async def get_with_fallback_async(self, name: str, default_value: str | int | float | None = None) -> str | int | float | None: + async def get_with_fallback_async( + self, name: str, default_value: str | int | float | None = None + ) -> str | int | float | None: """Get variable value with fallback asynchronously.""" fallback = self._get_fallback_value(name, default_value) variable_fallback = str(fallback) if fallback is not None else None @@ -10383,6 +11712,41 @@ from uuid import UUID, uuid4 from prefect.blocks.core import Block from pydantic import BaseModel, Field, HttpUrl, SecretStr +from ..config import get_settings + + +def _default_embedding_model() -> str: + return str(get_settings().embedding_model) + + +def _default_embedding_endpoint() -> HttpUrl: + endpoint = get_settings().llm_endpoint + return endpoint if isinstance(endpoint, HttpUrl) else HttpUrl(str(endpoint)) + + +def _default_embedding_dimension() -> int: + return int(get_settings().embedding_dimension) + + +def _default_batch_size() -> int: + return int(get_settings().default_batch_size) + + +def _default_collection_name() -> str: + return str(get_settings().default_collection_prefix) + + +def _default_max_crawl_depth() -> int: + return int(get_settings().max_crawl_depth) + + +def _default_max_crawl_pages() -> int: + return int(get_settings().max_crawl_pages) + + +def _default_max_file_size() -> int: + return int(get_settings().max_file_size) + class IngestionStatus(str, Enum): """Status of an ingestion job.""" @@ -10414,36 +11778,39 @@ class IngestionSource(str, Enum): class VectorConfig(BaseModel): """Configuration for vectorization.""" - model: str = Field(default="ollama/bge-m3:latest") - embedding_endpoint: HttpUrl = Field(default=HttpUrl("http://llm.lab")) - dimension: int = Field(default=1024) - batch_size: Annotated[int, Field(gt=0, le=1000)] = 100 + model: str = Field(default_factory=_default_embedding_model) + embedding_endpoint: HttpUrl = Field(default_factory=_default_embedding_endpoint) + dimension: int = Field(default_factory=_default_embedding_dimension) + batch_size: Annotated[int, Field(gt=0, le=1000)] = Field(default_factory=_default_batch_size) class StorageConfig(Block): """Configuration for storage backend.""" - _block_type_name: ClassVar[str] = "Storage Configuration" - _block_type_slug: ClassVar[str] = "storage-config" - _description: ClassVar[str] = "Configures storage backend connections and settings for document ingestion" + _block_type_name: ClassVar[str | None] = "Storage Configuration" + _block_type_slug: ClassVar[str | None] = "storage-config" + _description: ClassVar[str | None] = ( + "Configures storage backend connections and settings for document ingestion" + ) backend: StorageBackend endpoint: HttpUrl api_key: SecretStr | None = Field(default=None) - collection_name: str = Field(default="documents") - batch_size: Annotated[int, Field(gt=0, le=1000)] = 100 + collection_name: str = Field(default_factory=_default_collection_name) + batch_size: Annotated[int, Field(gt=0, le=1000)] = Field(default_factory=_default_batch_size) + grpc_port: int | None = Field(default=None, description="gRPC port for Weaviate connections") class FirecrawlConfig(Block): """Configuration for Firecrawl ingestion (operational parameters only).""" - _block_type_name: ClassVar[str] = "Firecrawl Configuration" - _block_type_slug: ClassVar[str] = "firecrawl-config" - _description: ClassVar[str] = "Configures Firecrawl web scraping and crawling parameters" + _block_type_name: ClassVar[str | None] = "Firecrawl Configuration" + _block_type_slug: ClassVar[str | None] = "firecrawl-config" + _description: ClassVar[str | None] = "Configures Firecrawl web scraping and crawling parameters" formats: list[str] = Field(default_factory=lambda: ["markdown", "html"]) - max_depth: Annotated[int, Field(ge=1, le=20)] = 5 - limit: Annotated[int, Field(ge=1, le=1000)] = 100 + max_depth: Annotated[int, Field(ge=1, le=20)] = Field(default_factory=_default_max_crawl_depth) + limit: Annotated[int, Field(ge=1, le=1000)] = Field(default_factory=_default_max_crawl_pages) only_main_content: bool = Field(default=True) include_subdomains: bool = Field(default=False) @@ -10451,9 +11818,11 @@ class FirecrawlConfig(Block): class RepomixConfig(Block): """Configuration for Repomix ingestion.""" - _block_type_name: ClassVar[str] = "Repomix Configuration" - _block_type_slug: ClassVar[str] = "repomix-config" - _description: ClassVar[str] = "Configures repository ingestion patterns and file processing settings" + _block_type_name: ClassVar[str | None] = "Repomix Configuration" + _block_type_slug: ClassVar[str | None] = "repomix-config" + _description: ClassVar[str | None] = ( + "Configures repository ingestion patterns and file processing settings" + ) include_patterns: list[str] = Field( default_factory=lambda: ["*.py", "*.js", "*.ts", "*.md", "*.yaml", "*.json"] @@ -10461,16 +11830,18 @@ class RepomixConfig(Block): exclude_patterns: list[str] = Field( default_factory=lambda: ["**/node_modules/**", "**/__pycache__/**", "**/.git/**"] ) - max_file_size: int = Field(default=1_000_000) # 1MB + max_file_size: int = Field(default_factory=_default_max_file_size) # 1MB respect_gitignore: bool = Field(default=True) class R2RConfig(Block): """Configuration for R2R ingestion.""" - _block_type_name: ClassVar[str] = "R2R Configuration" - _block_type_slug: ClassVar[str] = "r2r-config" - _description: ClassVar[str] = "Configures R2R-specific ingestion settings including chunking and graph enrichment" + _block_type_name: ClassVar[str | None] = "R2R Configuration" + _block_type_slug: ClassVar[str | None] = "r2r-config" + _description: ClassVar[str | None] = ( + "Configures R2R-specific ingestion settings including chunking and graph enrichment" + ) chunk_size: Annotated[int, Field(ge=100, le=8192)] = 1000 chunk_overlap: Annotated[int, Field(ge=0, le=1000)] = 200 @@ -10480,6 +11851,7 @@ class R2RConfig(Block): class DocumentMetadataRequired(TypedDict): """Required metadata fields for a document.""" + source_url: str timestamp: datetime content_type: str @@ -10543,7 +11915,7 @@ class Document(BaseModel): vector: list[float] | None = Field(default=None) score: float | None = Field(default=None) source: IngestionSource - collection: str = Field(default="documents") + collection: str = Field(default_factory=_default_collection_name) class IngestionJob(BaseModel): @@ -10572,734 +11944,6 @@ class IngestionResult(BaseModel): error_messages: list[str] = Field(default_factory=list) - -"""Prefect flow for ingestion pipeline.""" - -from __future__ import annotations - -from collections.abc import Callable -from datetime import UTC, datetime -from typing import TYPE_CHECKING, Literal, TypeAlias, assert_never, cast - -from prefect import flow, get_run_logger, task -from prefect.blocks.core import Block -from prefect.variables import Variable -from pydantic import SecretStr - -from ..config.settings import Settings -from ..core.exceptions import IngestionError -from ..core.models import ( - Document, - FirecrawlConfig, - IngestionJob, - IngestionResult, - IngestionSource, - IngestionStatus, - RepomixConfig, - StorageBackend, - StorageConfig, -) -from ..ingestors import BaseIngestor, FirecrawlIngestor, FirecrawlPage, RepomixIngestor -from ..storage import OpenWebUIStorage, WeaviateStorage -from ..storage import R2RStorage as RuntimeR2RStorage -from ..storage.base import BaseStorage -from ..utils.metadata_tagger import MetadataTagger - -SourceTypeLiteral = Literal["web", "repository", "documentation"] -StorageBackendLiteral = Literal["weaviate", "open_webui", "r2r"] -SourceTypeLike: TypeAlias = IngestionSource | SourceTypeLiteral -StorageBackendLike: TypeAlias = StorageBackend | StorageBackendLiteral - - -def _safe_cache_key(prefix: str, params: dict[str, object], key: str) -> str: - """Create a type-safe cache key from task parameters.""" - value = params.get(key, "") - return f"{prefix}_{hash(str(value))}" - - -if TYPE_CHECKING: - from ..storage.r2r.storage import R2RStorage as R2RStorageType -else: - R2RStorageType = BaseStorage - - -@task(name="validate_source", retries=2, retry_delay_seconds=10, tags=["validation"]) -async def validate_source_task(source_url: str, source_type: IngestionSource) -> bool: - """ - Validate that a source is accessible. - - Args: - source_url: URL or path to source - source_type: Type of source - - Returns: - True if valid - """ - if source_type == IngestionSource.WEB: - ingestor = FirecrawlIngestor() - elif source_type == IngestionSource.REPOSITORY: - ingestor = RepomixIngestor() - else: - raise ValueError(f"Unsupported source type: {source_type}") - - result = await ingestor.validate_source(source_url) - return bool(result) - - -@task(name="initialize_storage", retries=3, retry_delay_seconds=5, tags=["storage"]) -async def initialize_storage_task(config: StorageConfig | str) -> BaseStorage: - """ - Initialize storage backend. - - Args: - config: Storage configuration block or block name - - Returns: - Initialized storage adapter - """ - # Load block if string provided - if isinstance(config, str): - # Use Block.aload with type slug for better type inference - loaded_block = await Block.aload(f"storage-config/{config}") - config = cast(StorageConfig, loaded_block) - - if config.backend == StorageBackend.WEAVIATE: - storage = WeaviateStorage(config) - elif config.backend == StorageBackend.OPEN_WEBUI: - storage = OpenWebUIStorage(config) - elif config.backend == StorageBackend.R2R: - if RuntimeR2RStorage is None: - raise ValueError("R2R storage not available. Check dependencies.") - storage = RuntimeR2RStorage(config) - else: - raise ValueError(f"Unsupported backend: {config.backend}") - - await storage.initialize() - return storage - - -@task(name="map_firecrawl_site", retries=2, retry_delay_seconds=15, tags=["firecrawl", "map"], - cache_key_fn=lambda ctx, p: _safe_cache_key("firecrawl_map", p, "source_url")) -async def map_firecrawl_site_task(source_url: str, config: FirecrawlConfig | str) -> list[str]: - """Map a site using Firecrawl and return discovered URLs.""" - # Load block if string provided - if isinstance(config, str): - # Use Block.aload with type slug for better type inference - loaded_block = await Block.aload(f"firecrawl-config/{config}") - config = cast(FirecrawlConfig, loaded_block) - - ingestor = FirecrawlIngestor(config) - mapped = await ingestor.map_site(source_url) - return mapped or [source_url] - - -@task(name="filter_existing_documents", retries=1, retry_delay_seconds=5, tags=["dedup"], - cache_key_fn=lambda ctx, p: _safe_cache_key("filter_docs", p, "urls")) # Cache based on URL list -async def filter_existing_documents_task( - urls: list[str], - storage_client: BaseStorage, - stale_after_days: int = 30, - *, - collection_name: str | None = None, -) -> list[str]: - """Filter URLs to only those that need scraping (missing or stale in storage).""" - logger = get_run_logger() - eligible: list[str] = [] - - for url in urls: - document_id = str(FirecrawlIngestor.compute_document_id(url)) - exists = await storage_client.check_exists( - document_id, - collection_name=collection_name, - stale_after_days=stale_after_days - ) - - if not exists: - eligible.append(url) - - skipped = len(urls) - len(eligible) - if skipped > 0: - logger.info("Skipping %s up-to-date documents in %s", skipped, storage_client.display_name) - - return eligible - - -@task( - name="scrape_firecrawl_batch", retries=2, retry_delay_seconds=20, tags=["firecrawl", "scrape"] -) -async def scrape_firecrawl_batch_task( - batch_urls: list[str], config: FirecrawlConfig -) -> list[FirecrawlPage]: - """Scrape a batch of URLs via Firecrawl.""" - ingestor = FirecrawlIngestor(config) - result: list[FirecrawlPage] = await ingestor.scrape_pages(batch_urls) - return result - - -@task(name="annotate_firecrawl_metadata", retries=1, retry_delay_seconds=10, tags=["metadata"]) -async def annotate_firecrawl_metadata_task( - pages: list[FirecrawlPage], job: IngestionJob -) -> list[Document]: - """Annotate scraped pages with standardized metadata.""" - if not pages: - return [] - - ingestor = FirecrawlIngestor() - documents = [ingestor.create_document(page, job) for page in pages] - - try: - from ..config import get_settings - - settings = get_settings() - async with MetadataTagger(llm_endpoint=str(settings.llm_endpoint)) as tagger: - tagged_documents: list[Document] = await tagger.tag_batch(documents) - return tagged_documents - except IngestionError as exc: # pragma: no cover - logging side effect - logger = get_run_logger() - logger.warning("Metadata tagging failed: %s", exc) - return documents - except Exception as exc: # pragma: no cover - defensive - logger = get_run_logger() - logger.warning("Metadata tagging unavailable, using base metadata: %s", exc) - return documents - - -@task(name="upsert_r2r_documents", retries=2, retry_delay_seconds=20, tags=["storage", "r2r"]) -async def upsert_r2r_documents_task( - storage_client: R2RStorageType, - documents: list[Document], - collection_name: str | None, -) -> tuple[int, int]: - """Upsert documents into R2R storage.""" - if not documents: - return 0, 0 - - stored_ids: list[str] = await storage_client.store_batch( - documents, collection_name=collection_name - ) - processed = len(stored_ids) - failed = len(documents) - processed - - if failed: - logger = get_run_logger() - logger.warning("Failed to upsert %s documents to R2R", failed) - - return processed, failed - - -@task(name="ingest_documents", retries=2, retry_delay_seconds=30, tags=["ingestion"]) -async def ingest_documents_task( - job: IngestionJob, - collection_name: str | None = None, - batch_size: int | None = None, - storage_client: BaseStorage | None = None, - storage_block_name: str | None = None, - ingestor_config_block_name: str | None = None, - progress_callback: Callable[[int, str], None] | None = None, -) -> tuple[int, int]: - """ - Ingest documents from source with optional pre-initialized storage client. - - Args: - job: Ingestion job configuration - collection_name: Target collection name - batch_size: Number of documents per batch (uses Variable if None) - storage_client: Optional pre-initialized storage client - storage_block_name: Optional storage block name to load - ingestor_config_block_name: Optional ingestor config block name to load - progress_callback: Optional callback for progress updates - - Returns: - Tuple of (processed_count, failed_count) - """ - if progress_callback: - progress_callback(35, "Creating ingestor and storage clients...") - - # Use Variable for batch size if not provided - if batch_size is None: - try: - batch_size_var = await Variable.aget("default_batch_size", default="50") - # Convert Variable result to int, handling various types - if isinstance(batch_size_var, int): - batch_size = batch_size_var - elif isinstance(batch_size_var, (str, float)): - batch_size = int(float(str(batch_size_var))) - else: - batch_size = 50 - except Exception: - batch_size = 50 - - ingestor = await _create_ingestor(job, ingestor_config_block_name) - storage = storage_client or await _create_storage(job, collection_name, storage_block_name) - - if progress_callback: - progress_callback(40, "Starting document processing...") - - return await _process_documents(ingestor, storage, job, batch_size, collection_name, progress_callback) - - -async def _create_ingestor(job: IngestionJob, config_block_name: str | None = None) -> BaseIngestor: - """Create appropriate ingestor based on job source type.""" - if job.source_type == IngestionSource.WEB: - if config_block_name: - # Use Block.aload with type slug for better type inference - loaded_block = await Block.aload(f"firecrawl-config/{config_block_name}") - config = cast(FirecrawlConfig, loaded_block) - else: - # Fallback to default configuration - config = FirecrawlConfig() - return FirecrawlIngestor(config) - elif job.source_type == IngestionSource.REPOSITORY: - if config_block_name: - # Use Block.aload with type slug for better type inference - loaded_block = await Block.aload(f"repomix-config/{config_block_name}") - config = cast(RepomixConfig, loaded_block) - else: - # Fallback to default configuration - config = RepomixConfig() - return RepomixIngestor(config) - else: - raise ValueError(f"Unsupported source: {job.source_type}") - - -async def _create_storage(job: IngestionJob, collection_name: str | None, storage_block_name: str | None = None) -> BaseStorage: - """Create and initialize storage client.""" - if collection_name is None: - # Use variable for default collection prefix - prefix = await Variable.aget("default_collection_prefix", default="docs") - collection_name = f"{prefix}_{job.source_type.value}" - - if storage_block_name: - # Load storage config from block - loaded_block = await Block.aload(f"storage-config/{storage_block_name}") - storage_config = cast(StorageConfig, loaded_block) - # Override collection name if provided - storage_config.collection_name = collection_name - else: - # Fallback to building config from settings - from ..config import get_settings - settings = get_settings() - storage_config = _build_storage_config(job, settings, collection_name) - - storage = _instantiate_storage(job.storage_backend, storage_config) - await storage.initialize() - return storage - - -def _build_storage_config( - job: IngestionJob, settings: Settings, collection_name: str -) -> StorageConfig: - """Build storage configuration from job and settings.""" - storage_endpoints = { - StorageBackend.WEAVIATE: settings.weaviate_endpoint, - StorageBackend.OPEN_WEBUI: settings.openwebui_endpoint, - StorageBackend.R2R: settings.get_storage_endpoint("r2r"), - } - storage_api_keys: dict[StorageBackend, str | None] = { - StorageBackend.WEAVIATE: settings.get_api_key("weaviate"), - StorageBackend.OPEN_WEBUI: settings.get_api_key("openwebui"), - StorageBackend.R2R: None, # R2R is self-hosted, no API key needed - } - - api_key_raw: str | None = storage_api_keys[job.storage_backend] - api_key: SecretStr | None = SecretStr(api_key_raw) if api_key_raw is not None else None - - return StorageConfig( - backend=job.storage_backend, - endpoint=storage_endpoints[job.storage_backend], - api_key=api_key, - collection_name=collection_name, - ) - - -def _instantiate_storage(backend: StorageBackend, config: StorageConfig) -> BaseStorage: - """Instantiate storage based on backend type.""" - if backend == StorageBackend.WEAVIATE: - return WeaviateStorage(config) - elif backend == StorageBackend.OPEN_WEBUI: - return OpenWebUIStorage(config) - elif backend == StorageBackend.R2R: - if RuntimeR2RStorage is None: - raise ValueError("R2R storage not available. Check dependencies.") - return RuntimeR2RStorage(config) - - assert_never(backend) - - -def _chunk_urls(urls: list[str], chunk_size: int) -> list[list[str]]: - """Group URLs into fixed-size chunks for batch processing.""" - - if chunk_size <= 0: - raise ValueError("chunk_size must be greater than zero") - - return [urls[i : i + chunk_size] for i in range(0, len(urls), chunk_size)] - - -def _deduplicate_urls(urls: list[str]) -> list[str]: - """Return the URLs with order preserved and duplicates removed.""" - - seen: set[str] = set() - unique: list[str] = [] - for url in urls: - if url not in seen: - seen.add(url) - unique.append(url) - return unique - - -async def _process_documents( - ingestor: BaseIngestor, - storage: BaseStorage, - job: IngestionJob, - batch_size: int, - collection_name: str | None, - progress_callback: Callable[[int, str], None] | None = None, -) -> tuple[int, int]: - """Process documents in batches.""" - processed = 0 - failed = 0 - batch: list[Document] = [] - total_documents = 0 - batch_count = 0 - - if progress_callback: - progress_callback(45, "Ingesting documents from source...") - - # Use smart ingestion with deduplication if storage supports it - if hasattr(storage, 'check_exists'): - try: - # Try to use the smart ingestion method - document_generator = ingestor.ingest_with_dedup( - job, storage, collection_name=collection_name - ) - except Exception: - # Fall back to regular ingestion if smart method fails - document_generator = ingestor.ingest(job) - else: - document_generator = ingestor.ingest(job) - - async for document in document_generator: - batch.append(document) - total_documents += 1 - - if len(batch) >= batch_size: - batch_count += 1 - if progress_callback: - progress_callback( - 45 + min(35, (batch_count * 10)), - f"Processing batch {batch_count} ({total_documents} documents so far)..." - ) - - batch_processed, batch_failed = await _store_batch(storage, batch, collection_name) - processed += batch_processed - failed += batch_failed - batch = [] - - # Process remaining batch - if batch: - batch_count += 1 - if progress_callback: - progress_callback(80, f"Processing final batch ({total_documents} total documents)...") - - batch_processed, batch_failed = await _store_batch(storage, batch, collection_name) - processed += batch_processed - failed += batch_failed - - if progress_callback: - progress_callback(85, f"Completed processing {total_documents} documents") - - return processed, failed - - -async def _store_batch( - storage: BaseStorage, - batch: list[Document], - collection_name: str | None, -) -> tuple[int, int]: - """Store a batch of documents and return processed/failed counts.""" - try: - # Apply metadata tagging for backends that benefit from it - processed_batch = batch - if hasattr(storage, "config") and storage.config.backend in ( - StorageBackend.R2R, - StorageBackend.WEAVIATE, - ): - try: - from ..config import get_settings - - settings = get_settings() - async with MetadataTagger(llm_endpoint=str(settings.llm_endpoint)) as tagger: - processed_batch = await tagger.tag_batch(batch) - except Exception as exc: - print(f"Metadata tagging failed, using original documents: {exc}") - processed_batch = batch - - stored_ids = await storage.store_batch(processed_batch, collection_name=collection_name) - processed_count = len(stored_ids) - failed_count = len(processed_batch) - processed_count - - batch_type = ( - "final" if len(processed_batch) < 50 else "" - ) # Assume standard batch size is 50 - print(f"Successfully stored {processed_count} documents in {batch_type} batch".strip()) - - return processed_count, failed_count - except Exception as e: - batch_type = "Final" if len(batch) < 50 else "Batch" - print(f"{batch_type} storage failed: {e}") - return 0, len(batch) - - -@flow( - name="firecrawl_to_r2r", - description="Ingest Firecrawl pages into R2R with metadata annotation", - persist_result=False, - log_prints=True, -) -async def firecrawl_to_r2r_flow( - job: IngestionJob, collection_name: str | None = None, progress_callback: Callable[[int, str], None] | None = None -) -> tuple[int, int]: - """Specialized flow for Firecrawl ingestion into R2R.""" - logger = get_run_logger() - from ..config import get_settings - - if progress_callback: - progress_callback(35, "Initializing Firecrawl and R2R storage...") - - settings = get_settings() - firecrawl_config = FirecrawlConfig() - resolved_collection = collection_name or f"docs_{job.source_type.value}" - - storage_config = _build_storage_config(job, settings, resolved_collection) - storage_client = await initialize_storage_task(storage_config) - - if RuntimeR2RStorage is None or not isinstance(storage_client, RuntimeR2RStorage): - raise IngestionError("Firecrawl to R2R flow requires an R2R storage backend") - - r2r_storage = cast("R2RStorageType", storage_client) - - if progress_callback: - progress_callback(45, "Checking for existing content before mapping...") - - # Smart mapping: try single URL first to avoid expensive map operation - base_url = str(job.source_url) - single_url_id = str(FirecrawlIngestor.compute_document_id(base_url)) - base_exists = await r2r_storage.check_exists( - single_url_id, collection_name=resolved_collection, stale_after_days=30 - ) - - if base_exists: - # Check if this is a recent single-page update - logger.info("Base URL %s exists and is fresh, skipping expensive mapping", base_url) - if progress_callback: - progress_callback(100, "Content is up to date, no processing needed") - return 0, 0 - - if progress_callback: - progress_callback(50, "Discovering pages with Firecrawl...") - - discovered_urls = await map_firecrawl_site_task(base_url, firecrawl_config) - unique_urls = _deduplicate_urls(discovered_urls) - logger.info("Discovered %s unique URLs from Firecrawl map", len(unique_urls)) - - if progress_callback: - progress_callback(60, f"Found {len(unique_urls)} pages, filtering existing content...") - - eligible_urls = await filter_existing_documents_task( - unique_urls, r2r_storage, collection_name=resolved_collection - ) - - if not eligible_urls: - logger.info("All Firecrawl pages are up to date for %s", job.source_url) - if progress_callback: - progress_callback(100, "All pages are up to date, no processing needed") - return 0, 0 - - if progress_callback: - progress_callback(70, f"Scraping {len(eligible_urls)} new/updated pages...") - - batch_size = min(settings.default_batch_size, firecrawl_config.limit) - url_batches = _chunk_urls(eligible_urls, batch_size) - logger.info("Scraping %s batches of Firecrawl pages", len(url_batches)) - - # Use asyncio.gather for concurrent scraping - import asyncio - scrape_tasks = [ - scrape_firecrawl_batch_task(batch, firecrawl_config) - for batch in url_batches - ] - batch_results = await asyncio.gather(*scrape_tasks) - - scraped_pages: list[FirecrawlPage] = [] - for batch_pages in batch_results: - scraped_pages.extend(batch_pages) - - if progress_callback: - progress_callback(80, f"Processing {len(scraped_pages)} scraped pages...") - - documents = await annotate_firecrawl_metadata_task(scraped_pages, job) - - if not documents: - logger.warning("No documents produced after scraping for %s", job.source_url) - return 0, len(eligible_urls) - - if progress_callback: - progress_callback(90, f"Storing {len(documents)} documents in R2R...") - - processed, failed = await upsert_r2r_documents_task(r2r_storage, documents, resolved_collection) - - logger.info("Upserted %s documents into R2R (%s failed)", processed, failed) - - return processed, failed - - -@task(name="update_job_status", tags=["tracking"]) -async def update_job_status_task( - job: IngestionJob, - status: IngestionStatus, - processed: int = 0, - _failed: int = 0, - error: str | None = None, -) -> IngestionJob: - """ - Update job status. - - Args: - job: Ingestion job - status: New status - processed: Documents processed - _failed: Documents failed (currently unused) - error: Error message if any - - Returns: - Updated job - """ - job.status = status - job.updated_at = datetime.now(UTC) - job.document_count = processed - - if status == IngestionStatus.COMPLETED: - job.completed_at = datetime.now(UTC) - - if error: - job.error_message = error - - return job - - -@flow( - name="ingestion_pipeline", - description="Main ingestion pipeline for documents", - retries=1, - retry_delay_seconds=60, - persist_result=True, - log_prints=True, -) -async def create_ingestion_flow( - source_url: str, - source_type: SourceTypeLike, - storage_backend: StorageBackendLike = StorageBackend.WEAVIATE, - collection_name: str | None = None, - validate_first: bool = True, - progress_callback: Callable[[int, str], None] | None = None, -) -> IngestionResult: - """ - Main ingestion flow. - - Args: - source_url: URL or path to source - source_type: Type of source - storage_backend: Storage backend to use - validate_first: Whether to validate source first - progress_callback: Optional callback for progress updates - - Returns: - Ingestion result - """ - print(f"Starting ingestion from {source_url}") - - source_enum = IngestionSource(source_type) - backend_enum = StorageBackend(storage_backend) - - # Create job - job = IngestionJob( - source_url=source_url, - source_type=source_enum, - storage_backend=backend_enum, - status=IngestionStatus.PENDING, - ) - - start_time = datetime.now(UTC) - error_messages: list[str] = [] - processed = 0 - failed = 0 - - try: - # Validate source if requested - if validate_first: - if progress_callback: - progress_callback(10, "Validating source...") - print("Validating source...") - is_valid = await validate_source_task(source_url, job.source_type) - - if not is_valid: - raise IngestionError(f"Source validation failed: {source_url}") - - # Update status to in progress - if progress_callback: - progress_callback(20, "Initializing storage...") - job = await update_job_status_task(job, IngestionStatus.IN_PROGRESS) - - # Run ingestion - if progress_callback: - progress_callback(30, "Starting document ingestion...") - print("Ingesting documents...") - if job.source_type == IngestionSource.WEB and job.storage_backend == StorageBackend.R2R: - processed, failed = await firecrawl_to_r2r_flow(job, collection_name, progress_callback=progress_callback) - else: - processed, failed = await ingest_documents_task(job, collection_name, progress_callback=progress_callback) - - if progress_callback: - progress_callback(90, "Finalizing ingestion...") - - # Update final status - if failed > 0: - error_messages.append(f"{failed} documents failed to process") - - # Set status based on results - if processed == 0 and failed > 0: - final_status = IngestionStatus.FAILED - elif failed > 0: - final_status = IngestionStatus.PARTIAL - else: - final_status = IngestionStatus.COMPLETED - - job = await update_job_status_task(job, final_status, processed=processed, _failed=failed) - - print(f"Ingestion completed: {processed} processed, {failed} failed") - - except Exception as e: - print(f"Ingestion failed: {e}") - error_messages.append(str(e)) - - # Don't reset counts - keep whatever was processed before the error - job = await update_job_status_task( - job, IngestionStatus.FAILED, processed=processed, _failed=failed, error=str(e) - ) - - # Calculate duration - duration = (datetime.now(UTC) - start_time).total_seconds() - - return IngestionResult( - job_id=job.id, - status=job.status, - documents_processed=processed, - documents_failed=failed, - duration_seconds=duration, - error_messages=error_messages, - ) - - """R2R storage implementation using the official R2R SDK.""" @@ -11307,21 +11951,23 @@ from __future__ import annotations import asyncio import contextlib +import logging from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence from datetime import UTC, datetime -from typing import Self, TypeVar, cast +from typing import Final, Self, TypeVar, cast from uuid import UUID, uuid4 # Direct imports for runtime and type checking -# Note: Some type checkers (basedpyright/Pyrefly) may report import issues -# but these work correctly at runtime and with mypy -from httpx import AsyncClient, HTTPStatusError -from r2r import R2RAsyncClient, R2RException +from httpx import AsyncClient, HTTPStatusError # type: ignore +from r2r import R2RAsyncClient, R2RException # type: ignore from typing_extensions import override from ...core.exceptions import StorageError from ...core.models import Document, DocumentMetadata, IngestionSource, StorageConfig from ..base import BaseStorage +from ..types import DocumentInfo + +LOGGER: Final[logging.Logger] = logging.getLogger(__name__) T = TypeVar("T") @@ -11383,6 +12029,24 @@ class R2RStorage(BaseStorage): self.client: R2RAsyncClient = R2RAsyncClient(self.endpoint) self.default_collection_id: str | None = None + def _get_http_client_headers(self) -> dict[str, str]: + """Get consistent HTTP headers for direct API calls.""" + headers = {"Content-Type": "application/json"} + + # Add authentication headers if available + # Note: R2R SDK may handle auth internally, so we extract it if possible + if hasattr(self.client, "_get_headers"): + with contextlib.suppress(Exception): + sdk_headers = self.client._get_headers() # type: ignore[attr-defined] + if isinstance(sdk_headers, dict): + headers |= sdk_headers + return headers + + def _create_http_client(self) -> AsyncClient: + """Create a properly configured HTTP client for direct API calls.""" + headers = self._get_http_client_headers() + return AsyncClient(headers=headers, timeout=30.0) + @override async def initialize(self) -> None: """Initialize R2R connection and ensure default collection exists.""" @@ -11399,7 +12063,7 @@ class R2RStorage(BaseStorage): # Test connection using direct HTTP call to v3 API endpoint = self.endpoint - client = AsyncClient() + client = self._create_http_client() try: response = await client.get(f"{endpoint}/v3/collections") response.raise_for_status() @@ -11412,7 +12076,7 @@ class R2RStorage(BaseStorage): async def _ensure_collection(self, collection_name: str) -> str: """Get or create collection by name.""" endpoint = self.endpoint - client = AsyncClient() + client = self._create_http_client() try: # List collections and find by name response = await client.get(f"{endpoint}/v3/collections") @@ -11455,6 +12119,9 @@ class R2RStorage(BaseStorage): finally: await client.aclose() + # This should never be reached, but satisfies static analyzer + raise StorageError(f"Unexpected code path in _ensure_collection for '{collection_name}'") + @override async def store(self, document: Document, *, collection_name: str | None = None) -> str: """Store a single document.""" @@ -11464,20 +12131,46 @@ class R2RStorage(BaseStorage): async def store_batch( self, documents: list[Document], *, collection_name: str | None = None ) -> list[str]: - """Store multiple documents.""" + """Store multiple documents efficiently with connection reuse.""" collection_id = await self._resolve_collection_id(collection_name) - print( - f"Using collection ID: {collection_id} for collection: {collection_name or self.config.collection_name}" + LOGGER.info( + "Using collection ID: %s for collection: %s", + collection_id, + collection_name or self.config.collection_name, ) - stored_ids: list[str] = [] - for document in documents: - if not self._is_document_valid(document): - continue + # Filter valid documents upfront + valid_documents = [doc for doc in documents if self._is_document_valid(doc)] + if not valid_documents: + return [] - stored_id = await self._store_single_document(document, collection_id) - if stored_id: - stored_ids.append(stored_id) + stored_ids: list[str] = [] + + # Use a single HTTP client for all requests + http_client = AsyncClient() + async with http_client: # type: ignore + # Process documents with controlled concurrency + import asyncio + + semaphore = asyncio.Semaphore(5) # Limit concurrent uploads + + async def store_single_with_client(document: Document) -> str | None: + async with semaphore: + return await self._store_single_document_with_client( + document, collection_id, http_client + ) + + # Execute all uploads concurrently + results = await asyncio.gather( + *[store_single_with_client(doc) for doc in valid_documents], return_exceptions=True + ) + + # Collect successful IDs + for result in results: + if isinstance(result, str): + stored_ids.append(result) + elif isinstance(result, Exception): + LOGGER.error("Document upload failed: %s", result) return stored_ids @@ -11498,12 +12191,14 @@ class R2RStorage(BaseStorage): requested_id = str(document.id) if not document.content or not document.content.strip(): - print(f"Skipping document {requested_id}: empty content") + LOGGER.warning("Skipping document %s: empty content", requested_id) return False if len(document.content) > 1_000_000: # 1MB limit - print( - f"Skipping document {requested_id}: content too large ({len(document.content)} chars)" + LOGGER.warning( + "Skipping document %s: content too large (%d chars)", + requested_id, + len(document.content), ) return False @@ -11511,23 +12206,41 @@ class R2RStorage(BaseStorage): async def _store_single_document(self, document: Document, collection_id: str) -> str | None: """Store a single document with retry logic.""" + http_client = AsyncClient() + async with http_client: # type: ignore + return await self._store_single_document_with_client( + document, collection_id, http_client + ) + + async def _store_single_document_with_client( + self, document: Document, collection_id: str, http_client: AsyncClient + ) -> str | None: + """Store a single document with retry logic using provided HTTP client.""" requested_id = str(document.id) - print(f"Creating document with ID: {requested_id}") + LOGGER.debug("Creating document with ID: %s", requested_id) max_retries = 3 retry_delay = 1.0 for attempt in range(max_retries): try: - doc_response = await self._attempt_document_creation(document, collection_id) + doc_response = await self._attempt_document_creation_with_client( + document, collection_id, http_client + ) if doc_response: - return self._process_document_response(doc_response, requested_id, collection_id) + return self._process_document_response( + doc_response, requested_id, collection_id + ) except (TimeoutError, OSError) as e: - if not await self._should_retry_timeout(e, attempt, max_retries, requested_id, retry_delay): + if not await self._should_retry_timeout( + e, attempt, max_retries, requested_id, retry_delay + ): break retry_delay *= 2 except HTTPStatusError as e: - if not await self._should_retry_http_error(e, attempt, max_retries, requested_id, retry_delay): + if not await self._should_retry_http_error( + e, attempt, max_retries, requested_id, retry_delay + ): break retry_delay *= 2 except Exception as exc: @@ -11536,13 +12249,25 @@ class R2RStorage(BaseStorage): return None - async def _attempt_document_creation(self, document: Document, collection_id: str) -> dict[str, object] | None: + async def _attempt_document_creation( + self, document: Document, collection_id: str + ) -> dict[str, object] | None: """Attempt to create a document via HTTP API.""" + http_client = AsyncClient() + async with http_client: # type: ignore + return await self._attempt_document_creation_with_client( + document, collection_id, http_client + ) + + async def _attempt_document_creation_with_client( + self, document: Document, collection_id: str, http_client: AsyncClient + ) -> dict[str, object] | None: + """Attempt to create a document via HTTP API using provided client.""" import json requested_id = str(document.id) metadata = self._build_metadata(document) - print(f"Built metadata for document {requested_id}: {metadata}") + LOGGER.debug("Built metadata for document %s: %s", requested_id, metadata) files = { "raw_text": (None, document.content), @@ -11553,86 +12278,121 @@ class R2RStorage(BaseStorage): if collection_id: files["collection_ids"] = (None, json.dumps([collection_id])) - print(f"Creating document {requested_id} with collection_ids: [{collection_id}]") + LOGGER.debug( + "Creating document %s with collection_ids: [%s]", requested_id, collection_id + ) - print(f"Sending to R2R - files keys: {list(files.keys())}") - print(f"Metadata JSON: {files['metadata'][1]}") + LOGGER.debug("Sending to R2R - files keys: %s", list(files.keys())) + LOGGER.debug("Metadata JSON: %s", files["metadata"][1]) - async with AsyncClient() as http_client: - response = await http_client.post(f"{self.endpoint}/v3/documents", files=files) + response = await http_client.post(f"{self.endpoint}/v3/documents", files=files) # type: ignore[call-arg] - if response.status_code == 422: - self._handle_validation_error(response, requested_id, metadata) - return None + if response.status_code == 422: + self._handle_validation_error(response, requested_id, metadata) + return None - response.raise_for_status() - return response.json() + response.raise_for_status() + return response.json() - def _handle_validation_error(self, response: object, requested_id: str, metadata: dict[str, object]) -> None: + def _handle_validation_error( + self, response: object, requested_id: str, metadata: dict[str, object] + ) -> None: """Handle validation errors from R2R API.""" try: - error_detail = getattr(response, 'json', lambda: {})() if hasattr(response, 'json') else {} - print(f"R2R validation error for document {requested_id}: {error_detail}") - print(f"Document metadata sent: {metadata}") - print(f"Response status: {getattr(response, 'status_code', 'unknown')}") - print(f"Response headers: {dict(getattr(response, 'headers', {}))}") + error_detail = ( + getattr(response, "json", lambda: {})() if hasattr(response, "json") else {} + ) + LOGGER.error("R2R validation error for document %s: %s", requested_id, error_detail) + LOGGER.error("Document metadata sent: %s", metadata) + LOGGER.error("Response status: %s", getattr(response, "status_code", "unknown")) + LOGGER.error("Response headers: %s", dict(getattr(response, "headers", {}))) except Exception: - print(f"R2R validation error for document {requested_id}: {getattr(response, 'text', 'unknown error')}") - print(f"Document metadata sent: {metadata}") + LOGGER.error( + "R2R validation error for document %s: %s", + requested_id, + getattr(response, "text", "unknown error"), + ) + LOGGER.error("Document metadata sent: %s", metadata) - def _process_document_response(self, doc_response: dict[str, object], requested_id: str, collection_id: str) -> str: + def _process_document_response( + self, doc_response: dict[str, object], requested_id: str, collection_id: str + ) -> str: """Process successful document creation response.""" response_payload = doc_response.get("results", doc_response) doc_id = _extract_id(response_payload, requested_id) - print(f"R2R returned document ID: {doc_id}") + LOGGER.info("R2R returned document ID: %s", doc_id) if doc_id != requested_id: - print(f"Warning: Requested ID {requested_id} but got {doc_id}") + LOGGER.warning("Requested ID %s but got %s", requested_id, doc_id) if collection_id: - print(f"Document {doc_id} should be assigned to collection {collection_id} via creation API") + LOGGER.info( + "Document %s should be assigned to collection %s via creation API", + doc_id, + collection_id, + ) return doc_id - async def _should_retry_timeout(self, error: Exception, attempt: int, max_retries: int, requested_id: str, retry_delay: float) -> bool: + async def _should_retry_timeout( + self, + error: Exception, + attempt: int, + max_retries: int, + requested_id: str, + retry_delay: float, + ) -> bool: """Determine if timeout error should be retried.""" if attempt >= max_retries - 1: return False - print(f"Timeout for document {requested_id}, retrying in {retry_delay}s...") + LOGGER.warning("Timeout for document %s, retrying in %ss...", requested_id, retry_delay) await asyncio.sleep(retry_delay) return True - async def _should_retry_http_error(self, error: HTTPStatusError, attempt: int, max_retries: int, requested_id: str, retry_delay: float) -> bool: + async def _should_retry_http_error( + self, + error: HTTPStatusError, + attempt: int, + max_retries: int, + requested_id: str, + retry_delay: float, + ) -> bool: """Determine if HTTP error should be retried.""" - if error.response.status_code < 500 or attempt >= max_retries - 1: + status_code = error.response.status_code + if status_code < 500 or attempt >= max_retries - 1: return False - print(f"Server error {error.response.status_code} for document {requested_id}, retrying in {retry_delay}s...") + LOGGER.warning( + "Server error %s for document %s, retrying in %ss...", + status_code, + requested_id, + retry_delay, + ) await asyncio.sleep(retry_delay) return True def _log_document_error(self, document_id: object, exc: Exception) -> None: """Log document storage errors with specific categorization.""" - print(f"Failed to store document {document_id}: {exc}") + LOGGER.error("Failed to store document %s: %s", document_id, exc) exc_str = str(exc) if "422" in exc_str: - print(" → Data validation issue - check document content and metadata format") + LOGGER.error(" → Data validation issue - check document content and metadata format") elif "timeout" in exc_str.lower(): - print(" → Network timeout - R2R may be overloaded") + LOGGER.error(" → Network timeout - R2R may be overloaded") elif "500" in exc_str: - print(" → Server error - R2R internal issue") + LOGGER.error(" → Server error - R2R internal issue") else: import traceback + traceback.print_exc() def _build_metadata(self, document: Document) -> dict[str, object]: """Convert document metadata to enriched R2R format.""" metadata = document.metadata - # Core required fields result: dict[str, object] = { "source_url": metadata["source_url"], @@ -11768,7 +12528,9 @@ class R2RStorage(BaseStorage): except ValueError: return uuid4() - def _build_core_metadata(self, metadata_map: dict[str, object], timestamp: datetime) -> DocumentMetadata: + def _build_core_metadata( + self, metadata_map: dict[str, object], timestamp: datetime + ) -> DocumentMetadata: """Build core required metadata fields.""" return { "source_url": str(metadata_map.get("source_url", "")), @@ -11778,7 +12540,12 @@ class R2RStorage(BaseStorage): "char_count": _as_int(metadata_map.get("char_count")), } - def _add_optional_metadata_fields(self, metadata: DocumentMetadata, doc_map: dict[str, object], metadata_map: dict[str, object]) -> None: + def _add_optional_metadata_fields( + self, + metadata: DocumentMetadata, + doc_map: dict[str, object], + metadata_map: dict[str, object], + ) -> None: """Add optional metadata fields if present.""" self._add_title_and_description(metadata, doc_map, metadata_map) self._add_content_categorization(metadata, metadata_map) @@ -11787,7 +12554,12 @@ class R2RStorage(BaseStorage): self._add_processing_fields(metadata, metadata_map) self._add_quality_scores(metadata, metadata_map) - def _add_title_and_description(self, metadata: DocumentMetadata, doc_map: dict[str, object], metadata_map: dict[str, object]) -> None: + def _add_title_and_description( + self, + metadata: DocumentMetadata, + doc_map: dict[str, object], + metadata_map: dict[str, object], + ) -> None: """Add title and description fields.""" if title := (doc_map.get("title") or metadata_map.get("title")): metadata["title"] = cast(str | None, title) @@ -11797,7 +12569,9 @@ class R2RStorage(BaseStorage): elif description := metadata_map.get("description"): metadata["description"] = cast(str | None, description) - def _add_content_categorization(self, metadata: DocumentMetadata, metadata_map: dict[str, object]) -> None: + def _add_content_categorization( + self, metadata: DocumentMetadata, metadata_map: dict[str, object] + ) -> None: """Add content categorization fields.""" if tags := metadata_map.get("tags"): metadata["tags"] = [str(tag) for tag in tags] if isinstance(tags, list) else [] @@ -11808,7 +12582,9 @@ class R2RStorage(BaseStorage): if language := metadata_map.get("language"): metadata["language"] = str(language) - def _add_authorship_fields(self, metadata: DocumentMetadata, metadata_map: dict[str, object]) -> None: + def _add_authorship_fields( + self, metadata: DocumentMetadata, metadata_map: dict[str, object] + ) -> None: """Add authorship and source information fields.""" if author := metadata_map.get("author"): metadata["author"] = str(author) @@ -11817,7 +12593,9 @@ class R2RStorage(BaseStorage): if site_name := metadata_map.get("site_name"): metadata["site_name"] = str(site_name) - def _add_structure_fields(self, metadata: DocumentMetadata, metadata_map: dict[str, object]) -> None: + def _add_structure_fields( + self, metadata: DocumentMetadata, metadata_map: dict[str, object] + ) -> None: """Add document structure fields.""" if heading_hierarchy := metadata_map.get("heading_hierarchy"): metadata["heading_hierarchy"] = ( @@ -11832,7 +12610,9 @@ class R2RStorage(BaseStorage): if has_links := metadata_map.get("has_links"): metadata["has_links"] = bool(has_links) - def _add_processing_fields(self, metadata: DocumentMetadata, metadata_map: dict[str, object]) -> None: + def _add_processing_fields( + self, metadata: DocumentMetadata, metadata_map: dict[str, object] + ) -> None: """Add processing-related metadata fields.""" if extraction_method := metadata_map.get("extraction_method"): metadata["extraction_method"] = str(extraction_method) @@ -11841,7 +12621,9 @@ class R2RStorage(BaseStorage): if last_modified := metadata_map.get("last_modified"): metadata["last_modified"] = _as_datetime(last_modified) - def _add_quality_scores(self, metadata: DocumentMetadata, metadata_map: dict[str, object]) -> None: + def _add_quality_scores( + self, metadata: DocumentMetadata, metadata_map: dict[str, object] + ) -> None: """Add quality score fields with safe float conversion.""" if readability_score := metadata_map.get("readability_score"): try: @@ -11944,7 +12726,7 @@ class R2RStorage(BaseStorage): async def count(self, *, collection_name: str | None = None) -> int: """Get document count in collection.""" endpoint = self.endpoint - client = AsyncClient() + client = self._create_http_client() try: # Get collections and find the count for the specific collection response = await client.get(f"{endpoint}/v3/collections") @@ -11965,6 +12747,9 @@ class R2RStorage(BaseStorage): finally: await client.aclose() + # This should never be reached, but satisfies static analyzer + return 0 + @override async def close(self) -> None: """Close R2R client.""" @@ -12012,7 +12797,7 @@ class R2RStorage(BaseStorage): async def list_collections(self) -> list[str]: """List all available collections.""" endpoint = self.endpoint - client = AsyncClient() + client = self._create_http_client() try: response = await client.get(f"{endpoint}/v3/collections") response.raise_for_status() @@ -12029,6 +12814,9 @@ class R2RStorage(BaseStorage): finally: await client.aclose() + # This should never be reached, but satisfies static analyzer + return [] + async def list_collections_detailed(self) -> list[dict[str, object]]: """List all available collections with detailed information.""" try: @@ -12092,7 +12880,7 @@ class R2RStorage(BaseStorage): offset: int = 0, *, collection_name: str | None = None, - ) -> list[dict[str, object]]: + ) -> list[DocumentInfo]: """ List documents in R2R with pagination. @@ -12105,14 +12893,14 @@ class R2RStorage(BaseStorage): List of document dictionaries with metadata """ try: - documents: list[dict[str, object]] = [] + documents: list[DocumentInfo] = [] if collection_name: # Get collection ID first collection_id = await self._ensure_collection(collection_name) # Use the collections API to list documents in a specific collection endpoint = self.endpoint - client = AsyncClient() + client = self._create_http_client() try: params = {"offset": offset, "limit": limit} response = await client.get( @@ -12145,20 +12933,19 @@ class R2RStorage(BaseStorage): title = str(doc_map.get("title", "Untitled")) metadata = _as_mapping(doc_map.get("metadata", {})) - documents.append( - { - "id": doc_id, - "title": title, - "source_url": str(metadata.get("source_url", "")), - "description": str(metadata.get("description", "")), - "content_type": str(metadata.get("content_type", "text/plain")), - "content_preview": str(doc_map.get("content", ""))[:200] + "..." - if doc_map.get("content") - else "", - "word_count": _as_int(metadata.get("word_count", 0)), - "timestamp": str(doc_map.get("created_at", "")), - } - ) + document_info: DocumentInfo = { + "id": doc_id, + "title": title, + "source_url": str(metadata.get("source_url", "")), + "description": str(metadata.get("description", "")), + "content_type": str(metadata.get("content_type", "text/plain")), + "content_preview": str(doc_map.get("content", ""))[:200] + "..." + if doc_map.get("content") + else "", + "word_count": _as_int(metadata.get("word_count", 0)), + "timestamp": str(doc_map.get("created_at", "")), + } + documents.append(document_info) return documents @@ -12169,10 +12956,191 @@ class R2RStorage(BaseStorage): """Base storage interface.""" +import asyncio +import logging +import random from abc import ABC, abstractmethod from collections.abc import AsyncGenerator +from types import TracebackType +from typing import Final +import httpx +from pydantic import SecretStr + +from ..core.exceptions import StorageError from ..core.models import Document, StorageConfig +from .types import CollectionSummary, DocumentInfo + +LOGGER: Final[logging.Logger] = logging.getLogger(__name__) + + +class TypedHttpClient: + """ + A properly typed HTTP client wrapper for HTTPX. + + Provides consistent exception handling and type annotations + for storage adapters that use HTTP APIs. + + Note: Some type checkers (Pylance) may report warnings about HTTPX types + due to library compatibility issues. The code functions correctly at runtime. + """ + + client: httpx.AsyncClient + _base_url: str + + def __init__( + self, + base_url: str, + *, + api_key: SecretStr | None = None, + timeout: float = 30.0, + headers: dict[str, str] | None = None, + max_connections: int = 100, + max_keepalive_connections: int = 20, + ): + """ + Initialize the typed HTTP client. + + Args: + base_url: Base URL for all requests + api_key: Optional API key for authentication + timeout: Request timeout in seconds + headers: Additional headers to include with requests + max_connections: Maximum total connections in pool + max_keepalive_connections: Maximum keepalive connections + """ + self._base_url = base_url + + # Build headers with optional authentication + client_headers: dict[str, str] = headers or {} + if api_key: + client_headers["Authorization"] = f"Bearer {api_key.get_secret_value()}" + + # Create typed client configuration with connection pooling + limits = httpx.Limits( + max_connections=max_connections, max_keepalive_connections=max_keepalive_connections + ) + timeout_config = httpx.Timeout(connect=5.0, read=timeout, write=30.0, pool=10.0) + self.client = httpx.AsyncClient( + base_url=base_url, headers=client_headers, timeout=timeout_config, limits=limits + ) + + async def request( + self, + method: str, + path: str, + *, + allow_404: bool = False, + json: dict[str, object] | None = None, + data: dict[str, object] | None = None, + files: dict[str, tuple[str, bytes, str]] | None = None, + params: dict[str, str | bool] | None = None, + max_retries: int = 3, + retry_delay: float = 1.0, + ) -> httpx.Response | None: + """ + Perform an HTTP request with consistent error handling and retries. + + Args: + method: HTTP method (GET, POST, DELETE, etc.) + path: URL path relative to base_url + allow_404: If True, return None for 404 responses instead of raising + json: JSON data to send + data: Form data to send + files: Files to upload + params: Query parameters + max_retries: Maximum number of retry attempts + retry_delay: Base delay between retries in seconds + + Returns: + HTTP response object, or None if allow_404=True and status is 404 + + Raises: + StorageError: If request fails after retries + """ + last_exception: Exception | None = None + + for attempt in range(max_retries + 1): + try: + response = await self.client.request( + method, path, json=json, data=data, files=files, params=params + ) + response.raise_for_status() + return response + except httpx.HTTPStatusError as e: + # Handle 404 as special case if requested + if allow_404 and e.response.status_code == 404: + LOGGER.debug("Resource not found (404): %s %s", method, path) + return None + + # Don't retry client errors (4xx except for specific cases) + if 400 <= e.response.status_code < 500 and e.response.status_code not in [429, 408]: + raise StorageError( + f"HTTP {e.response.status_code} error from {self._base_url}: {e}" + ) from e + + last_exception = e + if attempt < max_retries: + # Exponential backoff with jitter for retryable errors + delay = retry_delay * (2**attempt) + random.uniform(0, 1) + LOGGER.warning( + "HTTP %d error on attempt %d/%d, retrying in %.2fs: %s", + e.response.status_code, + attempt + 1, + max_retries + 1, + delay, + e, + ) + await asyncio.sleep(delay) + + except httpx.HTTPError as e: + last_exception = e + if attempt < max_retries: + # Retry transport errors with backoff + delay = retry_delay * (2**attempt) + random.uniform(0, 1) + LOGGER.warning( + "HTTP transport error on attempt %d/%d, retrying in %.2fs: %s", + attempt + 1, + max_retries + 1, + delay, + e, + ) + await asyncio.sleep(delay) + + # All retries exhausted - last_exception should always be set if we reach here + if last_exception is None: + raise StorageError( + f"Request to {self._base_url} failed after {max_retries + 1} attempts with unknown error" + ) + + if isinstance(last_exception, httpx.HTTPStatusError): + raise StorageError( + f"HTTP {last_exception.response.status_code} error from {self._base_url} after {max_retries + 1} attempts: {last_exception}" + ) from last_exception + else: + raise StorageError( + f"HTTP transport error to {self._base_url} after {max_retries + 1} attempts: {last_exception}" + ) from last_exception + + async def close(self) -> None: + """Close the HTTP client and cleanup resources.""" + try: + await self.client.aclose() + except Exception as e: + LOGGER.warning("Error closing HTTP client: %s", e) + + async def __aenter__(self) -> "TypedHttpClient": + """Async context manager entry.""" + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Async context manager exit.""" + await self.close() class BaseStorage(ABC): @@ -12266,11 +13234,30 @@ class BaseStorage(ABC): # Check staleness if timestamp is available if "timestamp" in document.metadata: from datetime import UTC, datetime, timedelta + timestamp_obj = document.metadata["timestamp"] + + # Handle both datetime objects and ISO strings if isinstance(timestamp_obj, datetime): timestamp = timestamp_obj - cutoff = datetime.now(UTC) - timedelta(days=stale_after_days) - return timestamp >= cutoff + # Ensure timezone awareness + if timestamp.tzinfo is None: + timestamp = timestamp.replace(tzinfo=UTC) + elif isinstance(timestamp_obj, str): + try: + timestamp = datetime.fromisoformat(timestamp_obj) + # Ensure timezone awareness + if timestamp.tzinfo is None: + timestamp = timestamp.replace(tzinfo=UTC) + except ValueError: + # If parsing fails, assume document is stale + return False + else: + # Unknown timestamp format, assume stale + return False + + cutoff = datetime.now(UTC) - timedelta(days=stale_after_days) + return timestamp >= cutoff # If no timestamp, assume it exists and is valid return True @@ -12333,12 +13320,12 @@ class BaseStorage(ABC): """ return [] - async def describe_collections(self) -> list[dict[str, object]]: + async def describe_collections(self) -> list[CollectionSummary]: """ Describe available collections with metadata (if supported by backend). Returns: - List of collection metadata dictionaries, empty list if not supported + List of collection metadata, empty list if not supported """ return [] @@ -12375,7 +13362,7 @@ class BaseStorage(ABC): offset: int = 0, *, collection_name: str | None = None, - ) -> list[dict[str, object]]: + ) -> list[DocumentInfo]: """ List documents in the storage backend (if supported). @@ -12385,7 +13372,7 @@ class BaseStorage(ABC): collection_name: Collection to list documents from Returns: - List of document dictionaries with metadata + List of document information with metadata Raises: NotImplementedError: If backend doesn't support document listing @@ -12406,34 +13393,58 @@ class BaseStorage(ABC): """Open WebUI storage adapter.""" import asyncio +import contextlib import logging -from typing import TYPE_CHECKING, Final, TypedDict, cast +import time +from typing import Final, NamedTuple, TypedDict -import httpx from typing_extensions import override -if TYPE_CHECKING: - # Type checking imports - these will be ignored at runtime - from httpx import AsyncClient, ConnectError, HTTPStatusError, RequestError -else: - # Runtime imports that work properly - AsyncClient = httpx.AsyncClient - ConnectError = httpx.ConnectError - HTTPStatusError = httpx.HTTPStatusError - RequestError = httpx.RequestError - from ..core.exceptions import StorageError from ..core.models import Document, StorageConfig -from .base import BaseStorage +from .base import BaseStorage, TypedHttpClient +from .types import CollectionSummary, DocumentInfo LOGGER: Final[logging.Logger] = logging.getLogger(__name__) +class OpenWebUIFileResponse(TypedDict, total=False): + """OpenWebUI API file response structure.""" + + id: str + filename: str + name: str + content_type: str + size: int + created_at: str + meta: dict[str, str | int] + + +class OpenWebUIKnowledgeBase(TypedDict, total=False): + """OpenWebUI knowledge base response structure.""" + + id: str + name: str + description: str + files: list[OpenWebUIFileResponse] + data: dict[str, str] + created_at: str + updated_at: str + + +class CacheEntry(NamedTuple): + """Cache entry with value and expiration time.""" + + value: str + expires_at: float + + class OpenWebUIStorage(BaseStorage): """Storage adapter for Open WebUI knowledge endpoints.""" - client: AsyncClient - _knowledge_cache: dict[str, str] + http_client: TypedHttpClient + _knowledge_cache: dict[str, CacheEntry] + _cache_ttl: float def __init__(self, config: StorageConfig): """ @@ -12444,16 +13455,13 @@ class OpenWebUIStorage(BaseStorage): """ super().__init__(config) - headers: dict[str, str] = {} - if config.api_key: - headers["Authorization"] = f"Bearer {config.api_key}" - - self.client = AsyncClient( + self.http_client = TypedHttpClient( base_url=str(config.endpoint), - headers=headers, + api_key=config.api_key, timeout=30.0, ) self._knowledge_cache = {} + self._cache_ttl = 300.0 # 5 minutes TTL @override async def initialize(self) -> None: @@ -12464,60 +13472,61 @@ class OpenWebUIStorage(BaseStorage): self.config.collection_name, create=True, ) - - except ConnectError as e: - raise StorageError(f"Connection to OpenWebUI failed: {e}") from e - except HTTPStatusError as e: - raise StorageError(f"OpenWebUI returned error {e.response.status_code}: {e}") from e - except RequestError as e: - raise StorageError(f"Request to OpenWebUI failed: {e}") from e except Exception as e: raise StorageError(f"Failed to initialize Open WebUI: {e}") from e async def _create_collection(self, name: str) -> str: """Create knowledge base in Open WebUI.""" - try: - response = await self.client.post( - "/api/v1/knowledge/create", - json={ - "name": name, - "description": "Documents ingested from various sources", - "data": {}, - "access_control": None, - }, - ) - response.raise_for_status() - result = response.json() - knowledge_id = result.get("id") + response = await self.http_client.request( + "POST", + "/api/v1/knowledge/create", + json={ + "name": name, + "description": "Documents ingested from various sources", + "data": {}, + "access_control": None, + }, + ) + if response is None: + raise StorageError("Unexpected None response from knowledge base creation") + result = response.json() + knowledge_id = result.get("id") - if not knowledge_id or not isinstance(knowledge_id, str): - raise StorageError("Knowledge base creation failed: no ID returned") + if not knowledge_id or not isinstance(knowledge_id, str): + raise StorageError("Knowledge base creation failed: no ID returned") - return str(knowledge_id) + return str(knowledge_id) - except ConnectError as e: - raise StorageError(f"Connection to OpenWebUI failed during creation: {e}") from e - except HTTPStatusError as e: - raise StorageError( - f"OpenWebUI returned error {e.response.status_code} during creation: {e}" - ) from e - except RequestError as e: - raise StorageError(f"Request to OpenWebUI failed during creation: {e}") from e - except Exception as e: - raise StorageError(f"Failed to create knowledge base: {e}") from e - - async def _fetch_knowledge_bases(self) -> list[dict[str, object]]: + async def _fetch_knowledge_bases(self) -> list[OpenWebUIKnowledgeBase]: """Return the list of knowledge bases from the API.""" - response = await self.client.get("/api/v1/knowledge/list") - response.raise_for_status() + response = await self.http_client.request("GET", "/api/v1/knowledge/list") + if response is None: + return [] data = response.json() if not isinstance(data, list): return [] - normalized: list[dict[str, object]] = [] + normalized: list[OpenWebUIKnowledgeBase] = [] for item in data: - if isinstance(item, dict): - item_dict: dict[str, object] = item - normalized.append({str(k): v for k, v in item_dict.items()}) + if ( + isinstance(item, dict) + and "id" in item + and "name" in item + and isinstance(item["id"], str) + and isinstance(item["name"], str) + ): + # Create a new dict with known structure + kb_item: OpenWebUIKnowledgeBase = { + "id": item["id"], + "name": item["name"], + "description": item.get("description", ""), + "created_at": item.get("created_at", ""), + "updated_at": item.get("updated_at", ""), + } + if "files" in item and isinstance(item["files"], list): + kb_item["files"] = item["files"] + if "data" in item and isinstance(item["data"], dict): + kb_item["data"] = item["data"] + normalized.append(kb_item) return normalized async def _get_knowledge_id( @@ -12532,22 +13541,29 @@ class OpenWebUIStorage(BaseStorage): if not target: raise StorageError("Knowledge base name is required") - if cached := self._knowledge_cache.get(target): - return cached + # Check cache with TTL + if cached_entry := self._knowledge_cache.get(target): + if time.time() < cached_entry.expires_at: + return cached_entry.value + else: + # Entry expired, remove it + del self._knowledge_cache[target] knowledge_bases = await self._fetch_knowledge_bases() for kb in knowledge_bases: if kb.get("name") == target: kb_id = kb.get("id") if isinstance(kb_id, str): - self._knowledge_cache[target] = kb_id + expires_at = time.time() + self._cache_ttl + self._knowledge_cache[target] = CacheEntry(kb_id, expires_at) return kb_id if not create: return None knowledge_id = await self._create_collection(target) - self._knowledge_cache[target] = knowledge_id + expires_at = time.time() + self._cache_ttl + self._knowledge_cache[target] = CacheEntry(knowledge_id, expires_at) return knowledge_id @override @@ -12573,15 +13589,17 @@ class OpenWebUIStorage(BaseStorage): # Use document title from metadata if available, otherwise fall back to ID filename = document.metadata.get("title") or f"doc_{document.id}" # Ensure filename has proper extension - if not filename.endswith(('.txt', '.md', '.pdf', '.doc', '.docx')): + if not filename.endswith((".txt", ".md", ".pdf", ".doc", ".docx")): filename = f"{filename}.txt" files = {"file": (filename, document.content.encode(), "text/plain")} - response = await self.client.post( + response = await self.http_client.request( + "POST", "/api/v1/files/", files=files, params={"process": True, "process_in_background": False}, ) - response.raise_for_status() + if response is None: + raise StorageError("Unexpected None response from file upload") file_data = response.json() file_id = file_data.get("id") @@ -12590,19 +13608,12 @@ class OpenWebUIStorage(BaseStorage): raise StorageError("File upload failed: no file ID returned") # Step 2: Add file to knowledge base - response = await self.client.post( - f"/api/v1/knowledge/{knowledge_id}/file/add", json={"file_id": file_id} + response = await self.http_client.request( + "POST", f"/api/v1/knowledge/{knowledge_id}/file/add", json={"file_id": file_id} ) - response.raise_for_status() return str(file_id) - except ConnectError as e: - raise StorageError(f"Connection to OpenWebUI failed: {e}") from e - except HTTPStatusError as e: - raise StorageError(f"OpenWebUI returned error {e.response.status_code}: {e}") from e - except RequestError as e: - raise StorageError(f"Request to OpenWebUI failed: {e}") from e except Exception as e: raise StorageError(f"Failed to store document: {e}") from e @@ -12631,15 +13642,19 @@ class OpenWebUIStorage(BaseStorage): # Use document title from metadata if available, otherwise fall back to ID filename = doc.metadata.get("title") or f"doc_{doc.id}" # Ensure filename has proper extension - if not filename.endswith(('.txt', '.md', '.pdf', '.doc', '.docx')): + if not filename.endswith((".txt", ".md", ".pdf", ".doc", ".docx")): filename = f"{filename}.txt" files = {"file": (filename, doc.content.encode(), "text/plain")} - upload_response = await self.client.post( + upload_response = await self.http_client.request( + "POST", "/api/v1/files/", files=files, params={"process": True, "process_in_background": False}, ) - upload_response.raise_for_status() + if upload_response is None: + raise StorageError( + f"Unexpected None response from file upload for document {doc.id}" + ) file_data = upload_response.json() file_id = file_data.get("id") @@ -12649,10 +13664,9 @@ class OpenWebUIStorage(BaseStorage): f"File upload failed for document {doc.id}: no file ID returned" ) - attach_response = await self.client.post( - f"/api/v1/knowledge/{knowledge_id}/file/add", json={"file_id": file_id} + await self.http_client.request( + "POST", f"/api/v1/knowledge/{knowledge_id}/file/add", json={"file_id": file_id} ) - attach_response.raise_for_status() return str(file_id) @@ -12667,7 +13681,8 @@ class OpenWebUIStorage(BaseStorage): if isinstance(result, Exception): failures.append(f"{doc.id}: {result}") else: - file_ids.append(cast(str, result)) + if isinstance(result, str): + file_ids.append(result) if failures: LOGGER.warning( @@ -12678,14 +13693,6 @@ class OpenWebUIStorage(BaseStorage): return file_ids - except ConnectError as e: - raise StorageError(f"Connection to OpenWebUI failed during batch: {e}") from e - except HTTPStatusError as e: - raise StorageError( - f"OpenWebUI returned error {e.response.status_code} during batch: {e}" - ) from e - except RequestError as e: - raise StorageError(f"Request to OpenWebUI failed during batch: {e}") from e except Exception as e: raise StorageError(f"Failed to store batch: {e}") from e @@ -12703,11 +13710,88 @@ class OpenWebUIStorage(BaseStorage): Returns: Always None - retrieval not supported """ + _ = document_id, collection_name # Mark as used # OpenWebUI uses file-based storage without direct document retrieval - # This will cause the base check_exists method to return False, - # which means documents will always be re-scraped for OpenWebUI raise NotImplementedError("OpenWebUI doesn't support document retrieval by ID") + @override + async def check_exists( + self, document_id: str, *, collection_name: str | None = None, stale_after_days: int = 30 + ) -> bool: + """ + Check if a document exists in OpenWebUI knowledge base by searching files. + + Args: + document_id: Document ID to check (usually based on source URL) + collection_name: Knowledge base name + stale_after_days: Consider document stale after this many days + + Returns: + True if document exists and is not stale, False otherwise + """ + try: + from datetime import UTC, datetime, timedelta + + # Get knowledge base + knowledge_id = await self._get_knowledge_id(collection_name, create=False) + if not knowledge_id: + return False + + # Get detailed knowledge base info to access files + response = await self.http_client.request("GET", f"/api/v1/knowledge/{knowledge_id}") + if response is None: + return False + + kb_data = response.json() + files = kb_data.get("files", []) + + # Look for file with matching document ID or source URL in metadata + cutoff = datetime.now(UTC) - timedelta(days=stale_after_days) + + def _parse_openwebui_timestamp(timestamp_str: str) -> datetime | None: + """Parse OpenWebUI timestamp with proper timezone handling.""" + try: + # Handle both 'Z' suffix and explicit timezone + normalized = timestamp_str.replace("Z", "+00:00") + parsed = datetime.fromisoformat(normalized) + # Ensure timezone awareness + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=UTC) + return parsed + except (ValueError, AttributeError): + return None + + def _check_file_freshness(file_info: dict[str, object]) -> bool: + """Check if file is fresh enough based on creation date.""" + created_at = file_info.get("created_at") + if not isinstance(created_at, str): + # No date info available, consider stale to be safe + return False + + file_date = _parse_openwebui_timestamp(created_at) + return file_date is not None and file_date >= cutoff + + for file_info in files: + if not isinstance(file_info, dict): + continue + + file_id = file_info.get("id") + if str(file_id) == document_id: + return _check_file_freshness(file_info) + + # Also check meta.source_url if available for URL-based document IDs + meta = file_info.get("meta", {}) + if isinstance(meta, dict): + source_url = meta.get("source_url") + if source_url and document_id in str(source_url): + return _check_file_freshness(file_info) + + return False + + except Exception as e: + LOGGER.debug("Error checking document existence in OpenWebUI: %s", e) + return False + @override async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool: """ @@ -12728,35 +13812,16 @@ class OpenWebUIStorage(BaseStorage): return False # Remove file from knowledge base - response = await self.client.post( - f"/api/v1/knowledge/{knowledge_id}/file/remove", json={"file_id": document_id} + await self.http_client.request( + "POST", + f"/api/v1/knowledge/{knowledge_id}/file/remove", + json={"file_id": document_id}, ) - response.raise_for_status() - delete_response = await self.client.delete(f"/api/v1/files/{document_id}") - if delete_response.status_code == 404: - return True - delete_response.raise_for_status() + await self.http_client.request("DELETE", f"/api/v1/files/{document_id}", allow_404=True) return True - - except ConnectError as exc: - LOGGER.error( - "Failed to reach OpenWebUI when deleting file %s", document_id, exc_info=exc - ) - return False - except HTTPStatusError as exc: - LOGGER.error( - "OpenWebUI returned status error %s when deleting file %s", - exc.response.status_code if exc.response else "unknown", - document_id, - exc_info=exc, - ) - return False - except RequestError as exc: - LOGGER.error("Request error deleting file %s from OpenWebUI", document_id, exc_info=exc) - return False except Exception as exc: - LOGGER.error("Unexpected error deleting file %s", document_id, exc_info=exc) + LOGGER.error("Error deleting file %s from OpenWebUI", document_id, exc_info=exc) return False async def list_collections(self) -> list[str]: @@ -12775,12 +13840,6 @@ class OpenWebUIStorage(BaseStorage): for kb in knowledge_bases ] - except ConnectError as e: - raise StorageError(f"Connection to OpenWebUI failed: {e}") from e - except HTTPStatusError as e: - raise StorageError(f"OpenWebUI returned error {e.response.status_code}: {e}") from e - except RequestError as e: - raise StorageError(f"Request to OpenWebUI failed: {e}") from e except Exception as e: raise StorageError(f"Failed to list knowledge bases: {e}") from e @@ -12801,8 +13860,9 @@ class OpenWebUIStorage(BaseStorage): return True # Delete the knowledge base using the OpenWebUI API - response = await self.client.delete(f"/api/v1/knowledge/{knowledge_id}/delete") - response.raise_for_status() + await self.http_client.request( + "DELETE", f"/api/v1/knowledge/{knowledge_id}/delete", allow_404=True + ) # Remove from cache if it exists if collection_name in self._knowledge_cache: @@ -12811,45 +13871,26 @@ class OpenWebUIStorage(BaseStorage): LOGGER.info("Successfully deleted knowledge base: %s", collection_name) return True - except HTTPStatusError as e: - # Handle 404 as success (already deleted) - if e.response.status_code == 404: - LOGGER.info("Knowledge base %s was already deleted or not found", collection_name) - return True - LOGGER.error( - "OpenWebUI returned error %s when deleting knowledge base %s", - e.response.status_code, - collection_name, - exc_info=e, - ) - return False - except ConnectError as e: - LOGGER.error( - "Failed to reach OpenWebUI when deleting knowledge base %s", - collection_name, - exc_info=e, - ) - return False - except RequestError as e: - LOGGER.error( - "Request error deleting knowledge base %s from OpenWebUI", - collection_name, - exc_info=e, - ) - return False except Exception as e: - LOGGER.error("Unexpected error deleting knowledge base %s", collection_name, exc_info=e) + if hasattr(e, "response"): + response_attr = getattr(e, "response", None) + if response_attr is not None and hasattr(response_attr, "status_code"): + with contextlib.suppress(Exception): + status_code = response_attr.status_code + if status_code == 404: + LOGGER.info( + "Knowledge base %s was already deleted or not found", + collection_name, + ) + return True + LOGGER.error( + "Error deleting knowledge base %s from OpenWebUI", + collection_name, + exc_info=e, + ) return False - class CollectionSummary(TypedDict): - """Structure describing a knowledge base summary.""" - - name: str - count: int - size_mb: float - - - async def _get_knowledge_base_count(self, kb: dict[str, object]) -> int: + async def _get_knowledge_base_count(self, kb: OpenWebUIKnowledgeBase) -> int: """Get the file count for a knowledge base.""" kb_id = kb.get("id") name = kb.get("name", "Unknown") @@ -12859,17 +13900,21 @@ class OpenWebUIStorage(BaseStorage): return await self._count_files_from_detailed_info(str(kb_id), str(name), kb) - def _count_files_from_basic_info(self, kb: dict[str, object]) -> int: + def _count_files_from_basic_info(self, kb: OpenWebUIKnowledgeBase) -> int: """Count files from basic knowledge base info.""" files = kb.get("files", []) return len(files) if isinstance(files, list) and files is not None else 0 - async def _count_files_from_detailed_info(self, kb_id: str, name: str, kb: dict[str, object]) -> int: + async def _count_files_from_detailed_info( + self, kb_id: str, name: str, kb: OpenWebUIKnowledgeBase + ) -> int: """Count files by fetching detailed knowledge base info.""" try: LOGGER.debug(f"Fetching detailed info for KB '{name}' from /api/v1/knowledge/{kb_id}") - detail_response = await self.client.get(f"/api/v1/knowledge/{kb_id}") - detail_response.raise_for_status() + detail_response = await self.http_client.request("GET", f"/api/v1/knowledge/{kb_id}") + if detail_response is None: + LOGGER.warning(f"Knowledge base '{name}' (ID: {kb_id}) not found") + return self._count_files_from_basic_info(kb) detailed_kb = detail_response.json() files = detailed_kb.get("files", []) @@ -12882,21 +13927,18 @@ class OpenWebUIStorage(BaseStorage): LOGGER.warning(f"Failed to get detailed info for KB '{name}' (ID: {kb_id}): {e}") return self._count_files_from_basic_info(kb) - async def describe_collections(self) -> list[dict[str, object]]: + async def describe_collections(self) -> list[CollectionSummary]: """Return metadata about each knowledge base.""" try: knowledge_bases = await self._fetch_knowledge_bases() - collections: list[dict[str, object]] = [] + collections: list[CollectionSummary] = [] for kb in knowledge_bases: - if not isinstance(kb, dict): - continue - count = await self._get_knowledge_base_count(kb) name = kb.get("name", "Unknown") size_mb = count * 0.5 # rough heuristic - summary: dict[str, object] = { + summary: CollectionSummary = { "name": str(name), "count": count, "size_mb": float(size_mb), @@ -12940,8 +13982,10 @@ class OpenWebUIStorage(BaseStorage): return 0 # Get detailed knowledge base information to get accurate file count - detail_response = await self.client.get(f"/api/v1/knowledge/{kb_id}") - detail_response.raise_for_status() + detail_response = await self.http_client.request("GET", f"/api/v1/knowledge/{kb_id}") + if detail_response is None: + LOGGER.warning(f"Knowledge base '{collection_name}' (ID: {kb_id}) not found") + return self._count_files_from_basic_info(kb) detailed_kb = detail_response.json() files = detailed_kb.get("files", []) @@ -12954,7 +13998,7 @@ class OpenWebUIStorage(BaseStorage): LOGGER.warning(f"Failed to get count for collection '{collection_name}': {e}") return 0 - async def get_knowledge_by_name(self, name: str) -> dict[str, object] | None: + async def get_knowledge_by_name(self, name: str) -> OpenWebUIKnowledgeBase | None: """ Get knowledge base details by name. @@ -12965,18 +14009,33 @@ class OpenWebUIStorage(BaseStorage): Knowledge base details or None if not found """ try: - response = await self.client.get("/api/v1/knowledge/list") - response.raise_for_status() + response = await self.http_client.request("GET", "/api/v1/knowledge/list") + if response is None: + return None knowledge_bases = response.json() - return next( - ( - {str(k): v for k, v in kb.items()} - for kb in knowledge_bases - if isinstance(kb, dict) and kb.get("name") == name - ), - None, - ) + # Find and properly type the knowledge base + for kb in knowledge_bases: + if ( + isinstance(kb, dict) + and kb.get("name") == name + and "id" in kb + and isinstance(kb["id"], str) + ): + # Create properly typed response + result: OpenWebUIKnowledgeBase = { + "id": kb["id"], + "name": str(kb["name"]), + "description": kb.get("description", ""), + "created_at": kb.get("created_at", ""), + "updated_at": kb.get("updated_at", ""), + } + if "files" in kb and isinstance(kb["files"], list): + result["files"] = kb["files"] + if "data" in kb and isinstance(kb["data"], dict): + result["data"] = kb["data"] + return result + return None except Exception as e: raise StorageError(f"Failed to get knowledge base by name: {e}") from e @@ -12992,6 +14051,7 @@ class OpenWebUIStorage(BaseStorage): exc_tb: object | None, ) -> None: """Async context manager exit.""" + _ = exc_type, exc_val, exc_tb # Mark as used await self.close() async def list_documents( @@ -13000,7 +14060,7 @@ class OpenWebUIStorage(BaseStorage): offset: int = 0, *, collection_name: str | None = None, - ) -> list[dict[str, object]]: + ) -> list[DocumentInfo]: """ List documents (files) in a knowledge base. @@ -13050,11 +14110,8 @@ class OpenWebUIStorage(BaseStorage): paginated_files = files[offset : offset + limit] # Convert to document format with safe field access - documents: list[dict[str, object]] = [] + documents: list[DocumentInfo] = [] for i, file_info in enumerate(paginated_files): - if not isinstance(file_info, dict): - continue - # Safely extract fields with fallbacks doc_id = str(file_info.get("id", f"file_{i}")) @@ -13068,7 +14125,11 @@ class OpenWebUIStorage(BaseStorage): filename = file_info["name"] # Check meta.name (from FileModelResponse schema) elif isinstance(file_info.get("meta"), dict): - filename = file_info["meta"].get("name") + meta = file_info.get("meta") + if isinstance(meta, dict): + filename_value = meta.get("name") + if isinstance(filename_value, str): + filename = filename_value # Final fallback if not filename: @@ -13078,28 +14139,28 @@ class OpenWebUIStorage(BaseStorage): # Extract size from meta if available size = 0 - if isinstance(file_info.get("meta"), dict): - size = file_info["meta"].get("size", 0) + meta = file_info.get("meta") + if isinstance(meta, dict): + size_value = meta.get("size", 0) + size = int(size_value) if isinstance(size_value, (int, float)) else 0 else: - size = file_info.get("size", 0) + size_value = file_info.get("size", 0) + size = int(size_value) if isinstance(size_value, (int, float)) else 0 # Estimate word count from file size (very rough approximation) word_count = max(1, int(size / 6)) if isinstance(size, (int, float)) else 0 - documents.append( - { - "id": doc_id, - "title": filename, - "source_url": "", # OpenWebUI files don't typically have source URLs - "description": f"File: {filename}", - "content_type": str(file_info.get("content_type", "text/plain")), - "content_preview": f"File uploaded to OpenWebUI: {filename}", - "word_count": word_count, - "timestamp": str( - file_info.get("created_at") or file_info.get("timestamp", "") - ), - } - ) + doc_info: DocumentInfo = { + "id": doc_id, + "title": filename, + "source_url": "", # OpenWebUI files don't typically have source URLs + "description": f"File: {filename}", + "content_type": str(file_info.get("content_type", "text/plain")), + "content_preview": f"File uploaded to OpenWebUI: {filename}", + "word_count": word_count, + "timestamp": str(file_info.get("created_at") or file_info.get("timestamp", "")), + } + documents.append(doc_info) return documents @@ -13126,21 +14187,17 @@ class OpenWebUIStorage(BaseStorage): async def close(self) -> None: """Close client connection.""" - if hasattr(self, "client") and self.client: - try: - await self.client.aclose() - except Exception as e: - import logging - - logging.warning(f"Error closing OpenWebUI client: {e}") + if hasattr(self, "http_client"): + await self.http_client.close() """Weaviate storage adapter.""" -from collections.abc import AsyncGenerator, Mapping, Sequence +import asyncio +from collections.abc import AsyncGenerator, Callable, Mapping, Sequence from datetime import UTC, datetime -from typing import Literal, Self, TypeAlias, cast, overload +from typing import Literal, Self, TypeAlias, TypeVar, cast, overload from uuid import UUID import weaviate @@ -13159,8 +14216,10 @@ from ..core.exceptions import StorageError from ..core.models import Document, DocumentMetadata, IngestionSource, StorageConfig from ..utils.vectorizer import Vectorizer from .base import BaseStorage +from .types import CollectionSummary, DocumentInfo VectorContainer: TypeAlias = Mapping[str, object] | Sequence[object] | None +T = TypeVar("T") class WeaviateStorage(BaseStorage): @@ -13182,6 +14241,28 @@ class WeaviateStorage(BaseStorage): self.vectorizer = Vectorizer(config) self._default_collection = self._normalize_collection_name(config.collection_name) + async def _run_sync(self, func: Callable[..., T], *args: object, **kwargs: object) -> T: + """ + Run synchronous Weaviate operations in thread pool to avoid blocking event loop. + + Args: + func: Synchronous function to run + *args: Positional arguments for the function + **kwargs: Keyword arguments for the function + + Returns: + Result of the function call + + Raises: + StorageError: If the operation fails + """ + try: + return await asyncio.to_thread(func, *args, **kwargs) + except (WeaviateConnectionError, WeaviateBatchError, WeaviateQueryError) as e: + raise StorageError(f"Weaviate operation failed: {e}") from e + except Exception as e: + raise StorageError(f"Unexpected error in Weaviate operation: {e}") from e + @override async def initialize(self) -> None: """Initialize Weaviate client and create collection if needed.""" @@ -13190,7 +14271,7 @@ class WeaviateStorage(BaseStorage): self.client = weaviate.WeaviateClient( connection_params=weaviate.connect.ConnectionParams.from_url( url=str(self.config.endpoint), - grpc_port=50051, # Default gRPC port + grpc_port=self.config.grpc_port or 50051, ), additional_config=weaviate.classes.init.AdditionalConfig( timeout=weaviate.classes.init.Timeout(init=30, query=60, insert=120), @@ -13198,7 +14279,7 @@ class WeaviateStorage(BaseStorage): ) # Connect to the client - self.client.connect() + await self._run_sync(self.client.connect) # Ensure the default collection exists await self._ensure_collection(self._default_collection) @@ -13213,8 +14294,8 @@ class WeaviateStorage(BaseStorage): if not self.client: raise StorageError("Weaviate client not initialized") try: - client = cast(weaviate.WeaviateClient, self.client) - client.collections.create( + await self._run_sync( + self.client.collections.create, name=collection_name, properties=[ Property( @@ -13243,7 +14324,7 @@ class WeaviateStorage(BaseStorage): ], vectorizer_config=Configure.Vectorizer.none(), ) - except Exception as e: + except (WeaviateConnectionError, WeaviateBatchError) as e: raise StorageError(f"Failed to create collection: {e}") from e @staticmethod @@ -13251,13 +14332,9 @@ class WeaviateStorage(BaseStorage): """Normalize vector payloads returned by Weaviate into a float list.""" if isinstance(vector_raw, Mapping): default_vector = vector_raw.get("default") - return WeaviateStorage._extract_vector( - cast(VectorContainer, default_vector) - ) + return WeaviateStorage._extract_vector(cast(VectorContainer, default_vector)) - if not isinstance(vector_raw, Sequence) or isinstance( - vector_raw, (str, bytes, bytearray) - ): + if not isinstance(vector_raw, Sequence) or isinstance(vector_raw, (str, bytes, bytearray)): return None items = list(vector_raw) @@ -13272,9 +14349,7 @@ class WeaviateStorage(BaseStorage): except (TypeError, ValueError): return None - if isinstance(first_item, Sequence) and not isinstance( - first_item, (str, bytes, bytearray) - ): + if isinstance(first_item, Sequence) and not isinstance(first_item, (str, bytes, bytearray)): inner_items = list(first_item) if all(isinstance(item, (int, float)) for item in inner_items): try: @@ -13305,8 +14380,7 @@ class WeaviateStorage(BaseStorage): properties: object, *, context: str, - ) -> Mapping[str, object]: - ... + ) -> Mapping[str, object]: ... @staticmethod @overload @@ -13315,8 +14389,7 @@ class WeaviateStorage(BaseStorage): *, context: str, allow_missing: Literal[False], - ) -> Mapping[str, object]: - ... + ) -> Mapping[str, object]: ... @staticmethod @overload @@ -13325,8 +14398,7 @@ class WeaviateStorage(BaseStorage): *, context: str, allow_missing: Literal[True], - ) -> Mapping[str, object] | None: - ... + ) -> Mapping[str, object] | None: ... @staticmethod def _coerce_properties( @@ -13348,6 +14420,29 @@ class WeaviateStorage(BaseStorage): return cast(Mapping[str, object], properties) + @staticmethod + def _build_document_properties(doc: Document) -> dict[str, object]: + """ + Build Weaviate properties dict from document. + + Args: + doc: Document to build properties for + + Returns: + Properties dict suitable for Weaviate + """ + return { + "content": doc.content, + "source_url": doc.metadata["source_url"], + "title": doc.metadata.get("title", ""), + "description": doc.metadata.get("description", ""), + "timestamp": doc.metadata["timestamp"].isoformat(), + "content_type": doc.metadata["content_type"], + "word_count": doc.metadata["word_count"], + "char_count": doc.metadata["char_count"], + "source": doc.source.value, + } + def _normalize_collection_name(self, collection_name: str | None) -> str: """Return a canonicalized collection name, defaulting to configured value.""" candidate = collection_name or self.config.collection_name @@ -13364,7 +14459,7 @@ class WeaviateStorage(BaseStorage): if not self.client: raise StorageError("Weaviate client not initialized") - client = cast(weaviate.WeaviateClient, self.client) + client = self.client existing = client.collections.list_all() if collection_name not in existing: await self._create_collection(collection_name) @@ -13384,7 +14479,7 @@ class WeaviateStorage(BaseStorage): if ensure_exists: await self._ensure_collection(normalized) - client = cast(weaviate.WeaviateClient, self.client) + client = self.client return client.collections.get(normalized), normalized @override @@ -13408,26 +14503,19 @@ class WeaviateStorage(BaseStorage): ) # Prepare properties - properties = { - "content": document.content, - "source_url": document.metadata["source_url"], - "title": document.metadata.get("title", ""), - "description": document.metadata.get("description", ""), - "timestamp": document.metadata["timestamp"].isoformat(), - "content_type": document.metadata["content_type"], - "word_count": document.metadata["word_count"], - "char_count": document.metadata["char_count"], - "source": document.source.value, - } + properties = self._build_document_properties(document) # Insert with vector - result = collection.data.insert( - properties=properties, vector=document.vector, uuid=str(document.id) + result = await self._run_sync( + collection.data.insert, + properties=properties, + vector=document.vector, + uuid=str(document.id), ) return str(result) - except Exception as e: + except (WeaviateConnectionError, WeaviateBatchError, WeaviateQueryError) as e: raise StorageError(f"Failed to store document: {e}") from e @override @@ -13448,32 +14536,24 @@ class WeaviateStorage(BaseStorage): collection_name, ensure_exists=True ) - # Vectorize documents without vectors - for doc in documents: - if doc.vector is None: - doc.vector = await self.vectorizer.vectorize(doc.content) + # Vectorize documents without vectors using batch processing + to_vectorize = [(i, doc) for i, doc in enumerate(documents) if doc.vector is None] + if to_vectorize: + contents = [doc.content for _, doc in to_vectorize] + vectors = await self.vectorizer.vectorize_batch(contents) + for (idx, _), vector in zip(to_vectorize, vectors, strict=False): + documents[idx].vector = vector # Prepare batch data for insert_many batch_objects = [] for doc in documents: - properties = { - "content": doc.content, - "source_url": doc.metadata["source_url"], - "title": doc.metadata.get("title", ""), - "description": doc.metadata.get("description", ""), - "timestamp": doc.metadata["timestamp"].isoformat(), - "content_type": doc.metadata["content_type"], - "word_count": doc.metadata["word_count"], - "char_count": doc.metadata["char_count"], - "source": doc.source.value, - } - + properties = self._build_document_properties(doc) batch_objects.append( DataObject(properties=properties, vector=doc.vector, uuid=str(doc.id)) ) # Insert batch using insert_many - response = collection.data.insert_many(batch_objects) + response = await self._run_sync(collection.data.insert_many, batch_objects) successful_ids: list[str] = [] error_indices = set(response.errors.keys()) if response else set() @@ -13498,11 +14578,7 @@ class WeaviateStorage(BaseStorage): return successful_ids - except WeaviateBatchError as e: - raise StorageError(f"Batch operation failed: {e}") from e - except WeaviateConnectionError as e: - raise StorageError(f"Connection to Weaviate failed: {e}") from e - except Exception as e: + except (WeaviateBatchError, WeaviateConnectionError, WeaviateQueryError) as e: raise StorageError(f"Failed to store batch: {e}") from e @override @@ -13522,7 +14598,7 @@ class WeaviateStorage(BaseStorage): collection, resolved_name = await self._prepare_collection( collection_name, ensure_exists=False ) - result = collection.query.fetch_object_by_id(document_id) + result = await self._run_sync(collection.query.fetch_object_by_id, document_id) if not result: return None @@ -13532,13 +14608,30 @@ class WeaviateStorage(BaseStorage): result.properties, context="fetch_object_by_id", ) + # Parse timestamp to datetime for consistent metadata format + from datetime import UTC, datetime + + timestamp_raw = props.get("timestamp") + timestamp_parsed: datetime + try: + if isinstance(timestamp_raw, str): + timestamp_parsed = datetime.fromisoformat(timestamp_raw) + if timestamp_parsed.tzinfo is None: + timestamp_parsed = timestamp_parsed.replace(tzinfo=UTC) + elif isinstance(timestamp_raw, datetime): + timestamp_parsed = timestamp_raw + if timestamp_parsed.tzinfo is None: + timestamp_parsed = timestamp_parsed.replace(tzinfo=UTC) + else: + timestamp_parsed = datetime.now(UTC) + except (ValueError, TypeError): + timestamp_parsed = datetime.now(UTC) + metadata_dict = { "source_url": str(props["source_url"]), "title": str(props.get("title")) if props.get("title") else None, - "description": str(props.get("description")) - if props.get("description") - else None, - "timestamp": str(props["timestamp"]), + "description": str(props.get("description")) if props.get("description") else None, + "timestamp": timestamp_parsed, "content_type": str(props["content_type"]), "word_count": int(str(props["word_count"])), "char_count": int(str(props["char_count"])), @@ -13561,11 +14654,13 @@ class WeaviateStorage(BaseStorage): except WeaviateConnectionError as e: # Connection issues should be logged and return None import logging + logging.warning(f"Weaviate connection error retrieving document {document_id}: {e}") return None except Exception as e: # Log unexpected errors for debugging import logging + logging.warning(f"Unexpected error retrieving document {document_id}: {e}") return None @@ -13574,9 +14669,7 @@ class WeaviateStorage(BaseStorage): metadata_dict = { "source_url": str(props["source_url"]), "title": str(props.get("title")) if props.get("title") else None, - "description": str(props.get("description")) - if props.get("description") - else None, + "description": str(props.get("description")) if props.get("description") else None, "timestamp": str(props["timestamp"]), "content_type": str(props["content_type"]), "word_count": int(str(props["word_count"])), @@ -13599,6 +14692,7 @@ class WeaviateStorage(BaseStorage): return max(0.0, 1.0 - distance_value) except (TypeError, ValueError) as e: import logging + logging.debug(f"Invalid distance value {raw_distance}: {e}") return None @@ -13643,37 +14737,39 @@ class WeaviateStorage(BaseStorage): collection_name: str | None = None, ) -> AsyncGenerator[Document, None]: """ - Search for documents in Weaviate. + Search for documents in Weaviate using hybrid search. Args: query: Search query limit: Maximum results - threshold: Similarity threshold + threshold: Similarity threshold (not used in hybrid search) Yields: Matching documents """ try: - query_vector = await self.vectorizer.vectorize(query) + if not self.client: + raise StorageError("Weaviate client not initialized") + collection, resolved_name = await self._prepare_collection( collection_name, ensure_exists=False ) - results = collection.query.near_vector( - near_vector=query_vector, - limit=limit, - distance=1 - threshold, - return_metadata=["distance"], - ) + # Try hybrid search first, fall back to BM25 keyword search + try: + response = await self._run_sync( + collection.query.hybrid, query=query, limit=limit, return_metadata=["score"] + ) + except (WeaviateQueryError, StorageError): + # Fall back to BM25 if hybrid search is not supported or fails + response = await self._run_sync( + collection.query.bm25, query=query, limit=limit, return_metadata=["score"] + ) - for result in results.objects: - yield self._build_search_document(result, resolved_name) + for obj in response.objects: + yield self._build_document_from_search(obj, resolved_name) - except WeaviateQueryError as e: - raise StorageError(f"Search query failed: {e}") from e - except WeaviateConnectionError as e: - raise StorageError(f"Connection to Weaviate failed during search: {e}") from e - except Exception as e: + except (WeaviateQueryError, WeaviateConnectionError) as e: raise StorageError(f"Search failed: {e}") from e @override @@ -13689,7 +14785,7 @@ class WeaviateStorage(BaseStorage): """ try: collection, _ = await self._prepare_collection(collection_name, ensure_exists=False) - collection.data.delete_by_id(document_id) + await self._run_sync(collection.data.delete_by_id, document_id) return True except WeaviateQueryError as e: raise StorageError(f"Delete operation failed: {e}") from e @@ -13726,20 +14822,20 @@ class WeaviateStorage(BaseStorage): if not self.client: raise StorageError("Weaviate client not initialized") - client = cast(weaviate.WeaviateClient, self.client) + client = self.client return list(client.collections.list_all()) - except Exception as e: + except WeaviateConnectionError as e: raise StorageError(f"Failed to list collections: {e}") from e - async def describe_collections(self) -> list[dict[str, object]]: + async def describe_collections(self) -> list[CollectionSummary]: """Return metadata for each Weaviate collection.""" if not self.client: raise StorageError("Weaviate client not initialized") try: - client = cast(weaviate.WeaviateClient, self.client) - collections: list[dict[str, object]] = [] + client = self.client + collections: list[CollectionSummary] = [] for name in client.collections.list_all(): collection_obj = client.collections.get(name) if not collection_obj: @@ -13747,13 +14843,12 @@ class WeaviateStorage(BaseStorage): count = collection_obj.aggregate.over_all(total_count=True).total_count or 0 size_mb = count * 0.01 - collections.append( - { - "name": name, - "count": count, - "size_mb": size_mb, - } - ) + collection_summary: CollectionSummary = { + "name": name, + "count": count, + "size_mb": size_mb, + } + collections.append(collection_summary) return collections except Exception as e: @@ -13777,7 +14872,7 @@ class WeaviateStorage(BaseStorage): ) # Query for sample documents - response = collection.query.fetch_objects(limit=limit) + response = await self._run_sync(collection.query.fetch_objects, limit=limit) documents = [] for obj in response.objects: @@ -13850,9 +14945,7 @@ class WeaviateStorage(BaseStorage): return { "source_url": str(props.get("source_url", "")), "title": str(props.get("title", "")) if props.get("title") else None, - "description": str(props.get("description", "")) - if props.get("description") - else None, + "description": str(props.get("description", "")) if props.get("description") else None, "timestamp": datetime.fromisoformat( str(props.get("timestamp", datetime.now(UTC).isoformat())) ), @@ -13875,6 +14968,7 @@ class WeaviateStorage(BaseStorage): return float(raw_score) except (TypeError, ValueError) as e: import logging + logging.debug(f"Invalid score value {raw_score}: {e}") return None @@ -13918,31 +15012,11 @@ class WeaviateStorage(BaseStorage): Returns: List of matching documents """ - try: - if not self.client: - raise StorageError("Weaviate client not initialized") - - collection, resolved_name = await self._prepare_collection( - collection_name, ensure_exists=False - ) - - # Try hybrid search first, fall back to BM25 keyword search - try: - response = collection.query.hybrid( - query=query, limit=limit, return_metadata=["score"] - ) - except Exception: - response = collection.query.bm25( - query=query, limit=limit, return_metadata=["score"] - ) - - return [ - self._build_document_from_search(obj, resolved_name) - for obj in response.objects - ] - - except Exception as e: - raise StorageError(f"Failed to search documents: {e}") from e + # Delegate to the unified search method + results = [] + async for document in self.search(query, limit=limit, collection_name=collection_name): + results.append(document) + return results async def list_documents( self, @@ -13950,7 +15024,7 @@ class WeaviateStorage(BaseStorage): offset: int = 0, *, collection_name: str | None = None, - ) -> list[dict[str, object]]: + ) -> list[DocumentInfo]: """ List documents in the collection with pagination. @@ -13968,11 +15042,14 @@ class WeaviateStorage(BaseStorage): collection, _ = await self._prepare_collection(collection_name, ensure_exists=False) # Query documents with pagination - response = collection.query.fetch_objects( - limit=limit, offset=offset, return_metadata=["creation_time"] + response = await self._run_sync( + collection.query.fetch_objects, + limit=limit, + offset=offset, + return_metadata=["creation_time"], ) - documents: list[dict[str, object]] = [] + documents: list[DocumentInfo] = [] for obj in response.objects: props = self._coerce_properties( obj.properties, @@ -13991,7 +15068,7 @@ class WeaviateStorage(BaseStorage): else: word_count = 0 - doc_info: dict[str, object] = { + doc_info: DocumentInfo = { "id": str(obj.uuid), "title": str(props.get("title", "Untitled")), "source_url": str(props.get("source_url", "")), @@ -14034,7 +15111,9 @@ class WeaviateStorage(BaseStorage): ) delete_filter = Filter.by_id().contains_any(document_ids) - response = collection.data.delete_many(where=delete_filter, verbose=True) + response = await self._run_sync( + collection.data.delete_many, where=delete_filter, verbose=True + ) if objects := getattr(response, "objects", None): for result_obj in objects: @@ -14076,20 +15155,22 @@ class WeaviateStorage(BaseStorage): # Get documents matching filter if where_filter: - response = collection.query.fetch_objects( + response = await self._run_sync( + collection.query.fetch_objects, filters=where_filter, limit=1000, # Max batch size ) else: - response = collection.query.fetch_objects( - limit=1000 # Max batch size + response = await self._run_sync( + collection.query.fetch_objects, + limit=1000, # Max batch size ) # Delete matching documents deleted_count = 0 for obj in response.objects: try: - collection.data.delete_by_id(obj.uuid) + await self._run_sync(collection.data.delete_by_id, obj.uuid) deleted_count += 1 except Exception: continue @@ -14113,7 +15194,7 @@ class WeaviateStorage(BaseStorage): target = self._normalize_collection_name(collection_name) # Delete the collection using the client's collections API - client = cast(weaviate.WeaviateClient, self.client) + client = self.client client.collections.delete(target) return True @@ -14135,20 +15216,29 @@ class WeaviateStorage(BaseStorage): await self.close() async def close(self) -> None: - """Close client connection.""" + """Close client connection and vectorizer HTTP client.""" if self.client: try: - client = cast(weaviate.WeaviateClient, self.client) + client = self.client client.close() - except Exception as e: + except (WeaviateConnectionError, AttributeError) as e: import logging + logging.warning(f"Error closing Weaviate client: {e}") + # Close vectorizer HTTP client to prevent resource leaks + try: + await self.vectorizer.close() + except (AttributeError, OSError) as e: + import logging + + logging.warning(f"Error closing vectorizer client: {e}") + def __del__(self) -> None: """Clean up client connection as fallback.""" if self.client: try: - client = cast(weaviate.WeaviateClient, self.client) + client = self.client client.close() except Exception: pass # Ignore errors in destructor diff --git a/tests/__pycache__/__init__.cpython-312.pyc b/tests/__pycache__/__init__.cpython-312.pyc index 72ead4f..7359d82 100644 Binary files a/tests/__pycache__/__init__.cpython-312.pyc and b/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/tests/__pycache__/conftest.cpython-312-pytest-8.4.2.pyc b/tests/__pycache__/conftest.cpython-312-pytest-8.4.2.pyc index 8047a3d..44dc978 100644 Binary files a/tests/__pycache__/conftest.cpython-312-pytest-8.4.2.pyc and b/tests/__pycache__/conftest.cpython-312-pytest-8.4.2.pyc differ diff --git a/tests/__pycache__/openapi_mocks.cpython-312.pyc b/tests/__pycache__/openapi_mocks.cpython-312.pyc index a9e4d4f..d67ba32 100644 Binary files a/tests/__pycache__/openapi_mocks.cpython-312.pyc and b/tests/__pycache__/openapi_mocks.cpython-312.pyc differ diff --git a/tests/conftest.py b/tests/conftest.py index 8b77ee4..ce9a4e0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -59,6 +59,7 @@ class StubbedResponse: def raise_for_status(self) -> None: if self.status_code < 400: return + # Create a minimal exception for testing - we don't need full httpx objects for tests class TestHTTPError(Exception): request: object | None @@ -232,7 +233,7 @@ class AsyncClientStub: # Convert params to the format expected by other methods converted_params: dict[str, object] | None = None if params: - converted_params = {k: v for k, v in params.items()} + converted_params = dict(params) method_upper = method.upper() if method_upper == "GET": @@ -441,7 +442,6 @@ def r2r_spec() -> OpenAPISpec: return OpenAPISpec.from_file(PROJECT_ROOT / "r2r.json") - @pytest.fixture(scope="session") def firecrawl_spec() -> OpenAPISpec: """Load Firecrawl OpenAPI specification.""" @@ -539,7 +539,9 @@ def firecrawl_client_stub( links = [SimpleNamespace(url=cast(str, item.get("url", ""))) for item in links_data] return SimpleNamespace(success=payload.get("success", True), links=links) - async def scrape(self, url: str, formats: list[str] | None = None, **_: object) -> SimpleNamespace: + async def scrape( + self, url: str, formats: list[str] | None = None, **_: object + ) -> SimpleNamespace: payload = cast(MockResponseData, self._service.scrape_response(url, formats)) data = cast(MockResponseData, payload.get("data", {})) metadata_payload = cast(MockResponseData, data.get("metadata", {})) diff --git a/tests/openapi_mocks.py b/tests/openapi_mocks.py index 0a418c7..b1c052e 100644 --- a/tests/openapi_mocks.py +++ b/tests/openapi_mocks.py @@ -158,7 +158,9 @@ class OpenAPISpec: return [] return [segment for segment in path.strip("/").split("/") if segment] - def find_operation(self, method: str, path: str) -> tuple[Mapping[str, Any] | None, dict[str, str]]: + def find_operation( + self, method: str, path: str + ) -> tuple[Mapping[str, Any] | None, dict[str, str]]: method = method.lower() normalized = "/" if path in {"", "/"} else "/" + path.strip("/") actual_segments = self._split_path(normalized) @@ -188,9 +190,7 @@ class OpenAPISpec: status: str | None = None, ) -> tuple[int, Any]: responses = operation.get("responses", {}) - target_status = status or ( - "200" if "200" in responses else next(iter(responses), "200") - ) + target_status = status or ("200" if "200" in responses else next(iter(responses), "200")) status_code = 200 try: status_code = int(target_status) @@ -222,7 +222,7 @@ class OpenAPISpec: base = self.generate({"$ref": ref}) if overrides is None or not isinstance(base, Mapping): return copy.deepcopy(base) - merged = copy.deepcopy(base) + merged: dict[str, Any] = dict(base) for key, value in overrides.items(): if isinstance(value, Mapping) and isinstance(merged.get(key), Mapping): merged[key] = self.generate_from_mapping( @@ -237,8 +237,8 @@ class OpenAPISpec: self, base: Mapping[str, Any], overrides: Mapping[str, Any], - ) -> Mapping[str, Any]: - result = copy.deepcopy(base) + ) -> dict[str, Any]: + result: dict[str, Any] = dict(base) for key, value in overrides.items(): if isinstance(value, Mapping) and isinstance(result.get(key), Mapping): result[key] = self.generate_from_mapping(result[key], value) @@ -380,18 +380,20 @@ class OpenWebUIMockService(OpenAPIMockService): def _knowledge_user_response(self, knowledge: Mapping[str, Any]) -> dict[str, Any]: payload = self.spec.generate_from_ref("#/components/schemas/KnowledgeUserResponse") - payload.update({ - "id": knowledge["id"], - "user_id": knowledge["user_id"], - "name": knowledge["name"], - "description": knowledge.get("description", ""), - "data": copy.deepcopy(knowledge.get("data", {})), - "meta": copy.deepcopy(knowledge.get("meta", {})), - "access_control": copy.deepcopy(knowledge.get("access_control", {})), - "created_at": knowledge.get("created_at", self._timestamp()), - "updated_at": knowledge.get("updated_at", self._timestamp()), - "files": copy.deepcopy(knowledge.get("files", [])), - }) + payload.update( + { + "id": knowledge["id"], + "user_id": knowledge["user_id"], + "name": knowledge["name"], + "description": knowledge.get("description", ""), + "data": copy.deepcopy(knowledge.get("data", {})), + "meta": copy.deepcopy(knowledge.get("meta", {})), + "access_control": copy.deepcopy(knowledge.get("access_control", {})), + "created_at": knowledge.get("created_at", self._timestamp()), + "updated_at": knowledge.get("updated_at", self._timestamp()), + "files": copy.deepcopy(knowledge.get("files", [])), + } + ) return payload def _knowledge_files_response(self, knowledge: Mapping[str, Any]) -> dict[str, Any]: @@ -435,7 +437,9 @@ class OpenWebUIMockService(OpenAPIMockService): # Delegate to parent return super().handle(method=method, path=path, json=json, params=params, files=files) - def _handle_knowledge_endpoints(self, method: str, segments: list[str], json: Mapping[str, Any] | None) -> tuple[int, Any]: + def _handle_knowledge_endpoints( + self, method: str, segments: list[str], json: Mapping[str, Any] | None + ) -> tuple[int, Any]: """Handle knowledge-related API endpoints.""" method_upper = method.upper() segment_count = len(segments) @@ -471,7 +475,9 @@ class OpenWebUIMockService(OpenAPIMockService): ) return 200, entry - def _handle_specific_knowledge(self, method_upper: str, segments: list[str], json: Mapping[str, Any] | None) -> tuple[int, Any]: + def _handle_specific_knowledge( + self, method_upper: str, segments: list[str], json: Mapping[str, Any] | None + ) -> tuple[int, Any]: """Handle endpoints for specific knowledge entries.""" knowledge_id = segments[3] knowledge = self._knowledge.get(knowledge_id) @@ -486,7 +492,9 @@ class OpenWebUIMockService(OpenAPIMockService): # File operations if segment_count == 6 and segments[4] == "file": - return self._handle_knowledge_file_operations(method_upper, segments[5], knowledge, json or {}) + return self._handle_knowledge_file_operations( + method_upper, segments[5], knowledge, json or {} + ) # Delete knowledge if method_upper == "DELETE" and segment_count == 5 and segments[4] == "delete": @@ -495,7 +503,13 @@ class OpenWebUIMockService(OpenAPIMockService): return 404, {"detail": "Knowledge operation not found"} - def _handle_knowledge_file_operations(self, method_upper: str, operation: str, knowledge: dict[str, Any], payload: Mapping[str, Any]) -> tuple[int, Any]: + def _handle_knowledge_file_operations( + self, + method_upper: str, + operation: str, + knowledge: dict[str, Any], + payload: Mapping[str, Any], + ) -> tuple[int, Any]: """Handle file operations on knowledge entries.""" if method_upper != "POST": return 405, {"detail": "Method not allowed"} @@ -507,7 +521,9 @@ class OpenWebUIMockService(OpenAPIMockService): return 404, {"detail": "File operation not found"} - def _add_file_to_knowledge(self, knowledge: dict[str, Any], payload: Mapping[str, Any]) -> tuple[int, Any]: + def _add_file_to_knowledge( + self, knowledge: dict[str, Any], payload: Mapping[str, Any] + ) -> tuple[int, Any]: """Add a file to a knowledge entry.""" file_id = str(payload.get("file_id", "")) if not file_id: @@ -524,15 +540,21 @@ class OpenWebUIMockService(OpenAPIMockService): return 200, self._knowledge_files_response(knowledge) - def _remove_file_from_knowledge(self, knowledge: dict[str, Any], payload: Mapping[str, Any]) -> tuple[int, Any]: + def _remove_file_from_knowledge( + self, knowledge: dict[str, Any], payload: Mapping[str, Any] + ) -> tuple[int, Any]: """Remove a file from a knowledge entry.""" file_id = str(payload.get("file_id", "")) - knowledge["files"] = [item for item in knowledge.get("files", []) if item.get("id") != file_id] + knowledge["files"] = [ + item for item in knowledge.get("files", []) if item.get("id") != file_id + ] knowledge["updated_at"] = self._timestamp() return 200, self._knowledge_files_response(knowledge) - def _handle_file_endpoints(self, method: str, segments: list[str], files: Any) -> tuple[int, Any]: + def _handle_file_endpoints( + self, method: str, segments: list[str], files: Any + ) -> tuple[int, Any]: """Handle file-related API endpoints.""" method_upper = method.upper() @@ -561,7 +583,9 @@ class OpenWebUIMockService(OpenAPIMockService): """Delete a file and remove from all knowledge entries.""" self._files.pop(file_id, None) for knowledge in self._knowledge.values(): - knowledge["files"] = [item for item in knowledge.get("files", []) if item.get("id") != file_id] + knowledge["files"] = [ + item for item in knowledge.get("files", []) if item.get("id") != file_id + ] return 200, {"deleted": True} @@ -620,10 +644,14 @@ class R2RMockService(OpenAPIMockService): None, ) - def _set_collection_document_ids(self, collection_id: str, document_id: str, *, add: bool) -> None: + def _set_collection_document_ids( + self, collection_id: str, document_id: str, *, add: bool + ) -> None: collection = self._collections.get(collection_id) if collection is None: - collection = self.create_collection(name=f"Collection {collection_id}", collection_id=collection_id) + collection = self.create_collection( + name=f"Collection {collection_id}", collection_id=collection_id + ) documents = collection.setdefault("documents", []) if add: if document_id not in documents: @@ -677,7 +705,9 @@ class R2RMockService(OpenAPIMockService): self._set_collection_document_ids(collection_id, document_id, add=False) return True - def append_document_metadata(self, document_id: str, metadata_list: list[dict[str, Any]]) -> dict[str, Any] | None: + def append_document_metadata( + self, document_id: str, metadata_list: list[dict[str, Any]] + ) -> dict[str, Any] | None: entry = self._documents.get(document_id) if entry is None: return None @@ -712,7 +742,9 @@ class R2RMockService(OpenAPIMockService): # Delegate to parent return super().handle(method=method, path=path, json=json, params=params, files=files) - def _handle_collections_endpoint(self, method_upper: str, json: Mapping[str, Any] | None) -> tuple[int, Any]: + def _handle_collections_endpoint( + self, method_upper: str, json: Mapping[str, Any] | None + ) -> tuple[int, Any]: """Handle collection-related endpoints.""" if method_upper == "GET": return self._list_collections() @@ -732,10 +764,12 @@ class R2RMockService(OpenAPIMockService): clone["document_count"] = len(clone.get("documents", [])) results.append(clone) - payload.update({ - "results": results, - "total_entries": len(results), - }) + payload.update( + { + "results": results, + "total_entries": len(results), + } + ) return 200, payload def _create_collection_endpoint(self, body: Mapping[str, Any]) -> tuple[int, Any]: @@ -748,7 +782,9 @@ class R2RMockService(OpenAPIMockService): payload.update({"results": entry}) return 200, payload - def _handle_documents_endpoint(self, method_upper: str, segments: list[str], json: Mapping[str, Any] | None, files: Any) -> tuple[int, Any]: + def _handle_documents_endpoint( + self, method_upper: str, segments: list[str], json: Mapping[str, Any] | None, files: Any + ) -> tuple[int, Any]: """Handle document-related endpoints.""" if method_upper == "POST" and len(segments) == 2: return self._create_document_endpoint(files) @@ -774,11 +810,13 @@ class R2RMockService(OpenAPIMockService): "#/components/schemas/R2RResults_IngestionResponse_" ) ingestion = self.spec.generate_from_ref("#/components/schemas/IngestionResponse") - ingestion.update({ - "message": ingestion.get("message") or "Ingestion task queued successfully.", - "document_id": document["id"], - "task_id": ingestion.get("task_id") or str(uuid4()), - }) + ingestion.update( + { + "message": ingestion.get("message") or "Ingestion task queued successfully.", + "document_id": document["id"], + "task_id": ingestion.get("task_id") or str(uuid4()), + } + ) response_payload["results"] = ingestion return 202, response_payload @@ -835,7 +873,11 @@ class R2RMockService(OpenAPIMockService): def _extract_content_from_files(self, raw_text_entry: Any) -> str: """Extract content from files entry.""" - if isinstance(raw_text_entry, tuple) and len(raw_text_entry) >= 2 and raw_text_entry[1] is not None: + if ( + isinstance(raw_text_entry, tuple) + and len(raw_text_entry) >= 2 + and raw_text_entry[1] is not None + ): return str(raw_text_entry[1]) return "" @@ -886,7 +928,7 @@ class R2RMockService(OpenAPIMockService): results = response_payload.get("results") if isinstance(results, Mapping): - results = copy.deepcopy(results) + results = dict(results) results.update({"success": success}) else: results = {"success": success} @@ -894,7 +936,9 @@ class R2RMockService(OpenAPIMockService): response_payload["results"] = results return (200 if success else 404), response_payload - def _update_document_metadata(self, doc_id: str, json: Mapping[str, Any] | None) -> tuple[int, Any]: + def _update_document_metadata( + self, doc_id: str, json: Mapping[str, Any] | None + ) -> tuple[int, Any]: """Update document metadata.""" metadata_list = [dict(item) for item in json] if isinstance(json, list) else [] document = self.append_document_metadata(doc_id, metadata_list) @@ -919,7 +963,15 @@ class FirecrawlMockService(OpenAPIMockService): def register_map_result(self, origin: str, links: list[str]) -> None: self._maps[origin] = list(links) - def register_page(self, url: str, *, markdown: str | None = None, html: str | None = None, metadata: Mapping[str, Any] | None = None, links: list[str] | None = None) -> None: + def register_page( + self, + url: str, + *, + markdown: str | None = None, + html: str | None = None, + metadata: Mapping[str, Any] | None = None, + links: list[str] | None = None, + ) -> None: self._pages[url] = { "markdown": markdown, "html": html, diff --git a/tests/unit/__pycache__/__init__.cpython-312.pyc b/tests/unit/__pycache__/__init__.cpython-312.pyc index f7bf11b..6eff6e7 100644 Binary files a/tests/unit/__pycache__/__init__.cpython-312.pyc and b/tests/unit/__pycache__/__init__.cpython-312.pyc differ diff --git a/tests/unit/automations/conftest.py b/tests/unit/automations/conftest.py new file mode 100644 index 0000000..7e05a3b --- /dev/null +++ b/tests/unit/automations/conftest.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +import pytest + +from ingest_pipeline.automations import AUTOMATION_TEMPLATES, get_automation_yaml_templates + + +@pytest.fixture(scope="function") +def automation_templates_copy() -> dict[str, str]: + return get_automation_yaml_templates() + + +@pytest.fixture(scope="function") +def automation_templates_reference() -> dict[str, str]: + return AUTOMATION_TEMPLATES diff --git a/tests/unit/automations/test_init.py b/tests/unit/automations/test_init.py new file mode 100644 index 0000000..8749439 --- /dev/null +++ b/tests/unit/automations/test_init.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +import pytest + + +@pytest.mark.parametrize( + ("template_key", "expected_snippet"), + [ + ("cancel_long_running", "Cancel Long Running Ingestion Flows"), + ("retry_failed", "Retry Failed Ingestion Flows"), + ("resource_monitoring", "Manage Work Pool Based on Resources"), + ], +) +def test_automation_templates_include_expected_entries( + automation_templates_copy, + template_key: str, + expected_snippet: str, +) -> None: + assert expected_snippet in automation_templates_copy[template_key] + + +def test_get_automation_yaml_templates_returns_copy( + automation_templates_copy, + automation_templates_reference, +) -> None: + automation_templates_copy["cancel_long_running"] = "overridden" + + assert "overridden" not in automation_templates_reference["cancel_long_running"] diff --git a/tests/unit/cli/__pycache__/__init__.cpython-312.pyc b/tests/unit/cli/__pycache__/__init__.cpython-312.pyc index 014e303..cf1b2fd 100644 Binary files a/tests/unit/cli/__pycache__/__init__.cpython-312.pyc and b/tests/unit/cli/__pycache__/__init__.cpython-312.pyc differ diff --git a/tests/unit/cli/__pycache__/test_main_cli.cpython-312-pytest-8.4.2.pyc b/tests/unit/cli/__pycache__/test_main_cli.cpython-312-pytest-8.4.2.pyc index 4fc9d8a..c700639 100644 Binary files a/tests/unit/cli/__pycache__/test_main_cli.cpython-312-pytest-8.4.2.pyc and b/tests/unit/cli/__pycache__/test_main_cli.cpython-312-pytest-8.4.2.pyc differ diff --git a/tests/unit/cli/test_main_cli.py b/tests/unit/cli/test_main_cli.py index 8cf49d2..ed5e3dd 100644 --- a/tests/unit/cli/test_main_cli.py +++ b/tests/unit/cli/test_main_cli.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import AsyncGenerator from types import SimpleNamespace from uuid import uuid4 @@ -22,7 +23,9 @@ from ingest_pipeline.core.models import ( ), ) @pytest.mark.asyncio -async def test_run_ingestion_collection_resolution(monkeypatch: pytest.MonkeyPatch, collection_arg: str | None, expected: str) -> None: +async def test_run_ingestion_collection_resolution( + monkeypatch: pytest.MonkeyPatch, collection_arg: str | None, expected: str +) -> None: recorded: dict[str, object] = {} async def fake_flow(**kwargs: object) -> IngestionResult: @@ -128,7 +131,7 @@ async def test_run_search_collects_results(monkeypatch: pytest.MonkeyPatch) -> N threshold: float = 0.7, *, collection_name: str | None = None, - ) -> object: + ) -> AsyncGenerator[SimpleNamespace, None]: yield SimpleNamespace(title="Title", content="Body text", score=0.91) dummy_settings = SimpleNamespace( diff --git a/tests/unit/flows/__pycache__/__init__.cpython-312.pyc b/tests/unit/flows/__pycache__/__init__.cpython-312.pyc index cf9577d..a06fa39 100644 Binary files a/tests/unit/flows/__pycache__/__init__.cpython-312.pyc and b/tests/unit/flows/__pycache__/__init__.cpython-312.pyc differ diff --git a/tests/unit/flows/__pycache__/test_ingestion_flow.cpython-312-pytest-8.4.2.pyc b/tests/unit/flows/__pycache__/test_ingestion_flow.cpython-312-pytest-8.4.2.pyc index 34f8b12..c10e2da 100644 Binary files a/tests/unit/flows/__pycache__/test_ingestion_flow.cpython-312-pytest-8.4.2.pyc and b/tests/unit/flows/__pycache__/test_ingestion_flow.cpython-312-pytest-8.4.2.pyc differ diff --git a/tests/unit/flows/__pycache__/test_scheduler.cpython-312-pytest-8.4.2.pyc b/tests/unit/flows/__pycache__/test_scheduler.cpython-312-pytest-8.4.2.pyc index f1851be..775e9cc 100644 Binary files a/tests/unit/flows/__pycache__/test_scheduler.cpython-312-pytest-8.4.2.pyc and b/tests/unit/flows/__pycache__/test_scheduler.cpython-312-pytest-8.4.2.pyc differ diff --git a/tests/unit/flows/test_ingestion_flow.py b/tests/unit/flows/test_ingestion_flow.py index 80c01bd..252d6d3 100644 --- a/tests/unit/flows/test_ingestion_flow.py +++ b/tests/unit/flows/test_ingestion_flow.py @@ -6,9 +6,11 @@ from unittest.mock import AsyncMock import pytest from ingest_pipeline.core.models import ( + Document, IngestionJob, IngestionSource, StorageBackend, + StorageConfig, ) from ingest_pipeline.flows import ingestion from ingest_pipeline.flows.ingestion import ( @@ -18,6 +20,7 @@ from ingest_pipeline.flows.ingestion import ( filter_existing_documents_task, ingest_documents_task, ) +from ingest_pipeline.storage.base import BaseStorage @pytest.mark.parametrize( @@ -52,15 +55,34 @@ async def test_filter_existing_documents_task_filters_known_urls( new_url = "https://new.example.com" existing_id = str(FirecrawlIngestor.compute_document_id(existing_url)) - class StubStorage: - display_name = "stub-storage" + class StubStorage(BaseStorage): + def __init__(self) -> None: + super().__init__(StorageConfig(backend=StorageBackend.WEAVIATE, endpoint="http://test.local")) + + @property + def display_name(self) -> str: + return "stub-storage" + + async def initialize(self) -> None: + pass + + async def store(self, document: Document, *, collection_name: str | None = None) -> str: + return "stub-id" + + async def store_batch( + self, documents: list[Document], *, collection_name: str | None = None + ) -> list[str]: + return ["stub-id"] * len(documents) + + async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool: + return True async def check_exists( self, document_id: str, *, collection_name: str | None = None, - stale_after_days: int, + stale_after_days: int = 30, ) -> bool: return document_id == existing_id @@ -102,7 +124,9 @@ async def test_ingest_documents_task_invokes_helpers( ) -> object: return storage_sentinel - async def fake_create_ingestor(job_arg: IngestionJob, config_block_name: str | None = None) -> object: + async def fake_create_ingestor( + job_arg: IngestionJob, config_block_name: str | None = None + ) -> object: return ingestor_sentinel monkeypatch.setattr(ingestion, "_create_ingestor", fake_create_ingestor) diff --git a/tests/unit/flows/test_scheduler.py b/tests/unit/flows/test_scheduler.py index ed3b8ee..6ea8e1c 100644 --- a/tests/unit/flows/test_scheduler.py +++ b/tests/unit/flows/test_scheduler.py @@ -4,6 +4,8 @@ from datetime import timedelta from types import SimpleNamespace import pytest +from prefect.deployments.runner import RunnerDeployment +from prefect.schedules import Cron, Interval from ingest_pipeline.flows import scheduler @@ -17,7 +19,6 @@ def test_create_scheduled_deployment_cron(monkeypatch: pytest.MonkeyPatch) -> No captured |= kwargs return SimpleNamespace(**kwargs) - monkeypatch.setattr(scheduler, "create_ingestion_flow", DummyFlow()) deployment = scheduler.create_scheduled_deployment( @@ -28,8 +29,14 @@ def test_create_scheduled_deployment_cron(monkeypatch: pytest.MonkeyPatch) -> No cron_expression="0 * * * *", ) - assert captured["schedule"].cron == "0 * * * *" - assert captured["parameters"]["source_type"] == "web" + schedule = captured["schedule"] + # Check that it's a cron schedule by verifying it has a cron attribute + assert hasattr(schedule, "cron") + assert schedule.cron == "0 * * * *" + + parameters = captured["parameters"] + assert isinstance(parameters, dict) + assert parameters["source_type"] == "web" assert deployment.tags == ["web", "weaviate"] @@ -42,7 +49,6 @@ def test_create_scheduled_deployment_interval(monkeypatch: pytest.MonkeyPatch) - captured |= kwargs return SimpleNamespace(**kwargs) - monkeypatch.setattr(scheduler, "create_ingestion_flow", DummyFlow()) deployment = scheduler.create_scheduled_deployment( @@ -55,7 +61,11 @@ def test_create_scheduled_deployment_interval(monkeypatch: pytest.MonkeyPatch) - tags=["custom"], ) - assert captured["schedule"].interval == timedelta(minutes=15) + schedule = captured["schedule"] + # Check that it's an interval schedule by verifying it has an interval attribute + assert hasattr(schedule, "interval") + assert schedule.interval == timedelta(minutes=15) + assert captured["tags"] == ["custom"] assert deployment.parameters["storage_backend"] == "open_webui" @@ -63,12 +73,13 @@ def test_create_scheduled_deployment_interval(monkeypatch: pytest.MonkeyPatch) - def test_serve_deployments_invokes_prefect(monkeypatch: pytest.MonkeyPatch) -> None: called: dict[str, object] = {} - def fake_serve(*deployments: object, limit: int) -> None: + def fake_serve(*deployments: RunnerDeployment, limit: int) -> None: called["deployments"] = deployments called["limit"] = limit monkeypatch.setattr(scheduler, "prefect_serve", fake_serve) + # Create a mock deployment using SimpleNamespace to avoid Prefect complexity deployment = SimpleNamespace(name="only") scheduler.serve_deployments([deployment]) diff --git a/tests/unit/ingestors/__pycache__/__init__.cpython-312.pyc b/tests/unit/ingestors/__pycache__/__init__.cpython-312.pyc index cdae77c..fa23fc7 100644 Binary files a/tests/unit/ingestors/__pycache__/__init__.cpython-312.pyc and b/tests/unit/ingestors/__pycache__/__init__.cpython-312.pyc differ diff --git a/tests/unit/ingestors/__pycache__/test_firecrawl_ingestor.cpython-312-pytest-8.4.2.pyc b/tests/unit/ingestors/__pycache__/test_firecrawl_ingestor.cpython-312-pytest-8.4.2.pyc index 5e4e769..94ca013 100644 Binary files a/tests/unit/ingestors/__pycache__/test_firecrawl_ingestor.cpython-312-pytest-8.4.2.pyc and b/tests/unit/ingestors/__pycache__/test_firecrawl_ingestor.cpython-312-pytest-8.4.2.pyc differ diff --git a/tests/unit/ingestors/__pycache__/test_repomix_ingestor.cpython-312-pytest-8.4.2.pyc b/tests/unit/ingestors/__pycache__/test_repomix_ingestor.cpython-312-pytest-8.4.2.pyc index 10b949b..1a9cfae 100644 Binary files a/tests/unit/ingestors/__pycache__/test_repomix_ingestor.cpython-312-pytest-8.4.2.pyc and b/tests/unit/ingestors/__pycache__/test_repomix_ingestor.cpython-312-pytest-8.4.2.pyc differ diff --git a/tests/unit/ingestors/test_repomix_ingestor.py b/tests/unit/ingestors/test_repomix_ingestor.py index 3090a6a..41f05a3 100644 --- a/tests/unit/ingestors/test_repomix_ingestor.py +++ b/tests/unit/ingestors/test_repomix_ingestor.py @@ -9,7 +9,7 @@ from ingest_pipeline.ingestors.repomix import RepomixIngestor @pytest.mark.parametrize( ("content", "expected_keys"), ( - ("## File: src/app.py\nprint(\"hi\")", ["src/app.py"]), + ('## File: src/app.py\nprint("hi")', ["src/app.py"]), ("plain content without markers", ["repository"]), ), ) @@ -44,7 +44,7 @@ def test_chunk_content_respects_max_size( ("file_path", "content", "expected"), ( ("src/app.py", "def feature():\n return True", "python"), - ("scripts/run", "#!/usr/bin/env python\nprint(\"ok\")", "python"), + ("scripts/run", '#!/usr/bin/env python\nprint("ok")', "python"), ("documentation.md", "# Title", "markdown"), ("unknown.ext", "text", None), ), @@ -81,5 +81,8 @@ def test_create_document_enriches_metadata() -> None: assert document.metadata["repository_name"] == "demo" assert document.metadata["branch_name"] == "main" assert document.metadata["commit_hash"] == "deadbeef" - assert document.metadata["title"].endswith("(chunk 1)") + + title = document.metadata["title"] + assert title is not None + assert title.endswith("(chunk 1)") assert document.collection == job.storage_backend.value diff --git a/tests/unit/storage/__pycache__/__init__.cpython-312.pyc b/tests/unit/storage/__pycache__/__init__.cpython-312.pyc index 80dba8b..a0f488d 100644 Binary files a/tests/unit/storage/__pycache__/__init__.cpython-312.pyc and b/tests/unit/storage/__pycache__/__init__.cpython-312.pyc differ diff --git a/tests/unit/storage/__pycache__/test_base_storage.cpython-312-pytest-8.4.2.pyc b/tests/unit/storage/__pycache__/test_base_storage.cpython-312-pytest-8.4.2.pyc index f7e88c5..122b065 100644 Binary files a/tests/unit/storage/__pycache__/test_base_storage.cpython-312-pytest-8.4.2.pyc and b/tests/unit/storage/__pycache__/test_base_storage.cpython-312-pytest-8.4.2.pyc differ diff --git a/tests/unit/storage/__pycache__/test_openwebui.cpython-312-pytest-8.4.2.pyc b/tests/unit/storage/__pycache__/test_openwebui.cpython-312-pytest-8.4.2.pyc index 051c1f3..e3331f0 100644 Binary files a/tests/unit/storage/__pycache__/test_openwebui.cpython-312-pytest-8.4.2.pyc and b/tests/unit/storage/__pycache__/test_openwebui.cpython-312-pytest-8.4.2.pyc differ diff --git a/tests/unit/storage/__pycache__/test_r2r_helpers.cpython-312-pytest-8.4.2.pyc b/tests/unit/storage/__pycache__/test_r2r_helpers.cpython-312-pytest-8.4.2.pyc index 6728bad..61b9b51 100644 Binary files a/tests/unit/storage/__pycache__/test_r2r_helpers.cpython-312-pytest-8.4.2.pyc and b/tests/unit/storage/__pycache__/test_r2r_helpers.cpython-312-pytest-8.4.2.pyc differ diff --git a/tests/unit/storage/__pycache__/test_weaviate_helpers.cpython-312-pytest-8.4.2.pyc b/tests/unit/storage/__pycache__/test_weaviate_helpers.cpython-312-pytest-8.4.2.pyc index 8c7cdf2..7d35b90 100644 Binary files a/tests/unit/storage/__pycache__/test_weaviate_helpers.cpython-312-pytest-8.4.2.pyc and b/tests/unit/storage/__pycache__/test_weaviate_helpers.cpython-312-pytest-8.4.2.pyc differ diff --git a/tests/unit/storage/conftest.py b/tests/unit/storage/conftest.py new file mode 100644 index 0000000..eed606e --- /dev/null +++ b/tests/unit/storage/conftest.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import importlib +import sys +import types + +import pytest + + +@pytest.fixture(scope="function") +def storage_module() -> types.ModuleType: + return importlib.import_module("ingest_pipeline.storage") + + +@pytest.fixture(scope="function") +def storage_module_without_r2r(request: pytest.FixtureRequest) -> types.ModuleType: + original_package = sys.modules.pop("ingest_pipeline.storage.r2r", None) + original_storage = sys.modules.pop("ingest_pipeline.storage.r2r.storage", None) + + dummy_package = types.ModuleType("ingest_pipeline.storage.r2r") + dummy_package.__path__ = [] + dummy_storage = types.ModuleType("ingest_pipeline.storage.r2r.storage") + + sys.modules["ingest_pipeline.storage.r2r"] = dummy_package + sys.modules["ingest_pipeline.storage.r2r.storage"] = dummy_storage + + module = importlib.reload(importlib.import_module("ingest_pipeline.storage")) + + def _restore() -> None: + if original_package is not None: + sys.modules["ingest_pipeline.storage.r2r"] = original_package + else: + sys.modules.pop("ingest_pipeline.storage.r2r", None) + if original_storage is not None: + sys.modules["ingest_pipeline.storage.r2r.storage"] = original_storage + else: + sys.modules.pop("ingest_pipeline.storage.r2r.storage", None) + importlib.reload(importlib.import_module("ingest_pipeline.storage")) + + request.addfinalizer(_restore) + return module diff --git a/tests/unit/storage/test_base_storage.py b/tests/unit/storage/test_base_storage.py index 6e9a860..7fffc8b 100644 --- a/tests/unit/storage/test_base_storage.py +++ b/tests/unit/storage/test_base_storage.py @@ -69,7 +69,9 @@ async def test_check_exists_uses_staleness(document_factory, storage_config, del """Return True only when document timestamp is within freshness window.""" timestamp = datetime.now(UTC) - timedelta(days=delta_days) - document = document_factory(content=f"doc-{delta_days}", metadata_updates={"timestamp": timestamp}) + document = document_factory( + content=f"doc-{delta_days}", metadata_updates={"timestamp": timestamp} + ) storage = StubStorage(storage_config, result=document) outcome = await storage.check_exists("identifier", stale_after_days=30) diff --git a/tests/unit/storage/test_openwebui.py b/tests/unit/storage/test_openwebui.py index 4055bbf..42f0b8e 100644 --- a/tests/unit/storage/test_openwebui.py +++ b/tests/unit/storage/test_openwebui.py @@ -99,7 +99,9 @@ async def test_store_batch_handles_multiple_documents( file_ids = await storage.store_batch([first, second]) assert len(file_ids) == 2 - files_payloads: list[Any] = [request["files"] for request in httpx_stub.requests if request["method"] == "POST"] + files_payloads: list[Any] = [ + request["files"] for request in httpx_stub.requests if request["method"] == "POST" + ] assert any(payload is not None for payload in files_payloads) knowledge_entry = openwebui_service.find_knowledge_by_name(storage_config.collection_name) assert knowledge_entry is not None diff --git a/tests/unit/storage/test_package_init.py b/tests/unit/storage/test_package_init.py new file mode 100644 index 0000000..bd8424d --- /dev/null +++ b/tests/unit/storage/test_package_init.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +import pytest + + +@pytest.mark.parametrize( + "symbol_name", + ("BaseStorage", "WeaviateStorage", "OpenWebUIStorage", "R2RStorage"), +) +def test_storage_exports_expected_symbols(storage_module, symbol_name: str) -> None: + assert symbol_name in storage_module.__all__ + + +def test_storage_r2r_available_by_default(storage_module) -> None: + assert storage_module.R2RStorage is not None + + +def test_storage_optional_import_fallback(storage_module_without_r2r) -> None: + assert storage_module_without_r2r.R2RStorage is None diff --git a/tests/unit/storage/test_r2r_helpers.py b/tests/unit/storage/test_r2r_helpers.py index 6ef2af7..e6220a9 100644 --- a/tests/unit/storage/test_r2r_helpers.py +++ b/tests/unit/storage/test_r2r_helpers.py @@ -5,6 +5,7 @@ from datetime import UTC, datetime from types import SimpleNamespace from typing import Any, Self +import httpx import pytest from ingest_pipeline.core.models import StorageConfig @@ -23,7 +24,6 @@ def r2r_client_stub( monkeypatch: pytest.MonkeyPatch, r2r_service, ) -> object: - class DummyR2RException(Exception): def __init__(self, message: str, status_code: int | None = None) -> None: super().__init__(message) @@ -39,10 +39,13 @@ def r2r_client_stub( def raise_for_status(self) -> None: if self.status_code >= 400: - from httpx import HTTPStatusError - raise HTTPStatusError("HTTP error", request=None, response=self) - - + # Create minimal mock request and response for HTTPStatusError + mock_request = httpx.Request("GET", "http://test.local") + mock_response = httpx.Response( + status_code=self.status_code, + request=mock_request, + ) + raise httpx.HTTPStatusError("HTTP error", request=mock_request, response=mock_response) class MockAsyncClient: def __init__(self, service: Any) -> None: @@ -53,15 +56,23 @@ def r2r_client_stub( # Return existing collections collections = [] for collection_id, collection_data in self._service._collections.items(): - collections.append({ - "id": collection_id, - "name": collection_data["name"], - "description": collection_data.get("description", ""), - }) + collections.append( + { + "id": collection_id, + "name": collection_data["name"], + "description": collection_data.get("description", ""), + } + ) return MockResponse({"results": collections}) return MockResponse({}) - async def post(self, url: str, *, json: dict[str, Any] | None = None, files: dict[str, Any] | None = None) -> MockResponse: + async def post( + self, + url: str, + *, + json: dict[str, Any] | None = None, + files: dict[str, Any] | None = None, + ) -> MockResponse: if "/v3/collections" in url and json: return self._handle_collection_creation(json) if "/v3/documents" in url and files: @@ -76,13 +87,15 @@ def r2r_client_stub( collection_id=new_collection_id, description=json.get("description", ""), ) - return MockResponse({ - "results": { - "id": new_collection_id, - "name": json["name"], - "description": json.get("description", ""), + return MockResponse( + { + "results": { + "id": new_collection_id, + "name": json["name"], + "description": json.get("description", ""), + } } - }) + ) def _handle_document_creation(self, files: dict[str, Any]) -> MockResponse: """Handle document creation via POST with files.""" @@ -102,12 +115,14 @@ def r2r_client_stub( # Update collection document count if needed self._update_collection_document_count(files) - return MockResponse({ - "results": { - "document_id": document_id, - "message": "Document created successfully", + return MockResponse( + { + "results": { + "document_id": document_id, + "message": "Document created successfully", + } } - }) + ) def _extract_document_id(self, files: dict[str, Any]) -> str: """Extract document ID from files.""" @@ -120,6 +135,7 @@ def r2r_client_stub( def _extract_metadata(self, files: dict[str, Any]) -> dict[str, Any]: """Extract and parse metadata from files.""" import json as json_lib + metadata_str = files.get("metadata", (None, "{}"))[1] return json_lib.loads(metadata_str) if metadata_str else {} @@ -137,9 +153,15 @@ def r2r_client_stub( """Parse collection IDs from files entry.""" import json as json_lib - collection_ids_str = collection_ids[1] if isinstance(collection_ids, tuple) else collection_ids + collection_ids_str = ( + collection_ids[1] if isinstance(collection_ids, tuple) else collection_ids + ) try: - collection_list = json_lib.loads(collection_ids_str) if isinstance(collection_ids_str, str) else collection_ids_str + collection_list = ( + json_lib.loads(collection_ids_str) + if isinstance(collection_ids_str, str) + else collection_ids_str + ) if isinstance(collection_list, list) and collection_list: first_collection = collection_list[0] if isinstance(first_collection, dict) and "id" in first_collection: @@ -159,7 +181,6 @@ def r2r_client_stub( async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: return None - class DocumentsAPI: def __init__(self, service: Any) -> None: self._service = service @@ -186,10 +207,7 @@ def r2r_client_stub( self._service = service async def search(self, query: str, search_settings: Mapping[str, Any]) -> dict[str, Any]: - results = [ - {"document_id": doc_id, "score": 1.0} - for doc_id in self._service._documents - ] + results = [{"document_id": doc_id, "score": 1.0} for doc_id in self._service._documents] return {"results": results} class DummyClient: diff --git a/tests/unit/tui/__pycache__/__init__.cpython-312.pyc b/tests/unit/tui/__pycache__/__init__.cpython-312.pyc index 0316e93..4568bc8 100644 Binary files a/tests/unit/tui/__pycache__/__init__.cpython-312.pyc and b/tests/unit/tui/__pycache__/__init__.cpython-312.pyc differ diff --git a/tests/unit/tui/__pycache__/test_dashboard_screen.cpython-312-pytest-8.4.2.pyc b/tests/unit/tui/__pycache__/test_dashboard_screen.cpython-312-pytest-8.4.2.pyc index 0e464ef..b0e86f2 100644 Binary files a/tests/unit/tui/__pycache__/test_dashboard_screen.cpython-312-pytest-8.4.2.pyc and b/tests/unit/tui/__pycache__/test_dashboard_screen.cpython-312-pytest-8.4.2.pyc differ diff --git a/tests/unit/tui/__pycache__/test_storage_manager.cpython-312-pytest-8.4.2.pyc b/tests/unit/tui/__pycache__/test_storage_manager.cpython-312-pytest-8.4.2.pyc index 4ce1491..d325884 100644 Binary files a/tests/unit/tui/__pycache__/test_storage_manager.cpython-312-pytest-8.4.2.pyc and b/tests/unit/tui/__pycache__/test_storage_manager.cpython-312-pytest-8.4.2.pyc differ diff --git a/tests/unit/tui/test_dashboard_screen.py b/tests/unit/tui/test_dashboard_screen.py index a919494..8ae15f2 100644 --- a/tests/unit/tui/test_dashboard_screen.py +++ b/tests/unit/tui/test_dashboard_screen.py @@ -12,6 +12,7 @@ from ingest_pipeline.cli.tui.screens.documents import DocumentManagementScreen from ingest_pipeline.cli.tui.widgets.tables import EnhancedDataTable from ingest_pipeline.core.models import Document, StorageBackend, StorageConfig from ingest_pipeline.storage.base import BaseStorage +from ingest_pipeline.storage.types import DocumentInfo if TYPE_CHECKING: from ingest_pipeline.cli.tui.utils.storage_manager import StorageManager @@ -35,7 +36,9 @@ class StorageManagerStub: backends: dict[StorageBackend, BaseStorage] is_initialized: bool - def __init__(self, collections: list[CollectionInfo], backends: dict[StorageBackend, BaseStorage]) -> None: + def __init__( + self, collections: list[CollectionInfo], backends: dict[StorageBackend, BaseStorage] + ) -> None: self._collections = collections self._available = list(backends.keys()) self.backends = backends @@ -71,13 +74,19 @@ class PassiveStorage(BaseStorage): async def initialize(self) -> None: return None - async def store(self, document: Document, *, collection_name: str | None = None) -> str: # pragma: no cover - unused + async def store( + self, document: Document, *, collection_name: str | None = None + ) -> str: # pragma: no cover - unused raise NotImplementedError - async def store_batch(self, documents: list[Document], *, collection_name: str | None = None) -> list[str]: # pragma: no cover - unused + async def store_batch( + self, documents: list[Document], *, collection_name: str | None = None + ) -> list[str]: # pragma: no cover - unused raise NotImplementedError - async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool: # pragma: no cover - unused + async def delete( + self, document_id: str, *, collection_name: str | None = None + ) -> bool: # pragma: no cover - unused raise NotImplementedError @@ -165,6 +174,7 @@ async def test_collection_overview_table_reflects_collections() -> None: assert screen.total_documents == 1292 assert screen.active_backends == 2 + class DocumentStorageStub(BaseStorage): def __init__(self) -> None: super().__init__( @@ -178,16 +188,24 @@ class DocumentStorageStub(BaseStorage): async def initialize(self) -> None: return None - async def store(self, document: Document, *, collection_name: str | None = None) -> str: # pragma: no cover - unused + async def store( + self, document: Document, *, collection_name: str | None = None + ) -> str: # pragma: no cover - unused raise NotImplementedError - async def store_batch(self, documents: list[Document], *, collection_name: str | None = None) -> list[str]: # pragma: no cover - unused + async def store_batch( + self, documents: list[Document], *, collection_name: str | None = None + ) -> list[str]: # pragma: no cover - unused raise NotImplementedError - async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool: # pragma: no cover - unused + async def delete( + self, document_id: str, *, collection_name: str | None = None + ) -> bool: # pragma: no cover - unused raise NotImplementedError - async def list_documents(self, limit: int = 100, offset: int = 0, *, collection_name: str | None = None) -> list[dict[str, object]]: + async def list_documents( + self, limit: int = 100, offset: int = 0, *, collection_name: str | None = None + ) -> list[DocumentInfo]: return [ { "id": "doc-1234567890", diff --git a/tests/unit/tui/test_storage_manager.py b/tests/unit/tui/test_storage_manager.py index f95833d..e506f4a 100644 --- a/tests/unit/tui/test_storage_manager.py +++ b/tests/unit/tui/test_storage_manager.py @@ -4,14 +4,23 @@ from types import SimpleNamespace import pytest -from ingest_pipeline.cli.tui.utils.storage_manager import MultiStorageAdapter, StorageManager +from ingest_pipeline.cli.tui.utils.storage_manager import ( + MultiStorageAdapter, + StorageCapabilities, + StorageManager, +) +from typing import cast + +from ingest_pipeline.config.settings import Settings from ingest_pipeline.core.exceptions import StorageError from ingest_pipeline.core.models import Document, StorageBackend, StorageConfig from ingest_pipeline.storage.base import BaseStorage class StubStorage(BaseStorage): - def __init__(self, config: StorageConfig, *, documents: list[Document] | None = None, fail: bool = False) -> None: + def __init__( + self, config: StorageConfig, *, documents: list[Document] | None = None, fail: bool = False + )) -> None: super().__init__(config) self.documents = documents or [] self.fail = fail @@ -26,7 +35,9 @@ class StubStorage(BaseStorage): raise RuntimeError("store failed") return f"{self.config.backend.value}-single" - async def store_batch(self, documents: list[Document], *, collection_name: str | None = None) -> list[str]: + async def store_batch( + self, documents: list[Document], *, collection_name: str | None = None + )) -> list[str]: self.stored.extend(documents) if self.fail: raise RuntimeError("batch failed") @@ -50,7 +61,7 @@ class StubStorage(BaseStorage): threshold: float = 0.7, *, collection_name: str | None = None, - ): + )): for document in self.documents: yield document @@ -58,18 +69,58 @@ class StubStorage(BaseStorage): return None +class CollectionStubStorage(StubStorage): + def __init__( + self, + config: StorageConfig, + *, + collections: list[str], + counts: dict[str, int], + )) -> None: + super().__init__(config) + self.collections = collections + self.counts = counts + + async def list_collections(self) -> list[str]: + return self.collections + + async def count(self, *, collection_name: str | None = None) -> int: + if collection_name is None: + raise ValueError("collection name required") + return self.counts[collection_name] + + +class FailingStatusStorage(StubStorage): + async def list_collections(self) -> list[str]: + raise RuntimeError("status unavailable") + + +class ClosableStubStorage(StubStorage): + def __init__(self, config: StorageConfig) -> None: + super().__init__(config) + self.closed = False + + async def close(self) -> None: + self.closed = True + + +class FailingCloseStorage(StubStorage): + async def close(self) -> None: + raise RuntimeError("close failure") + + @pytest.mark.asyncio async def test_multi_storage_adapter_reports_replication_failure(document_factory) -> None: primary_config = StorageConfig( backend=StorageBackend.WEAVIATE, endpoint="http://weaviate.local", collection_name="primary", - ) + )) secondary_config = StorageConfig( backend=StorageBackend.OPEN_WEBUI, endpoint="http://chat.local", collection_name="secondary", - ) + )) primary = StubStorage(primary_config) secondary = StubStorage(secondary_config, fail=True) @@ -82,33 +133,33 @@ async def test_multi_storage_adapter_reports_replication_failure(document_factor def test_storage_manager_build_multi_storage_adapter_deduplicates(document_factory) -> None: - settings = SimpleNamespace( + settings = cast(Settings, SimpleNamespace( weaviate_endpoint="http://weaviate.local", weaviate_api_key=None, openwebui_endpoint="http://chat.local", openwebui_api_key=None, r2r_endpoint=None, r2r_api_key=None, - ) + ))) manager = StorageManager(settings) weaviate_config = StorageConfig( backend=StorageBackend.WEAVIATE, endpoint="http://weaviate.local", collection_name="primary", - ) + )) openwebui_config = StorageConfig( backend=StorageBackend.OPEN_WEBUI, endpoint="http://chat.local", collection_name="secondary", - ) + )) manager.backends[StorageBackend.WEAVIATE] = StubStorage(weaviate_config) manager.backends[StorageBackend.OPEN_WEBUI] = StubStorage(openwebui_config) adapter = manager.build_multi_storage_adapter( [StorageBackend.WEAVIATE, StorageBackend.WEAVIATE, StorageBackend.OPEN_WEBUI] - ) + )) assert len(adapter._storages) == 2 assert adapter._storages[0].config.backend == StorageBackend.WEAVIATE @@ -116,14 +167,14 @@ def test_storage_manager_build_multi_storage_adapter_deduplicates(document_facto def test_storage_manager_build_multi_storage_adapter_missing_backend() -> None: - settings = SimpleNamespace( + settings = cast(Settings, SimpleNamespace( weaviate_endpoint="http://weaviate.local", weaviate_api_key=None, openwebui_endpoint="http://chat.local", openwebui_api_key=None, r2r_endpoint=None, r2r_api_key=None, - ) + )) manager = StorageManager(settings) with pytest.raises(ValueError): @@ -132,41 +183,294 @@ def test_storage_manager_build_multi_storage_adapter_missing_backend() -> None: @pytest.mark.asyncio async def test_storage_manager_search_across_backends_groups_results(document_factory) -> None: - settings = SimpleNamespace( + settings = cast(Settings, SimpleNamespace( weaviate_endpoint="http://weaviate.local", weaviate_api_key=None, openwebui_endpoint="http://chat.local", openwebui_api_key=None, r2r_endpoint=None, r2r_api_key=None, - ) + )) manager = StorageManager(settings) - document_weaviate = document_factory(content="alpha", metadata_updates={"source_url": "https://alpha"}) - document_openwebui = document_factory(content="beta", metadata_updates={"source_url": "https://beta"}) + document_weaviate = document_factory( + content="alpha", metadata_updates={"source_url": "https://alpha"} + )) + document_openwebui = document_factory( + content="beta", metadata_updates={"source_url": "https://beta"} + )) manager.backends[StorageBackend.WEAVIATE] = StubStorage( StorageConfig( backend=StorageBackend.WEAVIATE, endpoint="http://weaviate.local", collection_name="primary", - ), + )), documents=[document_weaviate], - ) + )) manager.backends[StorageBackend.OPEN_WEBUI] = StubStorage( StorageConfig( backend=StorageBackend.OPEN_WEBUI, endpoint="http://chat.local", collection_name="secondary", - ), + )), documents=[document_openwebui], - ) + )) results = await manager.search_across_backends( "query", limit=5, backends=[StorageBackend.WEAVIATE, StorageBackend.OPEN_WEBUI], - ) + )) assert results[StorageBackend.WEAVIATE][0].content == "alpha" assert results[StorageBackend.OPEN_WEBUI][0].content == "beta" + + +@pytest.mark.asyncio +async def test_multi_storage_adapter_store_batch_replicates_to_all_backends(document_factory) -> None: + primary_config = StorageConfig( + backend=StorageBackend.WEAVIATE, + endpoint="http://weaviate.local", + collection_name="primary", + )) + secondary_config = StorageConfig( + backend=StorageBackend.OPEN_WEBUI, + endpoint="http://chat.local", + collection_name="secondary", + )) + + primary = StubStorage(primary_config) + secondary = StubStorage(secondary_config) + adapter = MultiStorageAdapter([primary, secondary, secondary]) + + first_document = document_factory(content="first") + second_document = document_factory(content="second") + + document_ids = await adapter.store_batch([first_document, second_document]) + + assert document_ids == ["weaviate-0", "weaviate-1"] + assert adapter._storages[0] is primary + assert primary.stored[0].content == "first" + assert secondary.stored[1].content == "second" + + +@pytest.mark.asyncio +async def test_multi_storage_adapter_delete_reports_secondary_failures() -> None: + primary_config = StorageConfig( + backend=StorageBackend.WEAVIATE, + endpoint="http://weaviate.local", + collection_name="primary", + )) + secondary_config = StorageConfig( + backend=StorageBackend.OPEN_WEBUI, + endpoint="http://chat.local", + collection_name="secondary", + )) + + primary = StubStorage(primary_config) + secondary = StubStorage(secondary_config, fail=True) + adapter = MultiStorageAdapter([primary, secondary]) + + with pytest.raises(StorageError) as exc_info: + await adapter.delete("identifier") + + assert "open_webui" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_storage_manager_initialize_all_backends_registers_capabilities(monkeypatch) -> None: + settings = cast(Settings, SimpleNamespace( + weaviate_endpoint="http://weaviate.local", + weaviate_api_key="key", + openwebui_endpoint="http://chat.local", + openwebui_api_key="token", + r2r_endpoint="http://r2r.local", + r2r_api_key="secret", + )) + manager = StorageManager(settings) + + monkeypatch.setattr( + "ingest_pipeline.cli.tui.utils.storage_manager.WeaviateStorage", + StubStorage, + )) + monkeypatch.setattr( + "ingest_pipeline.cli.tui.utils.storage_manager.OpenWebUIStorage", + StubStorage, + )) + monkeypatch.setattr( + "ingest_pipeline.cli.tui.utils.storage_manager.R2RStorage", + StubStorage, + )) + + results = await manager.initialize_all_backends() + + assert results[StorageBackend.WEAVIATE] is True + assert results[StorageBackend.OPEN_WEBUI] is True + assert results[StorageBackend.R2R] is True + assert manager.get_available_backends() == [ + StorageBackend.WEAVIATE, + StorageBackend.OPEN_WEBUI, + StorageBackend.R2R, + ] + assert manager.capabilities[StorageBackend.WEAVIATE] == StorageCapabilities.VECTOR_SEARCH + assert manager.capabilities[StorageBackend.OPEN_WEBUI] == StorageCapabilities.KNOWLEDGE_BASE + assert manager.capabilities[StorageBackend.R2R] == StorageCapabilities.FULL_FEATURED + assert manager.supports_advanced_features(StorageBackend.R2R) is True + assert manager.supports_advanced_features(StorageBackend.WEAVIATE) is False + assert manager.is_initialized is True + assert isinstance(manager.get_backend(StorageBackend.R2R), StubStorage) + + +@pytest.mark.asyncio +async def test_storage_manager_initialize_all_backends_handles_missing_config() -> None: + settings = cast(Settings, SimpleNamespace( + weaviate_endpoint=None, + weaviate_api_key=None, + openwebui_endpoint="http://chat.local", + openwebui_api_key=None, + r2r_endpoint=None, + r2r_api_key=None, + )) + manager = StorageManager(settings) + + results = await manager.initialize_all_backends() + + assert results[StorageBackend.WEAVIATE] is False + assert results[StorageBackend.OPEN_WEBUI] is False + assert results[StorageBackend.R2R] is False + assert manager.get_available_backends() == [] + assert manager.is_initialized is True + + +@pytest.mark.asyncio +async def test_storage_manager_get_all_collections_merges_counts_and_backends() -> None: + settings = cast(Settings, SimpleNamespace( + weaviate_endpoint="http://weaviate.local", + weaviate_api_key=None, + openwebui_endpoint="http://chat.local", + openwebui_api_key=None, + r2r_endpoint=None, + r2r_api_key=None, + )) + manager = StorageManager(settings) + + weaviate_storage = CollectionStubStorage( + StorageConfig( + backend=StorageBackend.WEAVIATE, + endpoint="http://weaviate.local", + collection_name="shared", + )), + collections=["shared", ""], + counts={"shared": 2}, + )) + openwebui_storage = CollectionStubStorage( + StorageConfig( + backend=StorageBackend.OPEN_WEBUI, + endpoint="http://chat.local", + collection_name="secondary", + )), + collections=["shared"], + counts={"shared": -1}, + )) + manager.backends = { + StorageBackend.WEAVIATE: weaviate_storage, + StorageBackend.OPEN_WEBUI: openwebui_storage, + } + + collections = await manager.get_all_collections() + + assert len(collections) == 1 + assert collections[0]["name"] == "shared" + assert collections[0]["count"] == 2 + assert collections[0]["backend"] == ["weaviate", "open_webui"] + assert collections[0]["type"] == "weaviate" + assert collections[0]["size_mb"] == pytest.approx(0.02) + + +@pytest.mark.asyncio +async def test_storage_manager_get_backend_status_reports_failures() -> None: + settings = cast(Settings, SimpleNamespace( + weaviate_endpoint="http://weaviate.local", + weaviate_api_key=None, + openwebui_endpoint="http://chat.local", + openwebui_api_key=None, + r2r_endpoint=None, + r2r_api_key=None, + )) + manager = StorageManager(settings) + + healthy_storage = CollectionStubStorage( + StorageConfig( + backend=StorageBackend.WEAVIATE, + endpoint="http://weaviate.local", + collection_name="primary", + )), + collections=["collection", "archive"], + counts={"collection": 2, "archive": 1}, + )) + failing_storage = FailingStatusStorage( + StorageConfig( + backend=StorageBackend.OPEN_WEBUI, + endpoint="http://chat.local", + collection_name="secondary", + )) + )) + manager.backends = { + StorageBackend.WEAVIATE: healthy_storage, + StorageBackend.OPEN_WEBUI: failing_storage, + } + manager.capabilities[StorageBackend.WEAVIATE] = StorageCapabilities.VECTOR_SEARCH + + status = await manager.get_backend_status() + + assert status[StorageBackend.WEAVIATE]["available"] is True + assert status[StorageBackend.WEAVIATE]["collections"] == 2 + assert status[StorageBackend.WEAVIATE]["total_documents"] == 3 + assert status[StorageBackend.WEAVIATE]["capabilities"] == StorageCapabilities.VECTOR_SEARCH + assert str(status[StorageBackend.WEAVIATE]["endpoint"]) == "http://weaviate.local/" + assert status[StorageBackend.OPEN_WEBUI]["available"] is False + assert status[StorageBackend.OPEN_WEBUI]["capabilities"] == StorageCapabilities.NONE + assert "status unavailable" in str(status[StorageBackend.OPEN_WEBUI]["error"]) + + +@pytest.mark.asyncio +async def test_storage_manager_close_all_clears_state() -> None: + settings = cast(Settings, SimpleNamespace( + weaviate_endpoint="http://weaviate.local", + weaviate_api_key=None, + openwebui_endpoint="http://chat.local", + openwebui_api_key=None, + r2r_endpoint=None, + r2r_api_key=None, + )) + manager = StorageManager(settings) + + closable_storage = ClosableStubStorage( + StorageConfig( + backend=StorageBackend.WEAVIATE, + endpoint="http://weaviate.local", + collection_name="primary", + )) + )) + failing_close_storage = FailingCloseStorage( + StorageConfig( + backend=StorageBackend.OPEN_WEBUI, + endpoint="http://chat.local", + collection_name="secondary", + )) + )) + manager.backends = { + StorageBackend.WEAVIATE: closable_storage, + StorageBackend.OPEN_WEBUI: failing_close_storage, + } + manager.capabilities[StorageBackend.WEAVIATE] = StorageCapabilities.VECTOR_SEARCH + manager.capabilities[StorageBackend.OPEN_WEBUI] = StorageCapabilities.KNOWLEDGE_BASE + manager._initialized = True + + await manager.close_all() + + assert closable_storage.closed is True + assert manager.backends == {} + assert manager.capabilities == {} + assert manager.is_initialized is False diff --git a/tests/unit/utils/__pycache__/__init__.cpython-312.pyc b/tests/unit/utils/__pycache__/__init__.cpython-312.pyc index 481ffeb..7fb1140 100644 Binary files a/tests/unit/utils/__pycache__/__init__.cpython-312.pyc and b/tests/unit/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/tests/unit/utils/__pycache__/test_metadata_tagger.cpython-312-pytest-8.4.2.pyc b/tests/unit/utils/__pycache__/test_metadata_tagger.cpython-312-pytest-8.4.2.pyc index f5e51d2..ab4fce7 100644 Binary files a/tests/unit/utils/__pycache__/test_metadata_tagger.cpython-312-pytest-8.4.2.pyc and b/tests/unit/utils/__pycache__/test_metadata_tagger.cpython-312-pytest-8.4.2.pyc differ diff --git a/tests/unit/utils/__pycache__/test_vectorizer.cpython-312-pytest-8.4.2.pyc b/tests/unit/utils/__pycache__/test_vectorizer.cpython-312-pytest-8.4.2.pyc index d3d034c..d5aad3d 100644 Binary files a/tests/unit/utils/__pycache__/test_vectorizer.cpython-312-pytest-8.4.2.pyc and b/tests/unit/utils/__pycache__/test_vectorizer.cpython-312-pytest-8.4.2.pyc differ diff --git a/tests/unit/utils/conftest.py b/tests/unit/utils/conftest.py new file mode 100644 index 0000000..9bdce8b --- /dev/null +++ b/tests/unit/utils/conftest.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable +from typing import Protocol + +import pytest + +from ingest_pipeline.utils.async_helpers import AsyncTaskManager + +Delays = tuple[float, ...] +Results = tuple[float, ...] + + +class ProcessorFactory(Protocol): + def __call__(self, value: int) -> Awaitable[int]: ... + + +@pytest.fixture(scope="function") +def async_manager_factory() -> Callable[[int], AsyncTaskManager]: + def _factory(limit: int) -> AsyncTaskManager: + return AsyncTaskManager(max_concurrent=limit) + + return _factory + + +@pytest.fixture(scope="function") +def concurrency_measure() -> Callable[[AsyncTaskManager, Delays], Awaitable[tuple[Results, int]]]: + async def _measure(manager: AsyncTaskManager, delays: Delays) -> tuple[Results, int]: + active = 0 + peak = 0 + lock = asyncio.Lock() + + async def tracked(delay: float) -> float: + nonlocal active, peak + async with lock: + active += 1 + peak = peak if peak >= active else active + await asyncio.sleep(delay) + async with lock: + active -= 1 + return delay + + tasks = tuple(tracked(delay) for delay in delays) + results = await manager.run_tasks(tasks) + # Filter out any exceptions and only keep float results + float_results = [r for r in results if isinstance(r, (int, float))] + return tuple(float_results), peak + + return _measure + + +@pytest.fixture(scope="function") +def success_failure_coroutines() -> tuple[Awaitable[str], Awaitable[str], Awaitable[str]]: + async def success() -> str: + await asyncio.sleep(0) + return "ok" + + async def failure() -> str: + await asyncio.sleep(0) + raise RuntimeError("boom") + + return (success(), failure(), success()) + + +@pytest.fixture(scope="function") +def batch_inputs() -> tuple[int, ...]: + return (1, 2, 3, 4, 5) + + +@pytest.fixture(scope="function") +def batch_processor_record() -> tuple[list[int], ProcessorFactory]: + recorded: list[int] = [] + + async def processor(value: int) -> int: + recorded.append(value) + await asyncio.sleep(0) + return value * 2 + + return recorded, processor + + +@pytest.fixture(scope="function") +def error_processor() -> ProcessorFactory: + async def processor(_: int) -> int: + await asyncio.sleep(0) + raise ValueError("failure") + + # Type checker doesn't understand this function raises, which is expected for error testing + return processor # type: ignore[return-value] + + +@pytest.fixture(scope="function") +def semaphore_case() -> tuple[asyncio.Semaphore, Callable[[], Awaitable[str]], list[str]]: + semaphore = asyncio.Semaphore(0) + events: list[str] = [] + + async def tracked() -> str: + events.append("entered") + await asyncio.sleep(0) + events.append("completed") + return "done" + + def factory() -> Awaitable[str]: + return tracked() + + return semaphore, factory, events + + +@pytest.fixture(scope="function") +def batch_expected_results(batch_inputs: tuple[int, ...]) -> tuple[int, ...]: + return tuple(value * 2 for value in batch_inputs) diff --git a/tests/unit/utils/test_async_helpers.py b/tests/unit/utils/test_async_helpers.py new file mode 100644 index 0000000..0cce241 --- /dev/null +++ b/tests/unit/utils/test_async_helpers.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Callable + +import pytest + +from ingest_pipeline.utils.async_helpers import AsyncTaskManager, batch_process, run_with_semaphore + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("limit", "delays"), + [(2, (0.02, 0.01, 0.03, 0.02))], +) +async def test_async_task_manager_limits_concurrency( + async_manager_factory: Callable[[int], AsyncTaskManager], + concurrency_measure, + limit: int, + delays: tuple[float, ...], +) -> None: + manager = async_manager_factory(limit) + results, peak = await concurrency_measure(manager, delays) + + assert results == delays + assert peak <= limit + + +@pytest.mark.asyncio +async def test_async_task_manager_collects_exceptions( + async_manager_factory: Callable[[int], AsyncTaskManager], + success_failure_coroutines, +) -> None: + manager = async_manager_factory(3) + outcomes = await manager.run_tasks(success_failure_coroutines, return_exceptions=True) + + assert outcomes[0] == "ok" + assert isinstance(outcomes[1], RuntimeError) + assert outcomes[2] == "ok" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("batch_size", "max_concurrent"), + [(2, 1)], +) +async def test_batch_process_handles_multiple_batches( + batch_inputs, + batch_processor_record, + batch_expected_results, + batch_size: int, + max_concurrent: int, +) -> None: + recorded, processor = batch_processor_record + + results = await batch_process( + list(batch_inputs), + processor, + batch_size=batch_size, + max_concurrent=max_concurrent, + ) + + assert tuple(results) == batch_expected_results + assert recorded == list(batch_inputs) + + +@pytest.mark.asyncio +async def test_batch_process_propagates_errors(error_processor) -> None: + with pytest.raises(ValueError): + await batch_process([1, 2], error_processor, batch_size=1) + + +@pytest.mark.asyncio +async def test_run_with_semaphore_blocks_until_available(semaphore_case) -> None: + semaphore, factory, events = semaphore_case + pending = asyncio.create_task(run_with_semaphore(semaphore, factory())) + + await asyncio.sleep(0.01) + assert events == [] + + semaphore.release() + result = await pending + + assert result == "done" + assert events == ["entered", "completed"] diff --git a/tests/unit/utils/test_metadata_tagger.py b/tests/unit/utils/test_metadata_tagger.py index 6056a00..3012839 100644 --- a/tests/unit/utils/test_metadata_tagger.py +++ b/tests/unit/utils/test_metadata_tagger.py @@ -86,14 +86,10 @@ async def test_tag_batch_processes_documents( """Apply tagging to each document in order.""" first_payload = { - "choices": [ - {"message": {"content": json.dumps({"summary": "First summary"})}} - ] + "choices": [{"message": {"content": json.dumps({"summary": "First summary"})}}] } second_payload = { - "choices": [ - {"message": {"content": json.dumps({"summary": "Second summary"})}} - ] + "choices": [{"message": {"content": json.dumps({"summary": "Second summary"})}}] } httpx_stub.queue_json(first_payload) httpx_stub.queue_json(second_payload) @@ -113,11 +109,7 @@ async def test_tag_batch_processes_documents( async def test_tag_document_raises_on_invalid_json(httpx_stub, document_factory) -> None: """Raise ingestion error when the LLM response is malformed.""" - payload = { - "choices": [ - {"message": {"content": "not-json"}} - ] - } + payload = {"choices": [{"message": {"content": "not-json"}}]} httpx_stub.queue_json(payload) document = document_factory(content="broken payload") diff --git a/typings/__init__.py b/typings/__init__.py index 1efb184..e4997b7 100644 --- a/typings/__init__.py +++ b/typings/__init__.py @@ -5,9 +5,11 @@ from typing import TypedDict class EmbeddingData(TypedDict): """Structure for embedding data from API response.""" + embedding: list[float] class EmbeddingResponse(TypedDict): """Structure for OpenAI-compatible embedding API response.""" + data: list[EmbeddingData] diff --git a/typings/__init__.pyi b/typings/__init__.pyi index 929db06..06e608f 100644 --- a/typings/__init__.pyi +++ b/typings/__init__.pyi @@ -4,9 +4,10 @@ from typing import TypedDict class EmbeddingData(TypedDict): """Structure for embedding data from API response.""" - embedding: list[float] + embedding: list[float] class EmbeddingResponse(TypedDict): """Structure for OpenAI-compatible embedding API response.""" + data: list[EmbeddingData] diff --git a/typings/__pycache__/__init__.cpython-312.pyc b/typings/__pycache__/__init__.cpython-312.pyc index 903cd04..a5e78ad 100644 Binary files a/typings/__pycache__/__init__.cpython-312.pyc and b/typings/__pycache__/__init__.cpython-312.pyc differ diff --git a/typings/httpx.pyi b/typings/httpx.pyi index 32fbc0a..6008c98 100644 --- a/typings/httpx.pyi +++ b/typings/httpx.pyi @@ -8,33 +8,17 @@ class Response: def raise_for_status(self) -> None: ... def json(self) -> dict[str, object]: ... - class AsyncClient: """Async HTTP client.""" def __init__( - self, - *, - timeout: float | None = None, - headers: dict[str, str] | None = None + self, *, timeout: float | None = None, headers: dict[str, str] | None = None ) -> None: ... - async def get( - self, - url: str, - *, - params: dict[str, object] | dict[str, int] | None = None + self, url: str, *, params: dict[str, object] | dict[str, int] | None = None ) -> Response: ... - - async def post( - self, - url: str, - *, - json: dict[str, object] | None = None - ) -> Response: ... - + async def post(self, url: str, *, json: dict[str, object] | None = None) -> Response: ... async def aclose(self) -> None: ... - # Make AsyncClient available for import __all__ = ["AsyncClient", "Response"] diff --git a/typings/prefect.pyi b/typings/prefect.pyi index 6ed4156..2f62582 100644 --- a/typings/prefect.pyi +++ b/typings/prefect.pyi @@ -1,11 +1,14 @@ """Type stubs for Prefect to fix missing type information.""" -from typing import Any, Awaitable, Callable, TypeVar, overload +from collections.abc import Awaitable, Callable +from typing import Any, TypeVar, overload + from typing_extensions import ParamSpec # Prefect-specific types for cache key functions class TaskRunContext: """Prefect task run context.""" + pass P = ParamSpec("P") @@ -37,7 +40,6 @@ def task( retry_condition_fn: Any = None, viz_return_value: Any = None, ) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]: ... - @overload def task( *, @@ -64,17 +66,14 @@ def task( retry_condition_fn: Any = None, viz_return_value: Any = None, ) -> Callable[[Callable[P, T]], Callable[P, T]]: ... - @overload def task( fn: Callable[P, Awaitable[T]], ) -> Callable[P, Awaitable[T]]: ... - @overload def task( fn: Callable[P, T], ) -> Callable[P, T]: ... - @overload def flow( *, @@ -99,7 +98,6 @@ def flow( on_cancellation: list[Any] | None = None, on_running: list[Any] | None = None, ) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]: ... - @overload def flow( fn: Callable[P, Awaitable[T]], @@ -111,4 +109,4 @@ class Logger: def warning(self, msg: str, *args: Any, **kwargs: Any) -> None: ... def error(self, msg: str, *args: Any, **kwargs: Any) -> None: ... -def get_run_logger() -> Logger: ... \ No newline at end of file +def get_run_logger() -> Logger: ... diff --git a/typings/prefect/blocks/core.pyi b/typings/prefect/blocks/core.pyi index 44a878a..b0b4cbd 100644 --- a/typings/prefect/blocks/core.pyi +++ b/typings/prefect/blocks/core.pyi @@ -1,7 +1,6 @@ """Type stubs for Prefect blocks core module.""" from typing import Any, ClassVar, TypeVar -from typing_extensions import Self T = TypeVar("T", bound="Block") @@ -12,8 +11,6 @@ class Block: @classmethod async def aload(cls: type[T], name: str, validate: bool = True, client: Any = None) -> T: ... - @classmethod def load(cls: type[T], name: str, validate: bool = True, client: Any = None) -> T: ... - - def save(self, name: str, overwrite: bool = False) -> None: ... \ No newline at end of file + def save(self, name: str, overwrite: bool = False) -> None: ... diff --git a/typings/r2r/__init__.pyi b/typings/r2r/__init__.pyi index c4bce07..9309f29 100644 --- a/typings/r2r/__init__.pyi +++ b/typings/r2r/__init__.pyi @@ -5,18 +5,15 @@ from typing import Generic, Protocol, TypeVar, overload _T_co = TypeVar("_T_co", covariant=True) - class SupportsModelDump(Protocol): def model_dump(self) -> dict[str, object]: ... - class CollectionResponse(SupportsModelDump): id: str name: str description: str | None owner_id: str | None - class CollectionSequence(SupportsModelDump, Protocol): @overload def __getitem__(self, index: int) -> CollectionResponse: ... @@ -26,64 +23,52 @@ class CollectionSequence(SupportsModelDump, Protocol): def __len__(self) -> int: ... def __iter__(self) -> Iterator[CollectionResponse]: ... - class DocumentSequence(SupportsModelDump, Protocol): def model_dump(self) -> dict[str, object]: ... - class UserSequence(SupportsModelDump, Protocol): def model_dump(self) -> dict[str, object]: ... - class GenericBooleanResponse(SupportsModelDump): success: bool - class GenericMessageResponse(SupportsModelDump): message: str | None - class DocumentCreateResponse(SupportsModelDump): """Response from document creation.""" + id: str message: str | None = None - class DocumentRetrieveResponse(SupportsModelDump): """Response from document retrieval.""" + id: str content: str | None = None metadata: dict[str, object] | None = None - class SearchResult(SupportsModelDump): """Individual search result.""" + id: str content: str | None = None metadata: dict[str, object] | None = None score: float | None = None - class SearchResponse(SupportsModelDump): """Response from search operations.""" - results: list[SearchResult] + results: list[SearchResult] class R2RResults(Generic[_T_co], SupportsModelDump): results: _T_co def model_dump(self) -> dict[str, object]: ... - class WrappedCollectionsResponse(R2RResults[CollectionSequence]): ... - - class WrappedDocumentsResponse(R2RResults[DocumentSequence]): ... - - class WrappedUsersResponse(R2RResults[UserSequence]): ... - class R2RException(Exception): ... - class R2RClientException(Exception): ... class _SystemClient: @@ -97,8 +82,12 @@ class _CollectionsClient: limit: int | None = ..., owner_only: bool | None = ..., ) -> WrappedCollectionsResponse: ... - async def create(self, name: str, description: str | None = ...) -> R2RResults[CollectionResponse]: ... - async def add_document(self, id: str, document_id: str) -> R2RResults[GenericMessageResponse]: ... + async def create( + self, name: str, description: str | None = ... + ) -> R2RResults[CollectionResponse]: ... + async def add_document( + self, id: str, document_id: str + ) -> R2RResults[GenericMessageResponse]: ... async def delete(self, id: str) -> None: ... async def retrieve(self, id: str) -> R2RResults[CollectionResponse]: ... async def update(