yooohoo
This commit is contained in:
@@ -8,7 +8,8 @@
|
||||
"Bash(cat:*)",
|
||||
"mcp__firecrawl__firecrawl_search",
|
||||
"mcp__firecrawl__firecrawl_scrape",
|
||||
"Bash(python:*)"
|
||||
"Bash(python:*)",
|
||||
"mcp__coder__coder_report_task"
|
||||
],
|
||||
"deny": [],
|
||||
"ask": []
|
||||
|
||||
432
docs/feeds.md
432
docs/feeds.md
@@ -1,263 +1,255 @@
|
||||
## Codebase Analysis Report: RAG Manager Ingestion Pipeline
|
||||
TL;DR / Highest‑impact fixes (do these first)
|
||||
|
||||
**Status:** Validated against current codebase implementation
|
||||
**Target:** Enhanced implementation guidance for efficient agent execution
|
||||
Event-loop blocking in the TUI (user-visible stutter):
|
||||
time.sleep(2) is called at the end of an ingestion run on the IngestionScreen. Even though this runs in a thread worker, it blocks that worker and delays UI transition; prefer a scheduled UI callback instead.
|
||||
|
||||
This analysis has been validated against the actual codebase structure and provides implementation-specific details for executing recommended improvements. The codebase demonstrates solid architecture with clear separation of concerns between ingestion flows, storage adapters, and TUI components.
|
||||
repomix-output (2)
|
||||
|
||||
### Architecture Overview
|
||||
- **Storage Backends**: Weaviate, OpenWebUI, R2R with unified `BaseStorage` interface
|
||||
- **TUI Framework**: Textual-based with reactive components and async worker patterns
|
||||
- **Orchestration**: Prefect flows with retry logic and progress callbacks
|
||||
- **Configuration**: Pydantic-based settings with environment variable support
|
||||
Blocking Weaviate client calls inside async methods:
|
||||
Several async methods in WeaviateStorage call the synchronous Weaviate client directly (connect(), collections.create, queries, inserts, deletes). Wrap those in asyncio.to_thread(...) (or equivalent) to avoid freezing the loop when these calls take time.
|
||||
|
||||
### Validated Implementation Analysis
|
||||
Embedding at scale: use batch vectorization:
|
||||
store_batch vectorizes each document one by one; you already have Vectorizer.vectorize_batch. Switching to batch reduces HTTP round trips and improves throughput under backpressure.
|
||||
|
||||
### 1. Bug Fixes & Potential Issues
|
||||
Connection lifecycle: close the vectorizer client:
|
||||
WeaviateStorage.close() closes the Weaviate client but not the httpx.AsyncClient inside the Vectorizer. Add an await to close it to prevent leaked connections/sockets under heavy usage.
|
||||
|
||||
These are areas where the code may not function as intended or could lead to errors.
|
||||
Broadened exception handling in UI utilities:
|
||||
Multiple places catch Exception broadly, making failures ambiguous and harder to surface to users (e.g., storage manager and list builders). Narrow these to expected exceptions and fail fast with user-friendly messages where appropriate.
|
||||
|
||||
* <details>
|
||||
<summary>
|
||||
<b>HIGH PRIORITY: `R2RStorage.store_batch` inefficient looping (Lines 161-179)</b>
|
||||
</summary>
|
||||
repomix-output (2)
|
||||
|
||||
* **File:** `ingest_pipeline/storage/r2r/storage.py:161-179`
|
||||
* **Issue:** CONFIRMED - Method loops through documents calling `_store_single_document` individually
|
||||
* **Impact:** ~5-10x performance degradation for batch operations
|
||||
* **Implementation:** Check R2R v3 API for bulk endpoints; current implementation uses `/v3/documents` per document
|
||||
* **Effort:** Medium (API research + refactor)
|
||||
* **Priority:** High - affects all R2R ingestion workflows
|
||||
</details>
|
||||
Correctness / Bugs
|
||||
|
||||
* <details>
|
||||
<summary>
|
||||
<b>MEDIUM PRIORITY: Mixed HTTP client usage in `R2RStorage` (Lines 80, 99, 258)</b>
|
||||
</summary>
|
||||
Blocking sleep in TUI:
|
||||
time.sleep(2) after posting notifications and before app.pop_screen() blocks the worker thread; use Textual timers instead.
|
||||
|
||||
* **File:** `ingest_pipeline/storage/r2r/storage.py:80,99,258`
|
||||
* **Issue:** VALIDATED - Mixes `R2RAsyncClient` (line 80) with direct `httpx.AsyncClient` (lines 99, 258)
|
||||
* **Specific Methods:** `initialize()`, `_ensure_collection()`, `_attempt_document_creation()`
|
||||
* **Impact:** Inconsistent auth/header handling, connection pooling inefficiency
|
||||
* **Implementation:** Extend `R2RAsyncClient` or create adapter pattern for missing endpoints
|
||||
* **Test Coverage:** Check if affected methods have unit tests before refactoring
|
||||
* **Effort:** Medium (requires SDK analysis)
|
||||
</details>
|
||||
repomix-output (2)
|
||||
|
||||
* <details>
|
||||
<summary>
|
||||
<b>MEDIUM PRIORITY: TUI blocking during storage init (Line 91)</b>
|
||||
</summary>
|
||||
Synchronous SDK in async contexts (Weaviate):
|
||||
|
||||
* **File:** `ingest_pipeline/cli/tui/utils/runners.py:91`
|
||||
* **Issue:** CONFIRMED - `await storage_manager.initialize_all_backends()` blocks TUI startup
|
||||
* **Current Implementation:** 30s timeout per backend in `StorageManager.initialize_all_backends()`
|
||||
* **User Impact:** Frozen terminal for up to 90s if all backends timeout
|
||||
* **Solution:** Move to `CollectionOverviewScreen.on_mount()` as `@work` task
|
||||
* **Dependencies:** `dashboard.py:304` already has worker pattern for `refresh_collections`
|
||||
* **Implementation:** Use existing loading indicators and status updates (lines 308-312)
|
||||
* **Effort:** Low (pattern exists, needs relocation)
|
||||
</details>
|
||||
initialize() calls self.client.connect() (sync). Wrap with asyncio.to_thread(self.client.connect).
|
||||
|
||||
* <details>
|
||||
<summary>
|
||||
<b>LOW PRIORITY: Weak URL validation in `IngestionScreen` (Lines 240-260)</b>
|
||||
</summary>
|
||||
repomix-output (2)
|
||||
|
||||
* **File:** `ingest_pipeline/cli/tui/screens/ingestion.py:240-260`
|
||||
* **Issue:** CONFIRMED - Method accepts `foo/bar` as valid (line 258)
|
||||
* **Security Risk:** Medium - malicious URLs could be passed to ingestors
|
||||
* **Current Logic:** Basic prefix checks only (http/https/file://)
|
||||
* **Enhancement:** Add `pathlib.Path.exists()` for file:// paths, `.git` directory check for repos
|
||||
* **Dependencies:** Import `pathlib` and add proper regex validation
|
||||
* **Alternative:** Use `validators` library (not currently imported)
|
||||
* **Effort:** Low (validation logic only)
|
||||
</details>
|
||||
Object operations such as collection.data.insert(...), collection.query.fetch_objects(...), collection.data.delete_many(...), and client.collections.delete(...) are sync calls invoked from async methods. These can stall the event loop under latency; use asyncio.to_thread(...) (see snippet below).
|
||||
|
||||
### 2. Code Redundancy & Refactoring Opportunities
|
||||
HTTP client lifecycle (Vectorizer):
|
||||
Vectorizer owns an httpx.AsyncClient but WeaviateStorage.close() doesn’t close it—add a close call to avoid resource leaks.
|
||||
|
||||
These suggestions aim to make the code more concise, maintainable, and reusable (D.R.Y. - Don't Repeat Yourself).
|
||||
Heuristic “word_count” for OpenWebUI file listing:
|
||||
word_count is estimated via size/6. That’s fine as a placeholder but can mislead paging and UI logic downstream—consider a sentinel or a clearer label (e.g., estimated_word_count).
|
||||
|
||||
* <details>
|
||||
<summary>
|
||||
<b>HIGH IMPACT: Redundant collection logic in dashboard (Lines 356-424)</b>
|
||||
</summary>
|
||||
repomix-output (2)
|
||||
|
||||
* **File:** `ingest_pipeline/cli/tui/screens/dashboard.py:356-424`
|
||||
* **Issue:** CONFIRMED - `list_weaviate_collections()` and `list_openwebui_collections()` duplicate `StorageManager.get_all_collections()`
|
||||
* **Code Duplication:** ~70 lines of redundant collection listing logic
|
||||
* **Architecture Violation:** UI layer coupled to specific storage implementations
|
||||
* **Current Usage:** `refresh_collections()` calls `get_all_collections()` (line 327), making methods obsolete
|
||||
* **Action:** DELETE methods `list_weaviate_collections` and `list_openwebui_collections`
|
||||
* **Impact:** Code reduction ~70 lines, improved maintainability
|
||||
* **Risk:** Low - methods appear unused in current flow
|
||||
* **Effort:** Low (deletion only)
|
||||
</details>
|
||||
Wide except Exception:
|
||||
Broad catches appear in places like the storage manager and TUI screen update closures; they hide actionable errors. Prefer catching StorageError, IngestionError, or specific SDK exceptions—and surface toasts with actionable details.
|
||||
|
||||
* <details>
|
||||
<summary>
|
||||
<b>MEDIUM IMPACT: Repetitive backend init pattern (Lines 255-291)</b>
|
||||
</summary>
|
||||
repomix-output (2)
|
||||
|
||||
* **File:** `ingest_pipeline/cli/tui/utils/storage_manager.py:255-291`
|
||||
* **Issue:** CONFIRMED - Pattern repeated 3x for each backend type
|
||||
* **Code Structure:** Check settings → Create config → Add task (12 lines × 3 backends)
|
||||
* **Current Backends:** Weaviate (258-267), OpenWebUI (270-279), R2R (282-291)
|
||||
* **Refactor Pattern:** Create `BackendConfig` dataclass with `(backend_type, endpoint_setting, api_key_setting, storage_class)`
|
||||
* **Implementation:** Loop over config list, reducing ~36 lines to ~15 lines
|
||||
* **Extensibility:** Adding new backend becomes one-line config addition
|
||||
* **Testing:** Ensure `asyncio.gather()` behavior unchanged (line 296)
|
||||
* **Effort:** Medium (requires dataclass design + testing)
|
||||
</details>
|
||||
Performance & Scalability
|
||||
|
||||
* <details>
|
||||
<summary>
|
||||
<b>MEDIUM IMPACT: Repeated Prefect block loading pattern (Lines 266-311)</b>
|
||||
</summary>
|
||||
Batch embeddings:
|
||||
In WeaviateStorage.store_batch you conditionally vectorize each doc inline. Use your existing vectorize_batch and map the results back to documents. This cuts request overhead and enables controlled concurrency (you already have AsyncTaskManager to help).
|
||||
|
||||
* **File:** `ingest_pipeline/flows/ingestion.py:266-311`
|
||||
* **Issue:** CONFIRMED - Pattern in `_create_ingestor()` and `_create_storage()` methods
|
||||
* **Duplication:** `Block.aload()` + fallback logic repeated 4x across both methods
|
||||
* **Variable Resolution:** Batch size logic (lines 244-255) also needs abstraction
|
||||
* **Helper Functions Needed:**
|
||||
- `load_block_with_fallback(block_slug: str, default_config: T) -> T`
|
||||
- `resolve_prefect_variable(var_name: str, default: T, type_cast: Type[T]) -> T`
|
||||
* **Impact:** Cleaner flow logic, better error handling, type safety
|
||||
* **Lines Reduced:** ~20 lines of repetitive code
|
||||
* **Effort:** Medium (requires generic typing)
|
||||
</details>
|
||||
Async-friendly Weaviate calls:
|
||||
Offload sync SDK operations to a thread so your Textual UI remains responsive while bulk inserting/querying/deleting. (See patch below.)
|
||||
|
||||
### 3. User Experience (UX) Enhancements
|
||||
Retry & backoff are solid in your HTTP base client:
|
||||
Your TypedHttpClient.request adds exponential backoff + jitter; this is great and worth keeping consistent across adapters.
|
||||
|
||||
These are suggestions to make your TUI more powerful, intuitive, and enjoyable for the user.
|
||||
repomix-output (2)
|
||||
|
||||
* <details>
|
||||
<summary>
|
||||
<b>HIGH IMPACT: Document content viewer modal (Add to documents.py)</b>
|
||||
</summary>
|
||||
UX notes (Textual TUI)
|
||||
|
||||
* **Target File:** `ingest_pipeline/cli/tui/screens/documents.py`
|
||||
* **Current State:** READY - `DocumentManagementScreen` has table selection (line 212)
|
||||
* **Implementation:**
|
||||
- Add `Binding("v", "view_document", "View")` to BINDINGS (line 27)
|
||||
- Create `DocumentContentModal(ModalScreen)` with `ScrollableContainer` + `Markdown`
|
||||
- Use existing `get_current_document()` method (line 212)
|
||||
- Fetch full content via `storage.retrieve(document_id)`
|
||||
* **Dependencies:** Import `ModalScreen`, `ScrollableContainer`, `Markdown` from textual
|
||||
* **User Value:** HIGH - essential for content inspection workflow
|
||||
* **Effort:** Low-Medium (~50 lines of modal code)
|
||||
* **Pattern:** Follow existing modal patterns in codebase
|
||||
</details>
|
||||
Notifications are good—make them consistent:
|
||||
app.safe_notify(...) exists—use that everywhere instead of self.notify(...) to normalize markup handling & safety.
|
||||
|
||||
* <details>
|
||||
<summary>
|
||||
<b>HIGH IMPACT: Analytics tab visualization (Lines 164-189)</b>
|
||||
</summary>
|
||||
repomix-output (2)
|
||||
|
||||
* **Target File:** `ingest_pipeline/cli/tui/screens/dashboard.py:164-189`
|
||||
* **Current State:** PLACEHOLDER - Static widgets with dummy content
|
||||
* **Data Source:** Use existing `self.collections` (line 65) populated by `refresh_collections()`
|
||||
* **Implementation Options:**
|
||||
1. **Simple Text Chart:** ASCII bar chart using existing collections data
|
||||
2. **textual-plotext:** Add dependency + bar chart widget
|
||||
3. **Custom Widget:** Simple bar visualization with Static widgets
|
||||
* **Metrics to Show:**
|
||||
- Documents per collection (data available)
|
||||
- Storage usage per backend (calculated in `_calculate_metrics()`)
|
||||
- Ingestion timeline (requires timestamp tracking)
|
||||
* **Effort:** Low-Medium (depends on visualization complexity)
|
||||
* **Dependencies:** Consider `textual-plotext` or pure ASCII approach
|
||||
</details>
|
||||
Avoid visible “freeze” scenarios:
|
||||
Replace the final time.sleep(2) with a timer and transition immediately; users see their actions complete without lag. (Patch below.)
|
||||
|
||||
* <details>
|
||||
<summary>
|
||||
<b>MEDIUM IMPACT: Global search implementation (Button exists, needs screen)</b>
|
||||
</summary>
|
||||
repomix-output (2)
|
||||
|
||||
* **Target File:** `ingest_pipeline/cli/tui/screens/dashboard.py`
|
||||
* **Current State:** READY - "Search All" button exists (line 122), handler stubbed
|
||||
* **Backend Support:** `StorageManager.search_across_backends()` method exists (line 413-441)
|
||||
* **Implementation:**
|
||||
- Create `GlobalSearchScreen(ModalScreen)` with search input + results table
|
||||
- Use existing `search_across_backends()` method for data
|
||||
- Add "Backend" column to results table showing data source
|
||||
- Handle async search with loading indicators
|
||||
* **Current Limitation:** Search only works for Weaviate (line 563), need to extend
|
||||
* **Data Flow:** Input → `storage_manager.search_across_backends()` → Results display
|
||||
* **Effort:** Medium (~100 lines for new screen + search logic)
|
||||
</details>
|
||||
Search table quick find:
|
||||
You have EnhancedDataTable with quick-search messages; wire a shortcut (e.g., /) to focus a search input and filter rows live.
|
||||
|
||||
* <details>
|
||||
<summary>
|
||||
<b>MEDIUM IMPACT: R2R advanced features integration (Widgets ready)</b>
|
||||
</summary>
|
||||
repomix-output (2)
|
||||
|
||||
* **Target File:** `ingest_pipeline/cli/tui/screens/documents.py`
|
||||
* **Available Widgets:** CONFIRMED - `ChunkViewer`, `EntityGraph`, `CollectionStats`, `DocumentOverview` in `r2r_widgets.py`
|
||||
* **Current Implementation:** Basic document table only, R2R-specific features unused
|
||||
* **Integration Points:**
|
||||
- Add "R2R Details" button when `collection["type"] == "r2r"` (conditional UI)
|
||||
- Create `R2RDocumentDetailsScreen` using existing widgets
|
||||
- Use `StorageManager.get_r2r_storage()` method (exists at line 442)
|
||||
* **R2R Methods Available:**
|
||||
- `get_document_chunks()`, `extract_entities()`, `get_document_overview()`
|
||||
* **User Value:** Medium-High for R2R users, showcases advanced features
|
||||
* **Effort:** Low-Medium (widgets exist, need screen integration)
|
||||
</details>
|
||||
Theme file size & maintainability:
|
||||
The theming system is thorough; consider splitting styles.py into smaller modules or generating CSS at build time to keep the Python file lean. (The responsive CSS generators are consolidated here.)
|
||||
|
||||
* <details>
|
||||
<summary>
|
||||
<b>LOW IMPACT: Create collection dialog (Backend methods exist)</b>
|
||||
</summary>
|
||||
repomix-output (2)
|
||||
|
||||
* **Target File:** `ingest_pipeline/cli/tui/screens/dashboard.py`
|
||||
* **Backend Support:** CONFIRMED - `create_collection()` method exists for R2R storage (line 690)
|
||||
* **Current State:** No "Create Collection" button in existing UI
|
||||
* **Implementation:**
|
||||
- Add "New Collection" button to dashboard action buttons
|
||||
- Create `CreateCollectionModal` with name input + backend checkboxes
|
||||
- Iterate over `storage_manager.get_available_backends()` for backend selection
|
||||
- Call `storage.create_collection()` on selected backends
|
||||
* **Backend Compatibility:** Check which storage backends support collection creation
|
||||
* **User Value:** Low-Medium (manual workflow, not critical)
|
||||
* **Effort:** Low-Medium (~75 lines for modal + integration)
|
||||
</details>
|
||||
Modularity / Redundancy
|
||||
|
||||
## Implementation Priority Matrix
|
||||
Converge repeated property mapping:
|
||||
WeaviateStorage.store and store_batch build the same properties dict; factor this into a small helper to keep schemas in one place (less drift, easier to extend).
|
||||
|
||||
### Quick Wins (High Impact, Low Effort)
|
||||
1. **Delete redundant collection methods** (dashboard.py:356-424) - 5 min
|
||||
2. **Fix TUI startup blocking** (runners.py:91) - 15 min
|
||||
3. **Document content viewer modal** (documents.py) - 30 min
|
||||
repomix-output (2)
|
||||
|
||||
### High Impact Fixes (Medium Effort)
|
||||
1. **R2R batch operation optimization** (storage.py:161-179) - Research R2R v3 API + implementation
|
||||
2. **Analytics tab visualization** (dashboard.py:164-189) - Choose visualization approach + implement
|
||||
3. **Backend initialization refactoring** (storage_manager.py:255-291) - Dataclass design + testing
|
||||
Common “describe/list/count” patterns across storages:
|
||||
R2RStorage, OpenWebUIStorage, and WeaviateStorage present similar collection/document listing and counting methods. Consider a small “collection view” mixin with shared helpers; each backend implements only the API-specific steps.
|
||||
|
||||
### Technical Debt (Long-term)
|
||||
1. **R2R client consistency** (storage.py) - SDK analysis + refactoring
|
||||
2. **Prefect block loading helpers** (ingestion.py:266-311) - Generic typing + testing
|
||||
3. **URL validation enhancement** (ingestion.py:240-260) - Security + validation logic
|
||||
Security & Reliability
|
||||
|
||||
### Feature Enhancements (User Value)
|
||||
1. **Global search implementation** - Medium effort, requires search backend extension
|
||||
2. **R2R advanced features integration** - Showcase existing widget capabilities
|
||||
3. **Create collection dialog** - Nice-to-have administrative feature
|
||||
Input sanitization for LLM metadata:
|
||||
Your MetadataTagger sanitizes and bounds fields (e.g., max lengths, language whitelist). This is a strong pattern—keep it.
|
||||
|
||||
## Agent Execution Notes
|
||||
repomix-output (2)
|
||||
|
||||
**Context Efficiency Tips:**
|
||||
- Focus on one priority tier at a time
|
||||
- Read specific file ranges mentioned in line numbers
|
||||
- Use existing patterns (worker decorators, modal screens, async methods)
|
||||
- Test changes incrementally, especially async operations
|
||||
- Verify import dependencies before implementation
|
||||
Timeouts and typed HTTP clients:
|
||||
You standardize HTTP clients, headers, and timeouts. Good foundation for consistent behavior & observability.
|
||||
|
||||
**Architecture Constraints:**
|
||||
- Maintain async/await patterns throughout
|
||||
- Follow Textual reactive widget patterns
|
||||
- Preserve Prefect flow structure for orchestration
|
||||
- Keep storage backend abstraction intact
|
||||
repomix-output (2)
|
||||
|
||||
The codebase demonstrates excellent architectural foundations - these enhancements build upon existing strengths rather than requiring structural changes.
|
||||
Suggested patches (drop‑in)
|
||||
1) Don’t block the UI when closing the ingestion screen
|
||||
|
||||
Current (blocking):
|
||||
|
||||
repomix-output (2)
|
||||
|
||||
import time
|
||||
time.sleep(2)
|
||||
cast("CollectionManagementApp", self.app).pop_screen()
|
||||
|
||||
|
||||
Safer (schedule via the app’s timer)
|
||||
|
||||
def _pop() -> None:
|
||||
try:
|
||||
self.app.pop_screen()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# schedule from the worker thread
|
||||
cast("CollectionManagementApp", self.app).call_from_thread(
|
||||
lambda: self.app.set_timer(2.0, _pop)
|
||||
)
|
||||
|
||||
2) Offload Weaviate sync calls from async methods
|
||||
|
||||
Example – insert (from WeaviateStorage.store)
|
||||
|
||||
repomix-output (2)
|
||||
|
||||
# before (sync in async method)
|
||||
collection.data.insert(properties=properties, vector=vector)
|
||||
|
||||
# after
|
||||
await asyncio.to_thread(collection.data.insert, properties=properties, vector=vector)
|
||||
|
||||
|
||||
Example – fetch/delete by filter (from delete_by_filter)
|
||||
|
||||
repomix-output (2)
|
||||
|
||||
response = await asyncio.to_thread(
|
||||
collection.query.fetch_objects, filters=where_filter, limit=1000
|
||||
)
|
||||
for obj in response.objects:
|
||||
await asyncio.to_thread(collection.data.delete_by_id, obj.uuid)
|
||||
|
||||
|
||||
Example – connect during initialize
|
||||
|
||||
repomix-output (2)
|
||||
|
||||
await asyncio.to_thread(self.client.connect)
|
||||
|
||||
3) Batch embeddings in store_batch
|
||||
|
||||
Current (per‑doc):
|
||||
|
||||
repomix-output (2)
|
||||
|
||||
for doc in documents:
|
||||
if doc.vector is None:
|
||||
doc.vector = await self.vectorizer.vectorize(doc.content)
|
||||
|
||||
|
||||
Proposed (batch):
|
||||
|
||||
repomix-output (2)
|
||||
|
||||
# collect contents needing vectors
|
||||
to_embed_idxs = [i for i, d in enumerate(documents) if d.vector is None]
|
||||
if to_embed_idxs:
|
||||
contents = [documents[i].content for i in to_embed_idxs]
|
||||
vectors = await self.vectorizer.vectorize_batch(contents)
|
||||
for j, idx in enumerate(to_embed_idxs):
|
||||
documents[idx].vector = vectors[j]
|
||||
|
||||
4) Close the Vectorizer HTTP client when storage closes
|
||||
|
||||
Current: WeaviateStorage.close() only closes the Weaviate client.
|
||||
|
||||
repomix-output (2)
|
||||
|
||||
Add:
|
||||
|
||||
async def close(self) -> None:
|
||||
if self.client:
|
||||
try:
|
||||
cast(weaviate.WeaviateClient, self.client).close()
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.warning("Error closing Weaviate client: %s", e)
|
||||
# NEW: close vectorizer HTTP client too
|
||||
try:
|
||||
await self.vectorizer.client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
(Your Vectorizer owns an httpx.AsyncClient with headers/timeouts set.)
|
||||
|
||||
repomix-output (2)
|
||||
|
||||
5) Prefer safe_notify consistently
|
||||
|
||||
Replace direct self.notify(...) calls inside TUI screens with cast(AppType, self.app).safe_notify(...) (same severity, markup=False by default), centralizing styling/sanitization.
|
||||
|
||||
repomix-output (2)
|
||||
|
||||
Smaller improvements (quality of life)
|
||||
|
||||
Quick search focus: Wire the EnhancedDataTable quick-search to a keyboard binding (e.g., /) so users can filter rows without reaching for the mouse.
|
||||
|
||||
repomix-output (2)
|
||||
|
||||
Refactor repeated properties dict: Extract the property construction in WeaviateStorage.store/store_batch into a helper to avoid drift and reduce line count.
|
||||
|
||||
repomix-output (2)
|
||||
|
||||
Styles organization: styles.py is hefty by design. Consider splitting “theme palette”, “components”, and “responsive” generators into separate modules to keep diffs small and reviews easier.
|
||||
|
||||
repomix-output (2)
|
||||
|
||||
Architecture & Modularity
|
||||
|
||||
Storage adapters: The three adapters share “describe/list/count” concepts. Introduce a tiny shared “CollectionIntrospection” mixin (interfaces + default fallbacks), and keep only API specifics in each adapter. This will simplify the TUI’s StorageManager as well.
|
||||
|
||||
Ingestion flows: Good use of Prefect tasks/flows with retries, variables, and tagging. The Firecrawl→R2R specialization cleanly reuses common steps. The batch boundaries are clear and progress updates are threaded back to the UI.
|
||||
|
||||
DX / Testing
|
||||
|
||||
Unit tests: In this export I don’t see tests. Add lightweight tests for:
|
||||
|
||||
Vector extraction/format parsing in Vectorizer (covers multiple providers).
|
||||
|
||||
repomix-output (2)
|
||||
|
||||
Weaviate adapters: property building, name normalization, vector extraction.
|
||||
|
||||
repomix-output (2)
|
||||
|
||||
MetadataTagger sanitization rules.
|
||||
|
||||
repomix-output (2)
|
||||
|
||||
TUI: use Textual’s pilot to test that notifications and timers trigger and that screens transition without blocking (verifies the sleep fix).
|
||||
|
||||
Static analysis: You already have excellent typing. Add ruff + mypy --strict and a pre-commit config to keep it consistent across contributors.
|
||||
@@ -19,7 +19,6 @@ actions:
|
||||
source: inferred
|
||||
enabled: true
|
||||
""",
|
||||
|
||||
"retry_failed": """
|
||||
name: Retry Failed Ingestion Flows
|
||||
description: Retries failed ingestion flows with original parameters
|
||||
@@ -39,7 +38,6 @@ actions:
|
||||
validate_first: false
|
||||
enabled: true
|
||||
""",
|
||||
|
||||
"resource_monitoring": """
|
||||
name: Manage Work Pool Based on Resources
|
||||
description: Pauses work pool when system resources are constrained
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -139,6 +139,7 @@ def ingest(
|
||||
# If we're already in an event loop (e.g., in Jupyter), use nest_asyncio
|
||||
try:
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
result = asyncio.run(run_with_progress())
|
||||
except ImportError:
|
||||
@@ -449,7 +450,9 @@ def search(
|
||||
def blocks_command() -> None:
|
||||
"""🧩 List and manage Prefect Blocks."""
|
||||
console.print("[bold cyan]📦 Prefect Blocks Management[/bold cyan]")
|
||||
console.print("Use 'prefect block register --module ingest_pipeline.core.models' to register custom blocks")
|
||||
console.print(
|
||||
"Use 'prefect block register --module ingest_pipeline.core.models' to register custom blocks"
|
||||
)
|
||||
console.print("Use 'prefect block ls' to list available blocks")
|
||||
|
||||
|
||||
@@ -507,7 +510,9 @@ async def run_list_collections() -> None:
|
||||
weaviate_config = StorageConfig(
|
||||
backend=StorageBackend.WEAVIATE,
|
||||
endpoint=settings.weaviate_endpoint,
|
||||
api_key=SecretStr(settings.weaviate_api_key) if settings.weaviate_api_key is not None else None,
|
||||
api_key=SecretStr(settings.weaviate_api_key)
|
||||
if settings.weaviate_api_key is not None
|
||||
else None,
|
||||
collection_name="default",
|
||||
)
|
||||
weaviate = WeaviateStorage(weaviate_config)
|
||||
@@ -528,7 +533,9 @@ async def run_list_collections() -> None:
|
||||
openwebui_config = StorageConfig(
|
||||
backend=StorageBackend.OPEN_WEBUI,
|
||||
endpoint=settings.openwebui_endpoint,
|
||||
api_key=SecretStr(settings.openwebui_api_key) if settings.openwebui_api_key is not None else None,
|
||||
api_key=SecretStr(settings.openwebui_api_key)
|
||||
if settings.openwebui_api_key is not None
|
||||
else None,
|
||||
collection_name="default",
|
||||
)
|
||||
openwebui = OpenWebUIStorage(openwebui_config)
|
||||
@@ -593,7 +600,9 @@ async def run_search(query: str, collection: str | None, backend: str, limit: in
|
||||
weaviate_config = StorageConfig(
|
||||
backend=StorageBackend.WEAVIATE,
|
||||
endpoint=settings.weaviate_endpoint,
|
||||
api_key=SecretStr(settings.weaviate_api_key) if settings.weaviate_api_key is not None else None,
|
||||
api_key=SecretStr(settings.weaviate_api_key)
|
||||
if settings.weaviate_api_key is not None
|
||||
else None,
|
||||
collection_name=collection or "default",
|
||||
)
|
||||
weaviate = WeaviateStorage(weaviate_config)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -31,7 +31,6 @@ else: # pragma: no cover - optional dependency fallback
|
||||
R2RStorage = BaseStorage
|
||||
|
||||
|
||||
|
||||
class CollectionManagementApp(App[None]):
|
||||
"""Enhanced modern Textual application with comprehensive keyboard navigation."""
|
||||
|
||||
|
||||
@@ -57,7 +57,9 @@ class ResponsiveGrid(Container):
|
||||
markup: bool = True,
|
||||
) -> None:
|
||||
"""Initialize responsive grid."""
|
||||
super().__init__(*children, name=name, id=id, classes=classes, disabled=disabled, markup=markup)
|
||||
super().__init__(
|
||||
*children, name=name, id=id, classes=classes, disabled=disabled, markup=markup
|
||||
)
|
||||
self._columns: int = columns
|
||||
self._auto_fit: bool = auto_fit
|
||||
self._compact: bool = compact
|
||||
@@ -327,7 +329,9 @@ class CardLayout(ResponsiveGrid):
|
||||
) -> None:
|
||||
"""Initialize card layout with default settings for cards."""
|
||||
# Default to auto-fit cards with minimum width
|
||||
super().__init__(auto_fit=True, name=name, id=id, classes=classes, disabled=disabled, markup=markup)
|
||||
super().__init__(
|
||||
auto_fit=True, name=name, id=id, classes=classes, disabled=disabled, markup=markup
|
||||
)
|
||||
|
||||
|
||||
class SplitPane(Container):
|
||||
@@ -396,7 +400,9 @@ class SplitPane(Container):
|
||||
if self._vertical:
|
||||
cast(Widget, self).add_class("vertical")
|
||||
|
||||
pane_classes = ("top-pane", "bottom-pane") if self._vertical else ("left-pane", "right-pane")
|
||||
pane_classes = (
|
||||
("top-pane", "bottom-pane") if self._vertical else ("left-pane", "right-pane")
|
||||
)
|
||||
|
||||
yield Container(self._left_content, classes=pane_classes[0])
|
||||
yield Static("", classes="splitter")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Data models and TypedDict definitions for the TUI."""
|
||||
|
||||
from enum import IntEnum
|
||||
from typing import Any, TypedDict
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class StorageCapabilities(IntEnum):
|
||||
@@ -47,7 +47,7 @@ class ChunkInfo(TypedDict):
|
||||
content: str
|
||||
start_index: int
|
||||
end_index: int
|
||||
metadata: dict[str, Any]
|
||||
metadata: dict[str, object]
|
||||
|
||||
|
||||
class EntityInfo(TypedDict):
|
||||
@@ -57,7 +57,7 @@ class EntityInfo(TypedDict):
|
||||
name: str
|
||||
type: str
|
||||
confidence: float
|
||||
metadata: dict[str, Any]
|
||||
metadata: dict[str, object]
|
||||
|
||||
|
||||
class FirecrawlOptions(TypedDict, total=False):
|
||||
@@ -77,7 +77,7 @@ class FirecrawlOptions(TypedDict, total=False):
|
||||
max_depth: int
|
||||
|
||||
# Extraction options
|
||||
extract_schema: dict[str, Any] | None
|
||||
extract_schema: dict[str, object] | None
|
||||
extract_prompt: str | None
|
||||
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -29,7 +29,7 @@ class BaseScreen(Screen[object]):
|
||||
name: str | None = None,
|
||||
id: str | None = None,
|
||||
classes: str | None = None,
|
||||
**kwargs: object
|
||||
**kwargs: object,
|
||||
) -> None:
|
||||
"""Initialize base screen."""
|
||||
super().__init__(name=name, id=id, classes=classes)
|
||||
@@ -55,7 +55,7 @@ class CRUDScreen(BaseScreen, Generic[T]):
|
||||
name: str | None = None,
|
||||
id: str | None = None,
|
||||
classes: str | None = None,
|
||||
**kwargs: object
|
||||
**kwargs: object,
|
||||
) -> None:
|
||||
"""Initialize CRUD screen."""
|
||||
super().__init__(storage_manager, name=name, id=id, classes=classes)
|
||||
@@ -311,7 +311,7 @@ class FormScreen(ModalScreen[T], Generic[T]):
|
||||
name: str | None = None,
|
||||
id: str | None = None,
|
||||
classes: str | None = None,
|
||||
**kwargs: object
|
||||
**kwargs: object,
|
||||
) -> None:
|
||||
"""Initialize form screen."""
|
||||
super().__init__(name=name, id=id, classes=classes)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Main dashboard screen with collections overview."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Final
|
||||
|
||||
from textual import work
|
||||
@@ -357,7 +356,6 @@ class CollectionOverviewScreen(Screen[None]):
|
||||
self.is_loading = False
|
||||
loading_indicator.display = False
|
||||
|
||||
|
||||
async def update_collections_table(self) -> None:
|
||||
"""Update the collections table with enhanced formatting."""
|
||||
table = self.query_one("#collections_table", EnhancedDataTable)
|
||||
|
||||
@@ -82,7 +82,10 @@ class ConfirmDeleteScreen(Screen[None]):
|
||||
try:
|
||||
if self.collection["type"] == "weaviate" and self.parent_screen.weaviate:
|
||||
# Delete Weaviate collection
|
||||
if self.parent_screen.weaviate.client and self.parent_screen.weaviate.client.collections:
|
||||
if (
|
||||
self.parent_screen.weaviate.client
|
||||
and self.parent_screen.weaviate.client.collections
|
||||
):
|
||||
self.parent_screen.weaviate.client.collections.delete(self.collection["name"])
|
||||
self.notify(
|
||||
f"Deleted Weaviate collection: {self.collection['name']}",
|
||||
@@ -100,7 +103,7 @@ class ConfirmDeleteScreen(Screen[None]):
|
||||
return
|
||||
|
||||
# Check if the storage backend supports collection deletion
|
||||
if not hasattr(storage_backend, 'delete_collection'):
|
||||
if not hasattr(storage_backend, "delete_collection"):
|
||||
self.notify(
|
||||
f"❌ Collection deletion not supported for {self.collection['type']} backend",
|
||||
severity="error",
|
||||
@@ -113,10 +116,13 @@ class ConfirmDeleteScreen(Screen[None]):
|
||||
collection_name = str(self.collection["name"])
|
||||
collection_type = str(self.collection["type"])
|
||||
|
||||
self.notify(f"Deleting {collection_type} collection: {collection_name}...", severity="information")
|
||||
self.notify(
|
||||
f"Deleting {collection_type} collection: {collection_name}...",
|
||||
severity="information",
|
||||
)
|
||||
|
||||
# Use the standard delete_collection method for all backends
|
||||
if hasattr(storage_backend, 'delete_collection'):
|
||||
if hasattr(storage_backend, "delete_collection"):
|
||||
success = await storage_backend.delete_collection(collection_name)
|
||||
else:
|
||||
self.notify("❌ Backend does not support collection deletion", severity="error")
|
||||
@@ -149,7 +155,6 @@ class ConfirmDeleteScreen(Screen[None]):
|
||||
self.parent_screen.refresh_collections()
|
||||
|
||||
|
||||
|
||||
class ConfirmDocumentDeleteScreen(Screen[None]):
|
||||
"""Screen for confirming document deletion."""
|
||||
|
||||
@@ -223,11 +228,11 @@ class ConfirmDocumentDeleteScreen(Screen[None]):
|
||||
|
||||
try:
|
||||
results: dict[str, bool] = {}
|
||||
if hasattr(self.parent_screen, 'storage') and self.parent_screen.storage:
|
||||
if hasattr(self.parent_screen, "storage") and self.parent_screen.storage:
|
||||
# Delete documents via storage
|
||||
# The storage should have delete_documents method for weaviate
|
||||
storage = self.parent_screen.storage
|
||||
if hasattr(storage, 'delete_documents'):
|
||||
if hasattr(storage, "delete_documents"):
|
||||
results = await storage.delete_documents(
|
||||
self.doc_ids,
|
||||
collection_name=self.collection["name"],
|
||||
@@ -280,7 +285,9 @@ class LogViewerScreen(ModalScreen[None]):
|
||||
yield Header(show_clock=True)
|
||||
yield Container(
|
||||
Static("📜 Live Application Logs", classes="title"),
|
||||
Static("Logs update in real time. Press S to reveal the log file path.", classes="subtitle"),
|
||||
Static(
|
||||
"Logs update in real time. Press S to reveal the log file path.", classes="subtitle"
|
||||
),
|
||||
RichLog(id="log_stream", classes="log-stream", wrap=True, highlight=False),
|
||||
Static("", id="log_file_path", classes="subtitle"),
|
||||
classes="main_container log-viewer-container",
|
||||
@@ -291,13 +298,13 @@ class LogViewerScreen(ModalScreen[None]):
|
||||
"""Attach this viewer to the parent application once mounted."""
|
||||
self._log_widget = self.query_one(RichLog)
|
||||
|
||||
if hasattr(self.app, 'attach_log_viewer'):
|
||||
if hasattr(self.app, "attach_log_viewer"):
|
||||
self.app.attach_log_viewer(self) # type: ignore[arg-type]
|
||||
|
||||
def on_unmount(self) -> None:
|
||||
"""Detach from the parent application when closed."""
|
||||
|
||||
if hasattr(self.app, 'detach_log_viewer'):
|
||||
if hasattr(self.app, "detach_log_viewer"):
|
||||
self.app.detach_log_viewer(self) # type: ignore[arg-type]
|
||||
|
||||
def _get_log_widget(self) -> RichLog:
|
||||
@@ -340,4 +347,6 @@ class LogViewerScreen(ModalScreen[None]):
|
||||
if self._log_file is None:
|
||||
self.notify("File logging is disabled for this session.", severity="warning")
|
||||
else:
|
||||
self.notify(f"Log file available at: {self._log_file}", severity="information", markup=False)
|
||||
self.notify(
|
||||
f"Log file available at: {self._log_file}", severity="information", markup=False
|
||||
)
|
||||
|
||||
@@ -116,7 +116,9 @@ class DocumentManagementScreen(Screen[None]):
|
||||
content_type=str(doc.get("content_type", "text/plain")),
|
||||
content_preview=str(doc.get("content_preview", "")),
|
||||
word_count=(
|
||||
lambda wc_val: int(wc_val) if isinstance(wc_val, (int, str)) and str(wc_val).isdigit() else 0
|
||||
lambda wc_val: int(wc_val)
|
||||
if isinstance(wc_val, (int, str)) and str(wc_val).isdigit()
|
||||
else 0
|
||||
)(doc.get("word_count", 0)),
|
||||
timestamp=str(doc.get("timestamp", "")),
|
||||
)
|
||||
@@ -126,7 +128,7 @@ class DocumentManagementScreen(Screen[None]):
|
||||
# For storage backends that don't support document listing, show a message
|
||||
self.notify(
|
||||
f"Document listing not supported for {self.storage.__class__.__name__}",
|
||||
severity="information"
|
||||
severity="information",
|
||||
)
|
||||
self.documents = []
|
||||
|
||||
@@ -330,7 +332,9 @@ class DocumentManagementScreen(Screen[None]):
|
||||
"""View the content of the currently selected document."""
|
||||
if doc := self.get_current_document():
|
||||
if self.storage:
|
||||
self.app.push_screen(DocumentContentModal(doc, self.storage, self.collection["name"]))
|
||||
self.app.push_screen(
|
||||
DocumentContentModal(doc, self.storage, self.collection["name"])
|
||||
)
|
||||
else:
|
||||
self.notify("No storage backend available", severity="error")
|
||||
else:
|
||||
@@ -381,13 +385,13 @@ class DocumentContentModal(ModalScreen[None]):
|
||||
yield Container(
|
||||
Static(
|
||||
f"📄 Document: {self.document['title'][:60]}{'...' if len(self.document['title']) > 60 else ''}",
|
||||
classes="modal-header"
|
||||
classes="modal-header",
|
||||
),
|
||||
ScrollableContainer(
|
||||
Markdown("Loading document content...", id="document_content"),
|
||||
LoadingIndicator(id="content_loading"),
|
||||
classes="modal-content"
|
||||
)
|
||||
classes="modal-content",
|
||||
),
|
||||
)
|
||||
|
||||
async def on_mount(self) -> None:
|
||||
@@ -398,30 +402,29 @@ class DocumentContentModal(ModalScreen[None]):
|
||||
try:
|
||||
# Get full document content
|
||||
doc_content = await self.storage.retrieve(
|
||||
self.document["id"],
|
||||
collection_name=self.collection_name
|
||||
self.document["id"], collection_name=self.collection_name
|
||||
)
|
||||
|
||||
# Format content for display
|
||||
if isinstance(doc_content, str):
|
||||
formatted_content = f"""# {self.document['title']}
|
||||
formatted_content = f"""# {self.document["title"]}
|
||||
|
||||
**Source:** {self.document.get('source_url', 'N/A')}
|
||||
**Type:** {self.document.get('content_type', 'text/plain')}
|
||||
**Words:** {self.document.get('word_count', 0):,}
|
||||
**Timestamp:** {self.document.get('timestamp', 'N/A')}
|
||||
**Source:** {self.document.get("source_url", "N/A")}
|
||||
**Type:** {self.document.get("content_type", "text/plain")}
|
||||
**Words:** {self.document.get("word_count", 0):,}
|
||||
**Timestamp:** {self.document.get("timestamp", "N/A")}
|
||||
|
||||
---
|
||||
|
||||
{doc_content}
|
||||
"""
|
||||
else:
|
||||
formatted_content = f"""# {self.document['title']}
|
||||
formatted_content = f"""# {self.document["title"]}
|
||||
|
||||
**Source:** {self.document.get('source_url', 'N/A')}
|
||||
**Type:** {self.document.get('content_type', 'text/plain')}
|
||||
**Words:** {self.document.get('word_count', 0):,}
|
||||
**Timestamp:** {self.document.get('timestamp', 'N/A')}
|
||||
**Source:** {self.document.get("source_url", "N/A")}
|
||||
**Type:** {self.document.get("content_type", "text/plain")}
|
||||
**Words:** {self.document.get("word_count", 0):,}
|
||||
**Timestamp:** {self.document.get("timestamp", "N/A")}
|
||||
|
||||
---
|
||||
|
||||
@@ -431,6 +434,8 @@ class DocumentContentModal(ModalScreen[None]):
|
||||
content_widget.update(formatted_content)
|
||||
|
||||
except Exception as e:
|
||||
content_widget.update(f"# Error Loading Document\n\nFailed to load document content: {e}")
|
||||
content_widget.update(
|
||||
f"# Error Loading Document\n\nFailed to load document content: {e}"
|
||||
)
|
||||
finally:
|
||||
loading.display = False
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from textual import work
|
||||
from textual.app import ComposeResult
|
||||
@@ -105,12 +105,23 @@ class IngestionScreen(ModalScreen[None]):
|
||||
Label("📋 Source Type (Press 1/2/3):", classes="input-label"),
|
||||
Horizontal(
|
||||
Button("🌐 Web (1)", id="web_btn", variant="primary", classes="type-button"),
|
||||
Button("📦 Repository (2)", id="repo_btn", variant="default", classes="type-button"),
|
||||
Button("📖 Documentation (3)", id="docs_btn", variant="default", classes="type-button"),
|
||||
Button(
|
||||
"📦 Repository (2)", id="repo_btn", variant="default", classes="type-button"
|
||||
),
|
||||
Button(
|
||||
"📖 Documentation (3)",
|
||||
id="docs_btn",
|
||||
variant="default",
|
||||
classes="type-button",
|
||||
),
|
||||
classes="type_buttons",
|
||||
),
|
||||
Rule(line_style="dashed"),
|
||||
Label(f"🗄️ Target Storages ({len(self.available_backends)} available):", classes="input-label", id="backend_label"),
|
||||
Label(
|
||||
f"🗄️ Target Storages ({len(self.available_backends)} available):",
|
||||
classes="input-label",
|
||||
id="backend_label",
|
||||
),
|
||||
Container(
|
||||
*self._create_backend_checkbox_widgets(),
|
||||
classes="backend-selection",
|
||||
@@ -139,7 +150,6 @@ class IngestionScreen(ModalScreen[None]):
|
||||
|
||||
yield LoadingIndicator(id="loading", classes="pulse")
|
||||
|
||||
|
||||
def _create_backend_checkbox_widgets(self) -> list[Checkbox]:
|
||||
"""Create checkbox widgets for each available backend."""
|
||||
checkboxes: list[Checkbox] = [
|
||||
@@ -219,19 +229,26 @@ class IngestionScreen(ModalScreen[None]):
|
||||
collection_name = collection_input.value.strip()
|
||||
|
||||
if not source_url:
|
||||
self.notify("🔍 Please enter a source URL", severity="error")
|
||||
cast("CollectionManagementApp", self.app).safe_notify(
|
||||
"🔍 Please enter a source URL", severity="error"
|
||||
)
|
||||
url_input.focus()
|
||||
return
|
||||
|
||||
# Validate URL format
|
||||
if not self._validate_url(source_url):
|
||||
self.notify("❌ Invalid URL format. Please enter a valid HTTP/HTTPS URL or file:// path", severity="error")
|
||||
cast("CollectionManagementApp", self.app).safe_notify(
|
||||
"❌ Invalid URL format. Please enter a valid HTTP/HTTPS URL or file:// path",
|
||||
severity="error",
|
||||
)
|
||||
url_input.focus()
|
||||
return
|
||||
|
||||
resolved_backends = self._resolve_selected_backends()
|
||||
if not resolved_backends:
|
||||
self.notify("⚠️ Select at least one storage backend", severity="warning")
|
||||
cast("CollectionManagementApp", self.app).safe_notify(
|
||||
"⚠️ Select at least one storage backend", severity="warning"
|
||||
)
|
||||
return
|
||||
|
||||
self.selected_backends = resolved_backends
|
||||
@@ -246,18 +263,16 @@ class IngestionScreen(ModalScreen[None]):
|
||||
url_lower = url.lower()
|
||||
|
||||
# Allow HTTP/HTTPS URLs
|
||||
if url_lower.startswith(('http://', 'https://')):
|
||||
if url_lower.startswith(("http://", "https://")):
|
||||
# Additional validation could be added here
|
||||
return True
|
||||
|
||||
# Allow file:// URLs for repository paths
|
||||
if url_lower.startswith('file://'):
|
||||
if url_lower.startswith("file://"):
|
||||
return True
|
||||
|
||||
# Allow local file paths that look like repositories
|
||||
return '/' in url and not url_lower.startswith(
|
||||
('javascript:', 'data:', 'vbscript:')
|
||||
)
|
||||
return "/" in url and not url_lower.startswith(("javascript:", "data:", "vbscript:"))
|
||||
|
||||
def _resolve_selected_backends(self) -> list[StorageBackend]:
|
||||
selected: list[StorageBackend] = []
|
||||
@@ -288,13 +303,14 @@ class IngestionScreen(ModalScreen[None]):
|
||||
if not self.selected_backends:
|
||||
status_widget.update("📋 Selected: None")
|
||||
elif len(self.selected_backends) == 1:
|
||||
backend_name = BACKEND_LABELS.get(self.selected_backends[0], self.selected_backends[0].value)
|
||||
backend_name = BACKEND_LABELS.get(
|
||||
self.selected_backends[0], self.selected_backends[0].value
|
||||
)
|
||||
status_widget.update(f"📋 Selected: {backend_name}")
|
||||
else:
|
||||
# Multiple backends selected
|
||||
backend_names = [
|
||||
BACKEND_LABELS.get(backend, backend.value)
|
||||
for backend in self.selected_backends
|
||||
BACKEND_LABELS.get(backend, backend.value) for backend in self.selected_backends
|
||||
]
|
||||
if len(backend_names) <= 3:
|
||||
# Show all names if 3 or fewer
|
||||
@@ -319,7 +335,10 @@ class IngestionScreen(ModalScreen[None]):
|
||||
for backend in BACKEND_ORDER:
|
||||
if backend not in self.available_backends:
|
||||
continue
|
||||
if backend.value.lower() == backend_name_lower or backend.name.lower() == backend_name_lower:
|
||||
if (
|
||||
backend.value.lower() == backend_name_lower
|
||||
or backend.name.lower() == backend_name_lower
|
||||
):
|
||||
matched_backends.append(backend)
|
||||
break
|
||||
return matched_backends or [self.available_backends[0]]
|
||||
@@ -351,6 +370,7 @@ class IngestionScreen(ModalScreen[None]):
|
||||
loading.display = False
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cast("CollectionManagementApp", self.app).call_from_thread(_update)
|
||||
|
||||
def progress_reporter(percent: int, message: str) -> None:
|
||||
@@ -362,6 +382,7 @@ class IngestionScreen(ModalScreen[None]):
|
||||
progress_text.update(message)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cast("CollectionManagementApp", self.app).call_from_thread(_update_progress)
|
||||
|
||||
try:
|
||||
@@ -377,13 +398,18 @@ class IngestionScreen(ModalScreen[None]):
|
||||
|
||||
for i, backend in enumerate(backends):
|
||||
progress_percent = 20 + (60 * i) // len(backends)
|
||||
progress_reporter(progress_percent, f"🔗 Processing {backend.value} backend ({i+1}/{len(backends)})...")
|
||||
progress_reporter(
|
||||
progress_percent,
|
||||
f"🔗 Processing {backend.value} backend ({i + 1}/{len(backends)})...",
|
||||
)
|
||||
|
||||
try:
|
||||
# Run the Prefect flow for this backend using asyncio.run with timeout
|
||||
import asyncio
|
||||
|
||||
async def run_flow_with_timeout(current_backend: StorageBackend = backend) -> IngestionResult:
|
||||
async def run_flow_with_timeout(
|
||||
current_backend: StorageBackend = backend,
|
||||
) -> IngestionResult:
|
||||
return await asyncio.wait_for(
|
||||
create_ingestion_flow(
|
||||
source_url=source_url,
|
||||
@@ -392,7 +418,7 @@ class IngestionScreen(ModalScreen[None]):
|
||||
collection_name=final_collection_name,
|
||||
progress_callback=progress_reporter,
|
||||
),
|
||||
timeout=600.0 # 10 minute timeout
|
||||
timeout=600.0, # 10 minute timeout
|
||||
)
|
||||
|
||||
result = asyncio.run(run_flow_with_timeout())
|
||||
@@ -401,25 +427,33 @@ class IngestionScreen(ModalScreen[None]):
|
||||
total_failed += result.documents_failed
|
||||
|
||||
if result.error_messages:
|
||||
flow_errors.extend([f"{backend.value}: {err}" for err in result.error_messages])
|
||||
flow_errors.extend(
|
||||
[f"{backend.value}: {err}" for err in result.error_messages]
|
||||
)
|
||||
|
||||
except TimeoutError:
|
||||
error_msg = f"{backend.value}: Timeout after 10 minutes"
|
||||
flow_errors.append(error_msg)
|
||||
progress_reporter(0, f"❌ {backend.value} timed out")
|
||||
def notify_timeout(msg: str = f"⏰ {backend.value} flow timed out after 10 minutes") -> None:
|
||||
|
||||
def notify_timeout(
|
||||
msg: str = f"⏰ {backend.value} flow timed out after 10 minutes",
|
||||
) -> None:
|
||||
try:
|
||||
self.notify(msg, severity="error", markup=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cast("CollectionManagementApp", self.app).call_from_thread(notify_timeout)
|
||||
except Exception as exc:
|
||||
flow_errors.append(f"{backend.value}: {exc}")
|
||||
|
||||
def notify_error(msg: str = f"❌ {backend.value} flow failed: {exc}") -> None:
|
||||
try:
|
||||
self.notify(msg, severity="error", markup=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cast("CollectionManagementApp", self.app).call_from_thread(notify_error)
|
||||
|
||||
successful = total_successful
|
||||
@@ -445,20 +479,37 @@ class IngestionScreen(ModalScreen[None]):
|
||||
|
||||
cast("CollectionManagementApp", self.app).call_from_thread(notify_results)
|
||||
|
||||
import time
|
||||
time.sleep(2)
|
||||
cast("CollectionManagementApp", self.app).pop_screen()
|
||||
def _pop() -> None:
|
||||
try:
|
||||
self.app.pop_screen()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Schedule screen pop via timer instead of blocking
|
||||
cast("CollectionManagementApp", self.app).call_from_thread(
|
||||
lambda: self.app.set_timer(2.0, _pop)
|
||||
)
|
||||
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
progress_reporter(0, f"❌ Prefect flows error: {exc}")
|
||||
|
||||
def notify_error(msg: str = f"❌ Prefect flows failed: {exc}") -> None:
|
||||
try:
|
||||
self.notify(msg, severity="error")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cast("CollectionManagementApp", self.app).call_from_thread(notify_error)
|
||||
import time
|
||||
time.sleep(2)
|
||||
|
||||
def _pop_on_error() -> None:
|
||||
try:
|
||||
self.app.pop_screen()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Schedule screen pop via timer for error case too
|
||||
cast("CollectionManagementApp", self.app).call_from_thread(
|
||||
lambda: self.app.set_timer(2.0, _pop_on_error)
|
||||
)
|
||||
finally:
|
||||
update_ui("hide_loading")
|
||||
|
||||
|
||||
@@ -43,12 +43,24 @@ class SearchScreen(Screen[None]):
|
||||
@override
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header()
|
||||
# Check if search is supported for this backend
|
||||
backends = self.collection["backend"]
|
||||
if isinstance(backends, str):
|
||||
backends = [backends]
|
||||
search_supported = "weaviate" in backends
|
||||
search_indicator = "✅ Search supported" if search_supported else "❌ Search not supported"
|
||||
|
||||
yield Container(
|
||||
Static(
|
||||
f"🔍 Search in: {self.collection['name']} ({self.collection['backend']})",
|
||||
f"🔍 Search in: {self.collection['name']} ({', '.join(backends)}) - {search_indicator}",
|
||||
classes="title",
|
||||
),
|
||||
Static("Press / or Ctrl+F to focus search, Enter to search", classes="subtitle"),
|
||||
Static(
|
||||
"Press / or Ctrl+F to focus search, Enter to search"
|
||||
if search_supported
|
||||
else "Search functionality not available for this backend",
|
||||
classes="subtitle",
|
||||
),
|
||||
Input(placeholder="Enter search query... (press Enter to search)", id="search_input"),
|
||||
Button("🔍 Search", id="search_btn", variant="primary"),
|
||||
Button("🗑️ Clear Results", id="clear_btn", variant="default"),
|
||||
@@ -127,7 +139,9 @@ class SearchScreen(Screen[None]):
|
||||
finally:
|
||||
loading.display = False
|
||||
|
||||
def _setup_search_ui(self, loading: LoadingIndicator, table: EnhancedDataTable, status: Static, query: str) -> None:
|
||||
def _setup_search_ui(
|
||||
self, loading: LoadingIndicator, table: EnhancedDataTable, status: Static, query: str
|
||||
) -> None:
|
||||
"""Setup the search UI elements."""
|
||||
loading.display = True
|
||||
status.update(f"🔍 Searching for '{query}'...")
|
||||
@@ -139,10 +153,18 @@ class SearchScreen(Screen[None]):
|
||||
if self.collection["type"] == "weaviate" and self.weaviate:
|
||||
return await self.search_weaviate(query)
|
||||
elif self.collection["type"] == "openwebui" and self.openwebui:
|
||||
return await self.search_openwebui(query)
|
||||
# OpenWebUI search is not yet implemented
|
||||
self.notify("Search not supported for OpenWebUI collections", severity="warning")
|
||||
return []
|
||||
elif self.collection["type"] == "r2r":
|
||||
# R2R search would go here when implemented
|
||||
self.notify("Search not supported for R2R collections", severity="warning")
|
||||
return []
|
||||
return []
|
||||
|
||||
def _populate_results_table(self, table: EnhancedDataTable, results: list[dict[str, str | float]]) -> None:
|
||||
def _populate_results_table(
|
||||
self, table: EnhancedDataTable, results: list[dict[str, str | float]]
|
||||
) -> None:
|
||||
"""Populate the results table with search results."""
|
||||
for result in results:
|
||||
row_data = self._format_result_row(result)
|
||||
@@ -193,7 +215,11 @@ class SearchScreen(Screen[None]):
|
||||
return str(score)
|
||||
|
||||
def _update_search_status(
|
||||
self, status: Static, query: str, results: list[dict[str, str | float]], table: EnhancedDataTable
|
||||
self,
|
||||
status: Static,
|
||||
query: str,
|
||||
results: list[dict[str, str | float]],
|
||||
table: EnhancedDataTable,
|
||||
) -> None:
|
||||
"""Update search status and notifications based on results."""
|
||||
if not results:
|
||||
|
||||
@@ -13,6 +13,8 @@ TextualApp = App[object]
|
||||
class AppProtocol(Protocol):
|
||||
"""Protocol for apps that support CSS and refresh."""
|
||||
|
||||
CSS: str
|
||||
|
||||
def refresh(self) -> None:
|
||||
"""Refresh the app."""
|
||||
...
|
||||
@@ -1122,16 +1124,16 @@ def get_css_for_theme(theme_type: ThemeType) -> str:
|
||||
def apply_theme_to_app(app: TextualApp | AppProtocol, theme_type: ThemeType) -> None:
|
||||
"""Apply a theme to a Textual app instance."""
|
||||
try:
|
||||
css = set_theme(theme_type)
|
||||
# Set CSS using the standard Textual approach
|
||||
if hasattr(app, "CSS") or isinstance(app, App):
|
||||
setattr(app, "CSS", css)
|
||||
# Refresh the app to apply new CSS
|
||||
if hasattr(app, "refresh"):
|
||||
app.refresh()
|
||||
# Note: CSS class variable cannot be changed at runtime
|
||||
# This function would need to be called during app initialization
|
||||
# or implement a different approach for dynamic theming
|
||||
_ = set_theme(theme_type) # Keep for future implementation
|
||||
if hasattr(app, "refresh"):
|
||||
app.refresh()
|
||||
except Exception as e:
|
||||
# Graceful fallback - log but don't crash the UI
|
||||
import logging
|
||||
|
||||
logging.debug(f"Failed to apply theme to app: {e}")
|
||||
|
||||
|
||||
@@ -1185,11 +1187,11 @@ class ThemeSwitcher:
|
||||
|
||||
# Responsive breakpoints for dynamic layout adaptation
|
||||
RESPONSIVE_BREAKPOINTS = {
|
||||
"xs": 40, # Extra small terminals
|
||||
"sm": 60, # Small terminals
|
||||
"md": 100, # Medium terminals
|
||||
"lg": 140, # Large terminals
|
||||
"xl": 180, # Extra large terminals
|
||||
"xs": 40, # Extra small terminals
|
||||
"sm": 60, # Small terminals
|
||||
"md": 100, # Medium terminals
|
||||
"lg": 140, # Large terminals
|
||||
"xl": 180, # Extra large terminals
|
||||
}
|
||||
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -10,8 +10,9 @@ from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import NamedTuple
|
||||
|
||||
import platformdirs
|
||||
|
||||
from ....config import configure_prefect, get_settings
|
||||
from ....core.models import StorageBackend
|
||||
from .storage_manager import StorageManager
|
||||
|
||||
|
||||
@@ -53,21 +54,36 @@ def _configure_tui_logging(*, log_level: str) -> _TuiLoggingContext:
|
||||
|
||||
log_file: Path | None = None
|
||||
try:
|
||||
# Try current directory first for development
|
||||
log_dir = Path.cwd() / "logs"
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
log_file = log_dir / "tui.log"
|
||||
file_handler = RotatingFileHandler(
|
||||
log_file,
|
||||
maxBytes=2_000_000,
|
||||
backupCount=5,
|
||||
encoding="utf-8",
|
||||
)
|
||||
file_handler.setLevel(resolved_level)
|
||||
file_handler.setFormatter(formatter)
|
||||
root_logger.addHandler(file_handler)
|
||||
except OSError as exc: # pragma: no cover - filesystem specific
|
||||
fallback = logging.getLogger(__name__)
|
||||
fallback.warning("Failed to configure file logging for TUI: %s", exc)
|
||||
except OSError:
|
||||
# Fall back to user log directory
|
||||
try:
|
||||
log_dir = Path(platformdirs.user_log_dir("ingest-pipeline", "ingest-pipeline"))
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
log_file = log_dir / "tui.log"
|
||||
except OSError as exc:
|
||||
fallback = logging.getLogger(__name__)
|
||||
fallback.warning("Failed to create log directory, file logging disabled: %s", exc)
|
||||
log_file = None
|
||||
|
||||
if log_file:
|
||||
try:
|
||||
file_handler = RotatingFileHandler(
|
||||
log_file,
|
||||
maxBytes=2_000_000,
|
||||
backupCount=5,
|
||||
encoding="utf-8",
|
||||
)
|
||||
file_handler.setLevel(resolved_level)
|
||||
file_handler.setFormatter(formatter)
|
||||
root_logger.addHandler(file_handler)
|
||||
except OSError as exc:
|
||||
fallback = logging.getLogger(__name__)
|
||||
fallback.warning("Failed to configure file logging for TUI: %s", exc)
|
||||
log_file = None
|
||||
|
||||
_logging_context = _TuiLoggingContext(log_queue, formatter, log_file)
|
||||
return _logging_context
|
||||
@@ -93,6 +109,7 @@ async def run_textual_tui() -> None:
|
||||
|
||||
# Import here to avoid circular import
|
||||
from ..app import CollectionManagementApp
|
||||
|
||||
app = CollectionManagementApp(
|
||||
storage_manager,
|
||||
None, # weaviate - will be available after initialization
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Storage management utilities for TUI applications."""
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
@@ -11,12 +10,11 @@ from pydantic import SecretStr
|
||||
|
||||
from ....core.exceptions import StorageError
|
||||
from ....core.models import Document, StorageBackend, StorageConfig
|
||||
from ..models import CollectionInfo, StorageCapabilities
|
||||
|
||||
from ....storage.base import BaseStorage
|
||||
from ....storage.openwebui import OpenWebUIStorage
|
||||
from ....storage.r2r.storage import R2RStorage
|
||||
from ....storage.weaviate import WeaviateStorage
|
||||
from ..models import CollectionInfo, StorageCapabilities
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ....config.settings import Settings
|
||||
@@ -39,7 +37,6 @@ class StorageBackendProtocol(Protocol):
|
||||
async def close(self) -> None: ...
|
||||
|
||||
|
||||
|
||||
class MultiStorageAdapter(BaseStorage):
|
||||
"""Mirror writes to multiple storage backends."""
|
||||
|
||||
@@ -70,7 +67,10 @@ class MultiStorageAdapter(BaseStorage):
|
||||
|
||||
# Replicate to secondary backends concurrently
|
||||
if len(self._storages) > 1:
|
||||
async def replicate_to_backend(storage: BaseStorage) -> tuple[BaseStorage, bool, Exception | None]:
|
||||
|
||||
async def replicate_to_backend(
|
||||
storage: BaseStorage,
|
||||
) -> tuple[BaseStorage, bool, Exception | None]:
|
||||
try:
|
||||
await storage.store(document, collection_name=collection_name)
|
||||
return storage, True, None
|
||||
@@ -106,11 +106,16 @@ class MultiStorageAdapter(BaseStorage):
|
||||
self, documents: list[Document], *, collection_name: str | None = None
|
||||
) -> list[str]:
|
||||
# Store in primary backend first
|
||||
primary_ids: list[str] = await self._primary.store_batch(documents, collection_name=collection_name)
|
||||
primary_ids: list[str] = await self._primary.store_batch(
|
||||
documents, collection_name=collection_name
|
||||
)
|
||||
|
||||
# Replicate to secondary backends concurrently
|
||||
if len(self._storages) > 1:
|
||||
async def replicate_batch_to_backend(storage: BaseStorage) -> tuple[BaseStorage, bool, Exception | None]:
|
||||
|
||||
async def replicate_batch_to_backend(
|
||||
storage: BaseStorage,
|
||||
) -> tuple[BaseStorage, bool, Exception | None]:
|
||||
try:
|
||||
await storage.store_batch(documents, collection_name=collection_name)
|
||||
return storage, True, None
|
||||
@@ -135,7 +140,9 @@ class MultiStorageAdapter(BaseStorage):
|
||||
|
||||
if failures:
|
||||
backends = ", ".join(failures)
|
||||
primary_error = errors[0] if errors else Exception("Unknown batch replication error")
|
||||
primary_error = (
|
||||
errors[0] if errors else Exception("Unknown batch replication error")
|
||||
)
|
||||
raise StorageError(
|
||||
f"Batch stored in primary backend but replication failed for: {backends}"
|
||||
) from primary_error
|
||||
@@ -144,11 +151,16 @@ class MultiStorageAdapter(BaseStorage):
|
||||
|
||||
async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool:
|
||||
# Delete from primary backend first
|
||||
primary_deleted: bool = await self._primary.delete(document_id, collection_name=collection_name)
|
||||
primary_deleted: bool = await self._primary.delete(
|
||||
document_id, collection_name=collection_name
|
||||
)
|
||||
|
||||
# Delete from secondary backends concurrently
|
||||
if len(self._storages) > 1:
|
||||
async def delete_from_backend(storage: BaseStorage) -> tuple[BaseStorage, bool, Exception | None]:
|
||||
|
||||
async def delete_from_backend(
|
||||
storage: BaseStorage,
|
||||
) -> tuple[BaseStorage, bool, Exception | None]:
|
||||
try:
|
||||
await storage.delete(document_id, collection_name=collection_name)
|
||||
return storage, True, None
|
||||
@@ -222,7 +234,6 @@ class MultiStorageAdapter(BaseStorage):
|
||||
return class_name
|
||||
|
||||
|
||||
|
||||
class StorageManager:
|
||||
"""Centralized manager for all storage backend operations."""
|
||||
|
||||
@@ -237,7 +248,9 @@ class StorageManager:
|
||||
"""Initialize all available storage backends with timeout protection."""
|
||||
results: dict[StorageBackend, bool] = {}
|
||||
|
||||
async def init_backend(backend_type: StorageBackend, config: StorageConfig, storage_class: type[BaseStorage]) -> bool:
|
||||
async def init_backend(
|
||||
backend_type: StorageBackend, config: StorageConfig, storage_class: type[BaseStorage]
|
||||
) -> bool:
|
||||
"""Initialize a single backend with timeout."""
|
||||
try:
|
||||
storage = storage_class(config)
|
||||
@@ -261,10 +274,17 @@ class StorageManager:
|
||||
config = StorageConfig(
|
||||
backend=StorageBackend.WEAVIATE,
|
||||
endpoint=self.settings.weaviate_endpoint,
|
||||
api_key=SecretStr(self.settings.weaviate_api_key) if self.settings.weaviate_api_key else None,
|
||||
api_key=SecretStr(self.settings.weaviate_api_key)
|
||||
if self.settings.weaviate_api_key
|
||||
else None,
|
||||
collection_name="default",
|
||||
)
|
||||
tasks.append((StorageBackend.WEAVIATE, init_backend(StorageBackend.WEAVIATE, config, WeaviateStorage)))
|
||||
tasks.append(
|
||||
(
|
||||
StorageBackend.WEAVIATE,
|
||||
init_backend(StorageBackend.WEAVIATE, config, WeaviateStorage),
|
||||
)
|
||||
)
|
||||
else:
|
||||
results[StorageBackend.WEAVIATE] = False
|
||||
|
||||
@@ -273,10 +293,17 @@ class StorageManager:
|
||||
config = StorageConfig(
|
||||
backend=StorageBackend.OPEN_WEBUI,
|
||||
endpoint=self.settings.openwebui_endpoint,
|
||||
api_key=SecretStr(self.settings.openwebui_api_key) if self.settings.openwebui_api_key else None,
|
||||
api_key=SecretStr(self.settings.openwebui_api_key)
|
||||
if self.settings.openwebui_api_key
|
||||
else None,
|
||||
collection_name="default",
|
||||
)
|
||||
tasks.append((StorageBackend.OPEN_WEBUI, init_backend(StorageBackend.OPEN_WEBUI, config, OpenWebUIStorage)))
|
||||
tasks.append(
|
||||
(
|
||||
StorageBackend.OPEN_WEBUI,
|
||||
init_backend(StorageBackend.OPEN_WEBUI, config, OpenWebUIStorage),
|
||||
)
|
||||
)
|
||||
else:
|
||||
results[StorageBackend.OPEN_WEBUI] = False
|
||||
|
||||
@@ -295,7 +322,9 @@ class StorageManager:
|
||||
# Execute initialization tasks concurrently
|
||||
if tasks:
|
||||
backend_types, task_coroutines = zip(*tasks, strict=False)
|
||||
task_results: Sequence[bool | BaseException] = await asyncio.gather(*task_coroutines, return_exceptions=True)
|
||||
task_results: Sequence[bool | BaseException] = await asyncio.gather(
|
||||
*task_coroutines, return_exceptions=True
|
||||
)
|
||||
|
||||
for backend_type, task_result in zip(backend_types, task_results, strict=False):
|
||||
results[backend_type] = task_result if isinstance(task_result, bool) else False
|
||||
@@ -312,7 +341,9 @@ class StorageManager:
|
||||
storages: list[BaseStorage] = []
|
||||
seen: set[StorageBackend] = set()
|
||||
for backend in backends:
|
||||
backend_enum = backend if isinstance(backend, StorageBackend) else StorageBackend(backend)
|
||||
backend_enum = (
|
||||
backend if isinstance(backend, StorageBackend) else StorageBackend(backend)
|
||||
)
|
||||
if backend_enum in seen:
|
||||
continue
|
||||
seen.add(backend_enum)
|
||||
@@ -350,12 +381,18 @@ class StorageManager:
|
||||
except StorageError as e:
|
||||
# Storage-specific errors - log and use 0 count
|
||||
import logging
|
||||
logging.warning(f"Failed to get count for {collection_name} on {backend_type.value}: {e}")
|
||||
|
||||
logging.warning(
|
||||
f"Failed to get count for {collection_name} on {backend_type.value}: {e}"
|
||||
)
|
||||
count = 0
|
||||
except Exception as e:
|
||||
# Unexpected errors - log and skip this collection from this backend
|
||||
import logging
|
||||
logging.warning(f"Unexpected error counting {collection_name} on {backend_type.value}: {e}")
|
||||
|
||||
logging.warning(
|
||||
f"Unexpected error counting {collection_name} on {backend_type.value}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
size_mb = count * 0.01 # Rough estimate: 10KB per document
|
||||
@@ -446,7 +483,9 @@ class StorageManager:
|
||||
storage = self.backends.get(StorageBackend.R2R)
|
||||
return storage if isinstance(storage, R2RStorage) else None
|
||||
|
||||
async def get_backend_status(self) -> dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]]:
|
||||
async def get_backend_status(
|
||||
self,
|
||||
) -> dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]]:
|
||||
"""Get comprehensive status for all backends."""
|
||||
status: dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]] = {}
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -158,9 +158,7 @@ class ScrapeOptionsForm(Widget):
|
||||
formats.append("screenshot")
|
||||
options: dict[str, object] = {
|
||||
"formats": formats,
|
||||
"only_main_content": self.query_one(
|
||||
"#only_main_content", Switch
|
||||
).value,
|
||||
"only_main_content": self.query_one("#only_main_content", Switch).value,
|
||||
}
|
||||
include_tags_input = self.query_one("#include_tags", Input).value
|
||||
if include_tags_input.strip():
|
||||
@@ -195,11 +193,15 @@ class ScrapeOptionsForm(Widget):
|
||||
|
||||
if include_tags := options.get("include_tags", []):
|
||||
include_list = include_tags if isinstance(include_tags, list) else []
|
||||
self.query_one("#include_tags", Input).value = ", ".join(str(tag) for tag in include_list)
|
||||
self.query_one("#include_tags", Input).value = ", ".join(
|
||||
str(tag) for tag in include_list
|
||||
)
|
||||
|
||||
if exclude_tags := options.get("exclude_tags", []):
|
||||
exclude_list = exclude_tags if isinstance(exclude_tags, list) else []
|
||||
self.query_one("#exclude_tags", Input).value = ", ".join(str(tag) for tag in exclude_list)
|
||||
self.query_one("#exclude_tags", Input).value = ", ".join(
|
||||
str(tag) for tag in exclude_list
|
||||
)
|
||||
|
||||
# Set performance
|
||||
wait_for = options.get("wait_for")
|
||||
@@ -503,7 +505,9 @@ class ExtractOptionsForm(Widget):
|
||||
"content": "string",
|
||||
"tags": ["string"]
|
||||
}"""
|
||||
prompt_widget.text = "Extract article title, author, publication date, main content, and associated tags"
|
||||
prompt_widget.text = (
|
||||
"Extract article title, author, publication date, main content, and associated tags"
|
||||
)
|
||||
|
||||
elif event.button.id == "preset_product":
|
||||
schema_widget.text = """{
|
||||
@@ -513,7 +517,9 @@ class ExtractOptionsForm(Widget):
|
||||
"category": "string",
|
||||
"availability": "string"
|
||||
}"""
|
||||
prompt_widget.text = "Extract product name, price, description, category, and availability status"
|
||||
prompt_widget.text = (
|
||||
"Extract product name, price, description, category, and availability status"
|
||||
)
|
||||
|
||||
elif event.button.id == "preset_contact":
|
||||
schema_widget.text = """{
|
||||
@@ -523,7 +529,9 @@ class ExtractOptionsForm(Widget):
|
||||
"company": "string",
|
||||
"position": "string"
|
||||
}"""
|
||||
prompt_widget.text = "Extract contact information including name, email, phone, company, and position"
|
||||
prompt_widget.text = (
|
||||
"Extract contact information including name, email, phone, company, and position"
|
||||
)
|
||||
|
||||
elif event.button.id == "preset_data":
|
||||
schema_widget.text = """{
|
||||
|
||||
@@ -26,8 +26,12 @@ class StatusIndicator(Static):
|
||||
|
||||
status_lower = status.lower()
|
||||
|
||||
if (status_lower in {"active", "online", "connected", "✓ active"} or
|
||||
status_lower.endswith("active") or "✓" in status_lower and "active" in status_lower):
|
||||
if (
|
||||
status_lower in {"active", "online", "connected", "✓ active"}
|
||||
or status_lower.endswith("active")
|
||||
or "✓" in status_lower
|
||||
and "active" in status_lower
|
||||
):
|
||||
self.add_class("status-active")
|
||||
self.add_class("glow")
|
||||
self.update(f"🟢 {status}")
|
||||
|
||||
@@ -99,10 +99,17 @@ class ChunkViewer(Widget):
|
||||
"id": str(chunk_data.get("id", "")),
|
||||
"document_id": self.document_id,
|
||||
"content": str(chunk_data.get("text", "")),
|
||||
"start_index": (lambda si: int(si) if isinstance(si, (int, str)) else 0)(chunk_data.get("start_index", 0)),
|
||||
"end_index": (lambda ei: int(ei) if isinstance(ei, (int, str)) else 0)(chunk_data.get("end_index", 0)),
|
||||
"start_index": (lambda si: int(si) if isinstance(si, (int, str)) else 0)(
|
||||
chunk_data.get("start_index", 0)
|
||||
),
|
||||
"end_index": (lambda ei: int(ei) if isinstance(ei, (int, str)) else 0)(
|
||||
chunk_data.get("end_index", 0)
|
||||
),
|
||||
"metadata": (
|
||||
dict(metadata_val) if (metadata_val := chunk_data.get("metadata")) and isinstance(metadata_val, dict) else {}
|
||||
dict(metadata_val)
|
||||
if (metadata_val := chunk_data.get("metadata"))
|
||||
and isinstance(metadata_val, dict)
|
||||
else {}
|
||||
),
|
||||
}
|
||||
self.chunks.append(chunk_info)
|
||||
@@ -275,10 +282,10 @@ class EntityGraph(Widget):
|
||||
"""Show detailed information about an entity."""
|
||||
details_widget = self.query_one("#entity_details", Static)
|
||||
|
||||
details_text = f"""**Entity:** {entity['name']}
|
||||
**Type:** {entity['type']}
|
||||
**Confidence:** {entity['confidence']:.2%}
|
||||
**ID:** {entity['id']}
|
||||
details_text = f"""**Entity:** {entity["name"]}
|
||||
**Type:** {entity["type"]}
|
||||
**Confidence:** {entity["confidence"]:.2%}
|
||||
**ID:** {entity["id"]}
|
||||
|
||||
**Metadata:**
|
||||
"""
|
||||
|
||||
@@ -26,9 +26,15 @@ def _setup_prefect_settings() -> tuple[object, object, object]:
|
||||
if registry is not None:
|
||||
Setting = getattr(ps, "Setting", None)
|
||||
if Setting is not None:
|
||||
api_key = registry.get("PREFECT_API_KEY") or Setting("PREFECT_API_KEY", type_=str, default=None)
|
||||
api_url = registry.get("PREFECT_API_URL") or Setting("PREFECT_API_URL", type_=str, default=None)
|
||||
work_pool = registry.get("PREFECT_DEFAULT_WORK_POOL_NAME") or Setting("PREFECT_DEFAULT_WORK_POOL_NAME", type_=str, default=None)
|
||||
api_key = registry.get("PREFECT_API_KEY") or Setting(
|
||||
"PREFECT_API_KEY", type_=str, default=None
|
||||
)
|
||||
api_url = registry.get("PREFECT_API_URL") or Setting(
|
||||
"PREFECT_API_URL", type_=str, default=None
|
||||
)
|
||||
work_pool = registry.get("PREFECT_DEFAULT_WORK_POOL_NAME") or Setting(
|
||||
"PREFECT_DEFAULT_WORK_POOL_NAME", type_=str, default=None
|
||||
)
|
||||
return api_key, api_url, work_pool
|
||||
|
||||
except ImportError:
|
||||
@@ -37,6 +43,7 @@ def _setup_prefect_settings() -> tuple[object, object, object]:
|
||||
# Ultimate fallback
|
||||
return None, None, None
|
||||
|
||||
|
||||
PREFECT_API_KEY, PREFECT_API_URL, PREFECT_DEFAULT_WORK_POOL_NAME = _setup_prefect_settings()
|
||||
|
||||
# Import after Prefect settings setup to avoid circular dependencies
|
||||
@@ -53,20 +60,30 @@ def configure_prefect(settings: Settings) -> None:
|
||||
|
||||
overrides: dict[Setting, str] = {}
|
||||
|
||||
if settings.prefect_api_url is not None and PREFECT_API_URL is not None and isinstance(PREFECT_API_URL, Setting):
|
||||
if (
|
||||
settings.prefect_api_url is not None
|
||||
and PREFECT_API_URL is not None
|
||||
and isinstance(PREFECT_API_URL, Setting)
|
||||
):
|
||||
overrides[PREFECT_API_URL] = str(settings.prefect_api_url)
|
||||
if settings.prefect_api_key and PREFECT_API_KEY is not None and isinstance(PREFECT_API_KEY, Setting):
|
||||
if (
|
||||
settings.prefect_api_key
|
||||
and PREFECT_API_KEY is not None
|
||||
and isinstance(PREFECT_API_KEY, Setting)
|
||||
):
|
||||
overrides[PREFECT_API_KEY] = settings.prefect_api_key
|
||||
if settings.prefect_work_pool and PREFECT_DEFAULT_WORK_POOL_NAME is not None and isinstance(PREFECT_DEFAULT_WORK_POOL_NAME, Setting):
|
||||
if (
|
||||
settings.prefect_work_pool
|
||||
and PREFECT_DEFAULT_WORK_POOL_NAME is not None
|
||||
and isinstance(PREFECT_DEFAULT_WORK_POOL_NAME, Setting)
|
||||
):
|
||||
overrides[PREFECT_DEFAULT_WORK_POOL_NAME] = settings.prefect_work_pool
|
||||
|
||||
if not overrides:
|
||||
return
|
||||
|
||||
filtered_overrides = {
|
||||
setting: value
|
||||
for setting, value in overrides.items()
|
||||
if setting.value() != value
|
||||
setting: value for setting, value in overrides.items() if setting.value() != value
|
||||
}
|
||||
|
||||
if not filtered_overrides:
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -139,10 +139,11 @@ class Settings(BaseSettings):
|
||||
key_name, key_value = required_keys[backend]
|
||||
if not key_value:
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
f"{key_name} not set - authentication may fail for {backend} backend",
|
||||
UserWarning,
|
||||
stacklevel=2
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return self
|
||||
@@ -165,16 +166,24 @@ class PrefectVariableConfig:
|
||||
def __init__(self) -> None:
|
||||
self._settings: Settings = get_settings()
|
||||
self._variable_names: list[str] = [
|
||||
"default_batch_size", "max_file_size", "max_crawl_depth", "max_crawl_pages",
|
||||
"default_storage_backend", "default_collection_prefix", "max_concurrent_tasks",
|
||||
"request_timeout", "default_schedule_interval"
|
||||
"default_batch_size",
|
||||
"max_file_size",
|
||||
"max_crawl_depth",
|
||||
"max_crawl_pages",
|
||||
"default_storage_backend",
|
||||
"default_collection_prefix",
|
||||
"max_concurrent_tasks",
|
||||
"request_timeout",
|
||||
"default_schedule_interval",
|
||||
]
|
||||
|
||||
def _get_fallback_value(self, name: str, default_value: object = None) -> object:
|
||||
"""Get fallback value from settings or default."""
|
||||
return default_value or getattr(self._settings, name, default_value)
|
||||
|
||||
def get_with_fallback(self, name: str, default_value: str | int | float | None = None) -> str | int | float | None:
|
||||
def get_with_fallback(
|
||||
self, name: str, default_value: str | int | float | None = None
|
||||
) -> str | int | float | None:
|
||||
"""Get variable value with fallback synchronously."""
|
||||
fallback = self._get_fallback_value(name, default_value)
|
||||
# Ensure fallback is a type that Variable expects
|
||||
@@ -195,7 +204,9 @@ class PrefectVariableConfig:
|
||||
return fallback
|
||||
return str(fallback) if fallback is not None else None
|
||||
|
||||
async def get_with_fallback_async(self, name: str, default_value: str | int | float | None = None) -> str | int | float | None:
|
||||
async def get_with_fallback_async(
|
||||
self, name: str, default_value: str | int | float | None = None
|
||||
) -> str | int | float | None:
|
||||
"""Get variable value with fallback asynchronously."""
|
||||
fallback = self._get_fallback_value(name, default_value)
|
||||
variable_fallback = str(fallback) if fallback is not None else None
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -12,35 +12,36 @@ from ..config import get_settings
|
||||
|
||||
|
||||
def _default_embedding_model() -> str:
|
||||
return get_settings().embedding_model
|
||||
return str(get_settings().embedding_model)
|
||||
|
||||
|
||||
def _default_embedding_endpoint() -> HttpUrl:
|
||||
return get_settings().llm_endpoint
|
||||
endpoint = get_settings().llm_endpoint
|
||||
return endpoint if isinstance(endpoint, HttpUrl) else HttpUrl(str(endpoint))
|
||||
|
||||
|
||||
def _default_embedding_dimension() -> int:
|
||||
return get_settings().embedding_dimension
|
||||
return int(get_settings().embedding_dimension)
|
||||
|
||||
|
||||
def _default_batch_size() -> int:
|
||||
return get_settings().default_batch_size
|
||||
return int(get_settings().default_batch_size)
|
||||
|
||||
|
||||
def _default_collection_name() -> str:
|
||||
return get_settings().default_collection_prefix
|
||||
return str(get_settings().default_collection_prefix)
|
||||
|
||||
|
||||
def _default_max_crawl_depth() -> int:
|
||||
return get_settings().max_crawl_depth
|
||||
return int(get_settings().max_crawl_depth)
|
||||
|
||||
|
||||
def _default_max_crawl_pages() -> int:
|
||||
return get_settings().max_crawl_pages
|
||||
return int(get_settings().max_crawl_pages)
|
||||
|
||||
|
||||
def _default_max_file_size() -> int:
|
||||
return get_settings().max_file_size
|
||||
return int(get_settings().max_file_size)
|
||||
|
||||
|
||||
class IngestionStatus(str, Enum):
|
||||
@@ -84,13 +85,16 @@ class StorageConfig(Block):
|
||||
|
||||
_block_type_name: ClassVar[str | None] = "Storage Configuration"
|
||||
_block_type_slug: ClassVar[str | None] = "storage-config"
|
||||
_description: ClassVar[str | None] = "Configures storage backend connections and settings for document ingestion"
|
||||
_description: ClassVar[str | None] = (
|
||||
"Configures storage backend connections and settings for document ingestion"
|
||||
)
|
||||
|
||||
backend: StorageBackend
|
||||
endpoint: HttpUrl
|
||||
api_key: SecretStr | None = Field(default=None)
|
||||
collection_name: str = Field(default_factory=_default_collection_name)
|
||||
batch_size: Annotated[int, Field(gt=0, le=1000)] = Field(default_factory=_default_batch_size)
|
||||
grpc_port: int | None = Field(default=None, description="gRPC port for Weaviate connections")
|
||||
|
||||
|
||||
class FirecrawlConfig(Block):
|
||||
@@ -112,7 +116,9 @@ class RepomixConfig(Block):
|
||||
|
||||
_block_type_name: ClassVar[str | None] = "Repomix Configuration"
|
||||
_block_type_slug: ClassVar[str | None] = "repomix-config"
|
||||
_description: ClassVar[str | None] = "Configures repository ingestion patterns and file processing settings"
|
||||
_description: ClassVar[str | None] = (
|
||||
"Configures repository ingestion patterns and file processing settings"
|
||||
)
|
||||
|
||||
include_patterns: list[str] = Field(
|
||||
default_factory=lambda: ["*.py", "*.js", "*.ts", "*.md", "*.yaml", "*.json"]
|
||||
@@ -129,7 +135,9 @@ class R2RConfig(Block):
|
||||
|
||||
_block_type_name: ClassVar[str | None] = "R2R Configuration"
|
||||
_block_type_slug: ClassVar[str | None] = "r2r-config"
|
||||
_description: ClassVar[str | None] = "Configures R2R-specific ingestion settings including chunking and graph enrichment"
|
||||
_description: ClassVar[str | None] = (
|
||||
"Configures R2R-specific ingestion settings including chunking and graph enrichment"
|
||||
)
|
||||
|
||||
chunk_size: Annotated[int, Field(ge=100, le=8192)] = 1000
|
||||
chunk_overlap: Annotated[int, Field(ge=0, le=1000)] = 200
|
||||
@@ -139,6 +147,7 @@ class R2RConfig(Block):
|
||||
|
||||
class DocumentMetadataRequired(TypedDict):
|
||||
"""Required metadata fields for a document."""
|
||||
|
||||
source_url: str
|
||||
timestamp: datetime
|
||||
content_type: str
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -103,8 +103,13 @@ async def initialize_storage_task(config: StorageConfig | str) -> BaseStorage:
|
||||
return storage
|
||||
|
||||
|
||||
@task(name="map_firecrawl_site", retries=2, retry_delay_seconds=15, tags=["firecrawl", "map"],
|
||||
cache_key_fn=lambda ctx, p: _safe_cache_key("firecrawl_map", p, "source_url"))
|
||||
@task(
|
||||
name="map_firecrawl_site",
|
||||
retries=2,
|
||||
retry_delay_seconds=15,
|
||||
tags=["firecrawl", "map"],
|
||||
cache_key_fn=lambda ctx, p: _safe_cache_key("firecrawl_map", p, "source_url"),
|
||||
)
|
||||
async def map_firecrawl_site_task(source_url: str, config: FirecrawlConfig | str) -> list[str]:
|
||||
"""Map a site using Firecrawl and return discovered URLs."""
|
||||
# Load block if string provided
|
||||
@@ -118,8 +123,13 @@ async def map_firecrawl_site_task(source_url: str, config: FirecrawlConfig | str
|
||||
return mapped or [source_url]
|
||||
|
||||
|
||||
@task(name="filter_existing_documents", retries=1, retry_delay_seconds=5, tags=["dedup"],
|
||||
cache_key_fn=lambda ctx, p: _safe_cache_key("filter_docs", p, "urls")) # Cache based on URL list
|
||||
@task(
|
||||
name="filter_existing_documents",
|
||||
retries=1,
|
||||
retry_delay_seconds=5,
|
||||
tags=["dedup"],
|
||||
cache_key_fn=lambda ctx, p: _safe_cache_key("filter_docs", p, "urls"),
|
||||
) # Cache based on URL list
|
||||
async def filter_existing_documents_task(
|
||||
urls: list[str],
|
||||
storage_client: BaseStorage,
|
||||
@@ -128,19 +138,40 @@ async def filter_existing_documents_task(
|
||||
collection_name: str | None = None,
|
||||
) -> list[str]:
|
||||
"""Filter URLs to only those that need scraping (missing or stale in storage)."""
|
||||
import asyncio
|
||||
|
||||
logger = get_run_logger()
|
||||
eligible: list[str] = []
|
||||
|
||||
for url in urls:
|
||||
document_id = str(FirecrawlIngestor.compute_document_id(url))
|
||||
exists = await storage_client.check_exists(
|
||||
document_id,
|
||||
collection_name=collection_name,
|
||||
stale_after_days=stale_after_days
|
||||
)
|
||||
# Use semaphore to limit concurrent existence checks
|
||||
semaphore = asyncio.Semaphore(20)
|
||||
|
||||
if not exists:
|
||||
eligible.append(url)
|
||||
async def check_url_exists(url: str) -> tuple[str, bool]:
|
||||
async with semaphore:
|
||||
try:
|
||||
document_id = str(FirecrawlIngestor.compute_document_id(url))
|
||||
exists = await storage_client.check_exists(
|
||||
document_id, collection_name=collection_name, stale_after_days=stale_after_days
|
||||
)
|
||||
return url, exists
|
||||
except Exception as e:
|
||||
logger.warning("Error checking existence for URL %s: %s", url, e)
|
||||
# Assume doesn't exist on error to ensure we scrape it
|
||||
return url, False
|
||||
|
||||
# Check all URLs in parallel - use return_exceptions=True for partial failure handling
|
||||
results = await asyncio.gather(*[check_url_exists(url) for url in urls], return_exceptions=True)
|
||||
|
||||
# Collect URLs that need scraping, handling any exceptions
|
||||
eligible = []
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
logger.error("Unexpected error in parallel existence check: %s", result)
|
||||
continue
|
||||
# Type narrowing: result is now known to be tuple[str, bool]
|
||||
if isinstance(result, tuple) and len(result) == 2:
|
||||
url, exists = result
|
||||
if not exists:
|
||||
eligible.append(url)
|
||||
|
||||
skipped = len(urls) - len(eligible)
|
||||
if skipped > 0:
|
||||
@@ -260,7 +291,9 @@ async def ingest_documents_task(
|
||||
if progress_callback:
|
||||
progress_callback(40, "Starting document processing...")
|
||||
|
||||
return await _process_documents(ingestor, storage, job, batch_size, collection_name, progress_callback)
|
||||
return await _process_documents(
|
||||
ingestor, storage, job, batch_size, collection_name, progress_callback
|
||||
)
|
||||
|
||||
|
||||
async def _create_ingestor(job: IngestionJob, config_block_name: str | None = None) -> BaseIngestor:
|
||||
@@ -287,7 +320,9 @@ async def _create_ingestor(job: IngestionJob, config_block_name: str | None = No
|
||||
raise ValueError(f"Unsupported source: {job.source_type}")
|
||||
|
||||
|
||||
async def _create_storage(job: IngestionJob, collection_name: str | None, storage_block_name: str | None = None) -> BaseStorage:
|
||||
async def _create_storage(
|
||||
job: IngestionJob, collection_name: str | None, storage_block_name: str | None = None
|
||||
) -> BaseStorage:
|
||||
"""Create and initialize storage client."""
|
||||
if collection_name is None:
|
||||
# Use variable for default collection prefix
|
||||
@@ -303,6 +338,7 @@ async def _create_storage(job: IngestionJob, collection_name: str | None, storag
|
||||
else:
|
||||
# Fallback to building config from settings
|
||||
from ..config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
storage_config = _build_storage_config(job, settings, collection_name)
|
||||
|
||||
@@ -391,7 +427,7 @@ async def _process_documents(
|
||||
progress_callback(45, "Ingesting documents from source...")
|
||||
|
||||
# Use smart ingestion with deduplication if storage supports it
|
||||
if hasattr(storage, 'check_exists'):
|
||||
if hasattr(storage, "check_exists"):
|
||||
try:
|
||||
# Try to use the smart ingestion method
|
||||
document_generator = ingestor.ingest_with_dedup(
|
||||
@@ -412,7 +448,7 @@ async def _process_documents(
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
45 + min(35, (batch_count * 10)),
|
||||
f"Processing batch {batch_count} ({total_documents} documents so far)..."
|
||||
f"Processing batch {batch_count} ({total_documents} documents so far)...",
|
||||
)
|
||||
|
||||
batch_processed, batch_failed = await _store_batch(storage, batch, collection_name)
|
||||
@@ -482,7 +518,9 @@ async def _store_batch(
|
||||
log_prints=True,
|
||||
)
|
||||
async def firecrawl_to_r2r_flow(
|
||||
job: IngestionJob, collection_name: str | None = None, progress_callback: Callable[[int, str], None] | None = None
|
||||
job: IngestionJob,
|
||||
collection_name: str | None = None,
|
||||
progress_callback: Callable[[int, str], None] | None = None,
|
||||
) -> tuple[int, int]:
|
||||
"""Specialized flow for Firecrawl ingestion into R2R."""
|
||||
logger = get_run_logger()
|
||||
@@ -549,10 +587,8 @@ async def firecrawl_to_r2r_flow(
|
||||
|
||||
# Use asyncio.gather for concurrent scraping
|
||||
import asyncio
|
||||
scrape_tasks = [
|
||||
scrape_firecrawl_batch_task(batch, firecrawl_config)
|
||||
for batch in url_batches
|
||||
]
|
||||
|
||||
scrape_tasks = [scrape_firecrawl_batch_task(batch, firecrawl_config) for batch in url_batches]
|
||||
batch_results = await asyncio.gather(*scrape_tasks)
|
||||
|
||||
scraped_pages: list[FirecrawlPage] = []
|
||||
@@ -680,9 +716,13 @@ async def create_ingestion_flow(
|
||||
progress_callback(30, "Starting document ingestion...")
|
||||
print("Ingesting documents...")
|
||||
if job.source_type == IngestionSource.WEB and job.storage_backend == StorageBackend.R2R:
|
||||
processed, failed = await firecrawl_to_r2r_flow(job, collection_name, progress_callback=progress_callback)
|
||||
processed, failed = await firecrawl_to_r2r_flow(
|
||||
job, collection_name, progress_callback=progress_callback
|
||||
)
|
||||
else:
|
||||
processed, failed = await ingest_documents_task(job, collection_name, progress_callback=progress_callback)
|
||||
processed, failed = await ingest_documents_task(
|
||||
job, collection_name, progress_callback=progress_callback
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(90, "Finalizing ingestion...")
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -191,9 +191,10 @@ class FirecrawlIngestor(BaseIngestor):
|
||||
"http://localhost"
|
||||
):
|
||||
# Self-hosted instance - try with api_url if supported
|
||||
self.client = cast(AsyncFirecrawlClient, AsyncFirecrawl(
|
||||
api_key=api_key, api_url=str(settings.firecrawl_endpoint)
|
||||
))
|
||||
self.client = cast(
|
||||
AsyncFirecrawlClient,
|
||||
AsyncFirecrawl(api_key=api_key, api_url=str(settings.firecrawl_endpoint)),
|
||||
)
|
||||
else:
|
||||
# Cloud instance - use standard initialization
|
||||
self.client = cast(AsyncFirecrawlClient, AsyncFirecrawl(api_key=api_key))
|
||||
@@ -256,9 +257,7 @@ class FirecrawlIngestor(BaseIngestor):
|
||||
for check_url in site_map:
|
||||
document_id = str(self.compute_document_id(check_url))
|
||||
exists = await storage_client.check_exists(
|
||||
document_id,
|
||||
collection_name=collection_name,
|
||||
stale_after_days=stale_after_days
|
||||
document_id, collection_name=collection_name, stale_after_days=stale_after_days
|
||||
)
|
||||
if not exists:
|
||||
eligible_urls.append(check_url)
|
||||
@@ -394,7 +393,9 @@ class FirecrawlIngestor(BaseIngestor):
|
||||
|
||||
# Extract basic metadata
|
||||
title: str | None = getattr(metadata, "title", None) if metadata else None
|
||||
description: str | None = getattr(metadata, "description", None) if metadata else None
|
||||
description: str | None = (
|
||||
getattr(metadata, "description", None) if metadata else None
|
||||
)
|
||||
|
||||
# Extract enhanced metadata if available
|
||||
author: str | None = getattr(metadata, "author", None) if metadata else None
|
||||
@@ -402,26 +403,38 @@ class FirecrawlIngestor(BaseIngestor):
|
||||
sitemap_last_modified: str | None = (
|
||||
getattr(metadata, "sitemap_last_modified", None) if metadata else None
|
||||
)
|
||||
source_url: str | None = getattr(metadata, "sourceURL", None) if metadata else None
|
||||
keywords: str | list[str] | None = getattr(metadata, "keywords", None) if metadata else None
|
||||
source_url: str | None = (
|
||||
getattr(metadata, "sourceURL", None) if metadata else None
|
||||
)
|
||||
keywords: str | list[str] | None = (
|
||||
getattr(metadata, "keywords", None) if metadata else None
|
||||
)
|
||||
robots: str | None = getattr(metadata, "robots", None) if metadata else None
|
||||
|
||||
# Open Graph metadata
|
||||
og_title: str | None = getattr(metadata, "ogTitle", None) if metadata else None
|
||||
og_description: str | None = getattr(metadata, "ogDescription", None) if metadata else None
|
||||
og_description: str | None = (
|
||||
getattr(metadata, "ogDescription", None) if metadata else None
|
||||
)
|
||||
og_url: str | None = getattr(metadata, "ogUrl", None) if metadata else None
|
||||
og_image: str | None = getattr(metadata, "ogImage", None) if metadata else None
|
||||
|
||||
# Twitter metadata
|
||||
twitter_card: str | None = getattr(metadata, "twitterCard", None) if metadata else None
|
||||
twitter_site: str | None = getattr(metadata, "twitterSite", None) if metadata else None
|
||||
twitter_card: str | None = (
|
||||
getattr(metadata, "twitterCard", None) if metadata else None
|
||||
)
|
||||
twitter_site: str | None = (
|
||||
getattr(metadata, "twitterSite", None) if metadata else None
|
||||
)
|
||||
twitter_creator: str | None = (
|
||||
getattr(metadata, "twitterCreator", None) if metadata else None
|
||||
)
|
||||
|
||||
# Additional metadata
|
||||
favicon: str | None = getattr(metadata, "favicon", None) if metadata else None
|
||||
status_code: int | None = getattr(metadata, "statusCode", None) if metadata else None
|
||||
status_code: int | None = (
|
||||
getattr(metadata, "statusCode", None) if metadata else None
|
||||
)
|
||||
|
||||
return FirecrawlPage(
|
||||
url=url,
|
||||
@@ -594,10 +607,16 @@ class FirecrawlIngestor(BaseIngestor):
|
||||
"site_name": domain_info["site_name"],
|
||||
# Document structure
|
||||
"heading_hierarchy": (
|
||||
list(hierarchy_val) if (hierarchy_val := structure_info.get("heading_hierarchy")) and isinstance(hierarchy_val, (list, tuple)) else []
|
||||
list(hierarchy_val)
|
||||
if (hierarchy_val := structure_info.get("heading_hierarchy"))
|
||||
and isinstance(hierarchy_val, (list, tuple))
|
||||
else []
|
||||
),
|
||||
"section_depth": (
|
||||
int(depth_val) if (depth_val := structure_info.get("section_depth")) and isinstance(depth_val, (int, str)) else 0
|
||||
int(depth_val)
|
||||
if (depth_val := structure_info.get("section_depth"))
|
||||
and isinstance(depth_val, (int, str))
|
||||
else 0
|
||||
),
|
||||
"has_code_blocks": bool(structure_info.get("has_code_blocks", False)),
|
||||
"has_images": bool(structure_info.get("has_images", False)),
|
||||
@@ -632,7 +651,11 @@ class FirecrawlIngestor(BaseIngestor):
|
||||
await self.client.close()
|
||||
except Exception as e:
|
||||
logging.debug(f"Error closing Firecrawl client: {e}")
|
||||
elif hasattr(self.client, "_session") and self.client._session and hasattr(self.client._session, "close"):
|
||||
elif (
|
||||
hasattr(self.client, "_session")
|
||||
and self.client._session
|
||||
and hasattr(self.client._session, "close")
|
||||
):
|
||||
try:
|
||||
await self.client._session.close()
|
||||
except Exception as e:
|
||||
|
||||
@@ -80,6 +80,7 @@ class RepomixIngestor(BaseIngestor):
|
||||
return result.returncode == 0
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.warning(f"Failed to validate repository {source_url}: {e}")
|
||||
return False
|
||||
|
||||
@@ -97,9 +98,7 @@ class RepomixIngestor(BaseIngestor):
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Shallow clone to get file count
|
||||
repo_path = await self._clone_repository(
|
||||
source_url, temp_dir, shallow=True
|
||||
)
|
||||
repo_path = await self._clone_repository(source_url, temp_dir, shallow=True)
|
||||
|
||||
# Count files matching patterns
|
||||
file_count = 0
|
||||
@@ -111,6 +110,7 @@ class RepomixIngestor(BaseIngestor):
|
||||
return file_count
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.warning(f"Failed to estimate size for repository {source_url}: {e}")
|
||||
return 0
|
||||
|
||||
@@ -179,9 +179,7 @@ class RepomixIngestor(BaseIngestor):
|
||||
|
||||
return output_file
|
||||
|
||||
async def _parse_repomix_output(
|
||||
self, output_file: Path, job: IngestionJob
|
||||
) -> list[Document]:
|
||||
async def _parse_repomix_output(self, output_file: Path, job: IngestionJob) -> list[Document]:
|
||||
"""
|
||||
Parse repomix output into documents.
|
||||
|
||||
@@ -210,14 +208,17 @@ class RepomixIngestor(BaseIngestor):
|
||||
chunks = self._chunk_content(file_content)
|
||||
for i, chunk in enumerate(chunks):
|
||||
doc = self._create_document(
|
||||
file_path, chunk, job, chunk_index=i,
|
||||
git_metadata=git_metadata, repo_info=repo_info
|
||||
file_path,
|
||||
chunk,
|
||||
job,
|
||||
chunk_index=i,
|
||||
git_metadata=git_metadata,
|
||||
repo_info=repo_info,
|
||||
)
|
||||
documents.append(doc)
|
||||
else:
|
||||
doc = self._create_document(
|
||||
file_path, file_content, job,
|
||||
git_metadata=git_metadata, repo_info=repo_info
|
||||
file_path, file_content, job, git_metadata=git_metadata, repo_info=repo_info
|
||||
)
|
||||
documents.append(doc)
|
||||
|
||||
@@ -300,64 +301,64 @@ class RepomixIngestor(BaseIngestor):
|
||||
|
||||
# Map common extensions to languages
|
||||
ext_map = {
|
||||
'.py': 'python',
|
||||
'.js': 'javascript',
|
||||
'.ts': 'typescript',
|
||||
'.jsx': 'javascript',
|
||||
'.tsx': 'typescript',
|
||||
'.java': 'java',
|
||||
'.c': 'c',
|
||||
'.cpp': 'cpp',
|
||||
'.cc': 'cpp',
|
||||
'.cxx': 'cpp',
|
||||
'.h': 'c',
|
||||
'.hpp': 'cpp',
|
||||
'.cs': 'csharp',
|
||||
'.go': 'go',
|
||||
'.rs': 'rust',
|
||||
'.php': 'php',
|
||||
'.rb': 'ruby',
|
||||
'.swift': 'swift',
|
||||
'.kt': 'kotlin',
|
||||
'.scala': 'scala',
|
||||
'.sh': 'shell',
|
||||
'.bash': 'shell',
|
||||
'.zsh': 'shell',
|
||||
'.sql': 'sql',
|
||||
'.html': 'html',
|
||||
'.css': 'css',
|
||||
'.scss': 'scss',
|
||||
'.less': 'less',
|
||||
'.yaml': 'yaml',
|
||||
'.yml': 'yaml',
|
||||
'.json': 'json',
|
||||
'.xml': 'xml',
|
||||
'.md': 'markdown',
|
||||
'.txt': 'text',
|
||||
'.cfg': 'config',
|
||||
'.ini': 'config',
|
||||
'.toml': 'toml',
|
||||
".py": "python",
|
||||
".js": "javascript",
|
||||
".ts": "typescript",
|
||||
".jsx": "javascript",
|
||||
".tsx": "typescript",
|
||||
".java": "java",
|
||||
".c": "c",
|
||||
".cpp": "cpp",
|
||||
".cc": "cpp",
|
||||
".cxx": "cpp",
|
||||
".h": "c",
|
||||
".hpp": "cpp",
|
||||
".cs": "csharp",
|
||||
".go": "go",
|
||||
".rs": "rust",
|
||||
".php": "php",
|
||||
".rb": "ruby",
|
||||
".swift": "swift",
|
||||
".kt": "kotlin",
|
||||
".scala": "scala",
|
||||
".sh": "shell",
|
||||
".bash": "shell",
|
||||
".zsh": "shell",
|
||||
".sql": "sql",
|
||||
".html": "html",
|
||||
".css": "css",
|
||||
".scss": "scss",
|
||||
".less": "less",
|
||||
".yaml": "yaml",
|
||||
".yml": "yaml",
|
||||
".json": "json",
|
||||
".xml": "xml",
|
||||
".md": "markdown",
|
||||
".txt": "text",
|
||||
".cfg": "config",
|
||||
".ini": "config",
|
||||
".toml": "toml",
|
||||
}
|
||||
|
||||
if extension in ext_map:
|
||||
return ext_map[extension]
|
||||
|
||||
# Try to detect from shebang
|
||||
if content.startswith('#!'):
|
||||
first_line = content.split('\n')[0]
|
||||
if 'python' in first_line:
|
||||
return 'python'
|
||||
elif 'node' in first_line or 'javascript' in first_line:
|
||||
return 'javascript'
|
||||
elif 'bash' in first_line or 'sh' in first_line:
|
||||
return 'shell'
|
||||
if content.startswith("#!"):
|
||||
first_line = content.split("\n")[0]
|
||||
if "python" in first_line:
|
||||
return "python"
|
||||
elif "node" in first_line or "javascript" in first_line:
|
||||
return "javascript"
|
||||
elif "bash" in first_line or "sh" in first_line:
|
||||
return "shell"
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _analyze_code_structure(content: str, language: str | None) -> dict[str, object]:
|
||||
"""Analyze code structure and extract metadata."""
|
||||
lines = content.split('\n')
|
||||
lines = content.split("\n")
|
||||
|
||||
# Basic metrics
|
||||
has_functions = False
|
||||
@@ -366,29 +367,35 @@ class RepomixIngestor(BaseIngestor):
|
||||
has_comments = False
|
||||
|
||||
# Language-specific patterns
|
||||
if language == 'python':
|
||||
has_functions = bool(re.search(r'^\s*def\s+\w+', content, re.MULTILINE))
|
||||
has_classes = bool(re.search(r'^\s*class\s+\w+', content, re.MULTILINE))
|
||||
has_imports = bool(re.search(r'^\s*(import|from)\s+', content, re.MULTILINE))
|
||||
has_comments = bool(re.search(r'^\s*#', content, re.MULTILINE))
|
||||
elif language in ['javascript', 'typescript']:
|
||||
has_functions = bool(re.search(r'(function\s+\w+|^\s*\w+\s*:\s*function|\w+\s*=>\s*)', content, re.MULTILINE))
|
||||
has_classes = bool(re.search(r'^\s*class\s+\w+', content, re.MULTILINE))
|
||||
has_imports = bool(re.search(r'^\s*(import|require)', content, re.MULTILINE))
|
||||
has_comments = bool(re.search(r'//|/\*', content))
|
||||
elif language == 'java':
|
||||
has_functions = bool(re.search(r'(public|private|protected).*\w+\s*\(', content, re.MULTILINE))
|
||||
has_classes = bool(re.search(r'(public|private)?\s*class\s+\w+', content, re.MULTILINE))
|
||||
has_imports = bool(re.search(r'^\s*import\s+', content, re.MULTILINE))
|
||||
has_comments = bool(re.search(r'//|/\*', content))
|
||||
if language == "python":
|
||||
has_functions = bool(re.search(r"^\s*def\s+\w+", content, re.MULTILINE))
|
||||
has_classes = bool(re.search(r"^\s*class\s+\w+", content, re.MULTILINE))
|
||||
has_imports = bool(re.search(r"^\s*(import|from)\s+", content, re.MULTILINE))
|
||||
has_comments = bool(re.search(r"^\s*#", content, re.MULTILINE))
|
||||
elif language in ["javascript", "typescript"]:
|
||||
has_functions = bool(
|
||||
re.search(
|
||||
r"(function\s+\w+|^\s*\w+\s*:\s*function|\w+\s*=>\s*)", content, re.MULTILINE
|
||||
)
|
||||
)
|
||||
has_classes = bool(re.search(r"^\s*class\s+\w+", content, re.MULTILINE))
|
||||
has_imports = bool(re.search(r"^\s*(import|require)", content, re.MULTILINE))
|
||||
has_comments = bool(re.search(r"//|/\*", content))
|
||||
elif language == "java":
|
||||
has_functions = bool(
|
||||
re.search(r"(public|private|protected).*\w+\s*\(", content, re.MULTILINE)
|
||||
)
|
||||
has_classes = bool(re.search(r"(public|private)?\s*class\s+\w+", content, re.MULTILINE))
|
||||
has_imports = bool(re.search(r"^\s*import\s+", content, re.MULTILINE))
|
||||
has_comments = bool(re.search(r"//|/\*", content))
|
||||
|
||||
return {
|
||||
'has_functions': has_functions,
|
||||
'has_classes': has_classes,
|
||||
'has_imports': has_imports,
|
||||
'has_comments': has_comments,
|
||||
'line_count': len(lines),
|
||||
'non_empty_lines': len([line for line in lines if line.strip()]),
|
||||
"has_functions": has_functions,
|
||||
"has_classes": has_classes,
|
||||
"has_imports": has_imports,
|
||||
"has_comments": has_comments,
|
||||
"line_count": len(lines),
|
||||
"non_empty_lines": len([line for line in lines if line.strip()]),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -399,18 +406,16 @@ class RepomixIngestor(BaseIngestor):
|
||||
org_name = None
|
||||
|
||||
# Handle different URL formats
|
||||
if 'github.com' in repo_url or 'gitlab.com' in repo_url:
|
||||
if path_match := re.search(
|
||||
r'/([^/]+)/([^/]+?)(?:\.git)?/?$', repo_url
|
||||
):
|
||||
if "github.com" in repo_url or "gitlab.com" in repo_url:
|
||||
if path_match := re.search(r"/([^/]+)/([^/]+?)(?:\.git)?/?$", repo_url):
|
||||
org_name = path_match[1]
|
||||
repo_name = path_match[2]
|
||||
elif path_match := re.search(r'/([^/]+?)(?:\.git)?/?$', repo_url):
|
||||
elif path_match := re.search(r"/([^/]+?)(?:\.git)?/?$", repo_url):
|
||||
repo_name = path_match[1]
|
||||
|
||||
return {
|
||||
'repository_name': repo_name or 'unknown',
|
||||
'organization': org_name or 'unknown',
|
||||
"repository_name": repo_name or "unknown",
|
||||
"organization": org_name or "unknown",
|
||||
}
|
||||
|
||||
async def _get_git_metadata(self, repo_path: Path) -> dict[str, str | None]:
|
||||
@@ -418,31 +423,35 @@ class RepomixIngestor(BaseIngestor):
|
||||
try:
|
||||
# Get current branch
|
||||
branch_result = await self._run_command(
|
||||
['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
|
||||
cwd=str(repo_path),
|
||||
timeout=5
|
||||
["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd=str(repo_path), timeout=5
|
||||
)
|
||||
branch_name = (
|
||||
branch_result.stdout.decode().strip() if branch_result.returncode == 0 else None
|
||||
)
|
||||
branch_name = branch_result.stdout.decode().strip() if branch_result.returncode == 0 else None
|
||||
|
||||
# Get current commit hash
|
||||
commit_result = await self._run_command(
|
||||
['git', 'rev-parse', 'HEAD'],
|
||||
cwd=str(repo_path),
|
||||
timeout=5
|
||||
["git", "rev-parse", "HEAD"], cwd=str(repo_path), timeout=5
|
||||
)
|
||||
commit_hash = (
|
||||
commit_result.stdout.decode().strip() if commit_result.returncode == 0 else None
|
||||
)
|
||||
commit_hash = commit_result.stdout.decode().strip() if commit_result.returncode == 0 else None
|
||||
|
||||
return {
|
||||
'branch_name': branch_name,
|
||||
'commit_hash': commit_hash[:8] if commit_hash else None, # Short hash
|
||||
"branch_name": branch_name,
|
||||
"commit_hash": commit_hash[:8] if commit_hash else None, # Short hash
|
||||
}
|
||||
except Exception:
|
||||
return {'branch_name': None, 'commit_hash': None}
|
||||
return {"branch_name": None, "commit_hash": None}
|
||||
|
||||
def _create_document(
|
||||
self, file_path: str, content: str, job: IngestionJob, chunk_index: int = 0,
|
||||
self,
|
||||
file_path: str,
|
||||
content: str,
|
||||
job: IngestionJob,
|
||||
chunk_index: int = 0,
|
||||
git_metadata: dict[str, str | None] | None = None,
|
||||
repo_info: dict[str, str] | None = None
|
||||
repo_info: dict[str, str] | None = None,
|
||||
) -> Document:
|
||||
"""
|
||||
Create a Document from repository content with enriched metadata.
|
||||
@@ -466,19 +475,19 @@ class RepomixIngestor(BaseIngestor):
|
||||
|
||||
# Determine content type based on language
|
||||
content_type_map = {
|
||||
'python': 'text/x-python',
|
||||
'javascript': 'text/javascript',
|
||||
'typescript': 'text/typescript',
|
||||
'java': 'text/x-java-source',
|
||||
'html': 'text/html',
|
||||
'css': 'text/css',
|
||||
'json': 'application/json',
|
||||
'yaml': 'text/yaml',
|
||||
'markdown': 'text/markdown',
|
||||
'shell': 'text/x-shellscript',
|
||||
'sql': 'text/x-sql',
|
||||
"python": "text/x-python",
|
||||
"javascript": "text/javascript",
|
||||
"typescript": "text/typescript",
|
||||
"java": "text/x-java-source",
|
||||
"html": "text/html",
|
||||
"css": "text/css",
|
||||
"json": "application/json",
|
||||
"yaml": "text/yaml",
|
||||
"markdown": "text/markdown",
|
||||
"shell": "text/x-shellscript",
|
||||
"sql": "text/x-sql",
|
||||
}
|
||||
content_type = content_type_map.get(programming_language or 'text', 'text/plain')
|
||||
content_type = content_type_map.get(programming_language or "text", "text/plain")
|
||||
|
||||
# Build rich metadata
|
||||
metadata: DocumentMetadata = {
|
||||
@@ -487,8 +496,7 @@ class RepomixIngestor(BaseIngestor):
|
||||
"content_type": content_type,
|
||||
"word_count": len(content.split()),
|
||||
"char_count": len(content),
|
||||
"title": f"{file_path}"
|
||||
+ (f" (chunk {chunk_index})" if chunk_index > 0 else ""),
|
||||
"title": f"{file_path}" + (f" (chunk {chunk_index})" if chunk_index > 0 else ""),
|
||||
"description": f"Repository file: {file_path}",
|
||||
"category": "source_code" if programming_language else "documentation",
|
||||
"language": programming_language or "text",
|
||||
@@ -500,19 +508,23 @@ class RepomixIngestor(BaseIngestor):
|
||||
|
||||
# Add repository info if available
|
||||
if repo_info:
|
||||
metadata["repository_name"] = repo_info.get('repository_name')
|
||||
metadata["repository_name"] = repo_info.get("repository_name")
|
||||
|
||||
# Add git metadata if available
|
||||
if git_metadata:
|
||||
metadata["branch_name"] = git_metadata.get('branch_name')
|
||||
metadata["commit_hash"] = git_metadata.get('commit_hash')
|
||||
metadata["branch_name"] = git_metadata.get("branch_name")
|
||||
metadata["commit_hash"] = git_metadata.get("commit_hash")
|
||||
|
||||
# Add code-specific metadata for programming files
|
||||
if programming_language and structure_info:
|
||||
# Calculate code quality score
|
||||
total_lines = structure_info.get('line_count', 1)
|
||||
non_empty_lines = structure_info.get('non_empty_lines', 0)
|
||||
if isinstance(total_lines, int) and isinstance(non_empty_lines, int) and total_lines > 0:
|
||||
total_lines = structure_info.get("line_count", 1)
|
||||
non_empty_lines = structure_info.get("non_empty_lines", 0)
|
||||
if (
|
||||
isinstance(total_lines, int)
|
||||
and isinstance(non_empty_lines, int)
|
||||
and total_lines > 0
|
||||
):
|
||||
completeness_score = (non_empty_lines / total_lines) * 100
|
||||
metadata["completeness_score"] = completeness_score
|
||||
|
||||
@@ -566,5 +578,6 @@ class RepomixIngestor(BaseIngestor):
|
||||
await asyncio.wait_for(proc.wait(), timeout=2.0)
|
||||
except TimeoutError:
|
||||
import logging
|
||||
|
||||
logging.warning(f"Process {proc.pid} did not terminate cleanly")
|
||||
raise IngestionError(f"Command timed out: {' '.join(cmd)}") from e
|
||||
|
||||
@@ -11,6 +11,7 @@ if TYPE_CHECKING:
|
||||
|
||||
try:
|
||||
from .r2r.storage import R2RStorage as _RuntimeR2RStorage
|
||||
|
||||
R2RStorage: type[BaseStorage] | None = _RuntimeR2RStorage
|
||||
except ImportError:
|
||||
R2RStorage = None
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,10 +1,12 @@
|
||||
"""Base storage interface."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Final
|
||||
from types import TracebackType
|
||||
from typing import Final
|
||||
|
||||
import httpx
|
||||
from pydantic import SecretStr
|
||||
@@ -37,6 +39,8 @@ class TypedHttpClient:
|
||||
api_key: SecretStr | None = None,
|
||||
timeout: float = 30.0,
|
||||
headers: dict[str, str] | None = None,
|
||||
max_connections: int = 100,
|
||||
max_keepalive_connections: int = 20,
|
||||
):
|
||||
"""
|
||||
Initialize the typed HTTP client.
|
||||
@@ -46,6 +50,8 @@ class TypedHttpClient:
|
||||
api_key: Optional API key for authentication
|
||||
timeout: Request timeout in seconds
|
||||
headers: Additional headers to include with requests
|
||||
max_connections: Maximum total connections in pool
|
||||
max_keepalive_connections: Maximum keepalive connections
|
||||
"""
|
||||
self._base_url = base_url
|
||||
|
||||
@@ -54,14 +60,14 @@ class TypedHttpClient:
|
||||
if api_key:
|
||||
client_headers["Authorization"] = f"Bearer {api_key.get_secret_value()}"
|
||||
|
||||
# Note: Pylance incorrectly reports "No parameter named 'base_url'"
|
||||
# but base_url is a valid AsyncClient parameter (see HTTPX docs)
|
||||
client_kwargs: dict[str, str | dict[str, str] | float] = {
|
||||
"base_url": base_url,
|
||||
"headers": client_headers,
|
||||
"timeout": timeout,
|
||||
}
|
||||
self.client = httpx.AsyncClient(**client_kwargs) # type: ignore
|
||||
# Create typed client configuration with connection pooling
|
||||
limits = httpx.Limits(
|
||||
max_connections=max_connections, max_keepalive_connections=max_keepalive_connections
|
||||
)
|
||||
timeout_config = httpx.Timeout(connect=5.0, read=timeout, write=30.0, pool=10.0)
|
||||
self.client = httpx.AsyncClient(
|
||||
base_url=base_url, headers=client_headers, timeout=timeout_config, limits=limits
|
||||
)
|
||||
|
||||
async def request(
|
||||
self,
|
||||
@@ -73,44 +79,92 @@ class TypedHttpClient:
|
||||
data: dict[str, object] | None = None,
|
||||
files: dict[str, tuple[str, bytes, str]] | None = None,
|
||||
params: dict[str, str | bool] | None = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
) -> httpx.Response | None:
|
||||
"""
|
||||
Perform an HTTP request with consistent error handling.
|
||||
Perform an HTTP request with consistent error handling and retries.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, DELETE, etc.)
|
||||
path: URL path relative to base_url
|
||||
allow_404: If True, return None for 404 responses instead of raising
|
||||
**kwargs: Arguments passed to httpx request
|
||||
json: JSON data to send
|
||||
data: Form data to send
|
||||
files: Files to upload
|
||||
params: Query parameters
|
||||
max_retries: Maximum number of retry attempts
|
||||
retry_delay: Base delay between retries in seconds
|
||||
|
||||
Returns:
|
||||
HTTP response object, or None if allow_404=True and status is 404
|
||||
|
||||
Raises:
|
||||
StorageError: If request fails
|
||||
StorageError: If request fails after retries
|
||||
"""
|
||||
try:
|
||||
response = await self.client.request( # type: ignore
|
||||
method, path, json=json, data=data, files=files, params=params
|
||||
)
|
||||
response.raise_for_status() # type: ignore
|
||||
return response # type: ignore
|
||||
except Exception as e:
|
||||
# Handle 404 as special case if requested
|
||||
if allow_404 and hasattr(e, 'response') and getattr(e.response, 'status_code', None) == 404: # type: ignore
|
||||
LOGGER.debug("Resource not found (404): %s %s", method, path)
|
||||
return None
|
||||
last_exception: Exception | None = None
|
||||
|
||||
# Convert all HTTP-related exceptions to StorageError
|
||||
error_name = e.__class__.__name__
|
||||
if 'HTTP' in error_name or 'Connect' in error_name or 'Request' in error_name:
|
||||
if hasattr(e, 'response') and hasattr(e.response, 'status_code'): # type: ignore
|
||||
status_code = getattr(e.response, 'status_code', 'unknown') # type: ignore
|
||||
raise StorageError(f"HTTP {status_code} error from {self._base_url}: {e}") from e
|
||||
else:
|
||||
raise StorageError(f"Request failed to {self._base_url}: {e}") from e
|
||||
# Re-raise non-HTTP exceptions
|
||||
raise
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
response = await self.client.request(
|
||||
method, path, json=json, data=data, files=files, params=params
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
except httpx.HTTPStatusError as e:
|
||||
# Handle 404 as special case if requested
|
||||
if allow_404 and e.response.status_code == 404:
|
||||
LOGGER.debug("Resource not found (404): %s %s", method, path)
|
||||
return None
|
||||
|
||||
# Don't retry client errors (4xx except for specific cases)
|
||||
if 400 <= e.response.status_code < 500 and e.response.status_code not in [429, 408]:
|
||||
raise StorageError(
|
||||
f"HTTP {e.response.status_code} error from {self._base_url}: {e}"
|
||||
) from e
|
||||
|
||||
last_exception = e
|
||||
if attempt < max_retries:
|
||||
# Exponential backoff with jitter for retryable errors
|
||||
delay = retry_delay * (2**attempt) + random.uniform(0, 1)
|
||||
LOGGER.warning(
|
||||
"HTTP %d error on attempt %d/%d, retrying in %.2fs: %s",
|
||||
e.response.status_code,
|
||||
attempt + 1,
|
||||
max_retries + 1,
|
||||
delay,
|
||||
e,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
last_exception = e
|
||||
if attempt < max_retries:
|
||||
# Retry transport errors with backoff
|
||||
delay = retry_delay * (2**attempt) + random.uniform(0, 1)
|
||||
LOGGER.warning(
|
||||
"HTTP transport error on attempt %d/%d, retrying in %.2fs: %s",
|
||||
attempt + 1,
|
||||
max_retries + 1,
|
||||
delay,
|
||||
e,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# All retries exhausted - last_exception should always be set if we reach here
|
||||
if last_exception is None:
|
||||
raise StorageError(
|
||||
f"Request to {self._base_url} failed after {max_retries + 1} attempts with unknown error"
|
||||
)
|
||||
|
||||
if isinstance(last_exception, httpx.HTTPStatusError):
|
||||
raise StorageError(
|
||||
f"HTTP {last_exception.response.status_code} error from {self._base_url} after {max_retries + 1} attempts: {last_exception}"
|
||||
) from last_exception
|
||||
else:
|
||||
raise StorageError(
|
||||
f"HTTP transport error to {self._base_url} after {max_retries + 1} attempts: {last_exception}"
|
||||
) from last_exception
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client and cleanup resources."""
|
||||
@@ -127,7 +181,7 @@ class TypedHttpClient:
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
"""Async context manager exit."""
|
||||
await self.close()
|
||||
@@ -224,11 +278,30 @@ class BaseStorage(ABC):
|
||||
# Check staleness if timestamp is available
|
||||
if "timestamp" in document.metadata:
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
timestamp_obj = document.metadata["timestamp"]
|
||||
|
||||
# Handle both datetime objects and ISO strings
|
||||
if isinstance(timestamp_obj, datetime):
|
||||
timestamp = timestamp_obj
|
||||
cutoff = datetime.now(UTC) - timedelta(days=stale_after_days)
|
||||
return timestamp >= cutoff
|
||||
# Ensure timezone awareness
|
||||
if timestamp.tzinfo is None:
|
||||
timestamp = timestamp.replace(tzinfo=UTC)
|
||||
elif isinstance(timestamp_obj, str):
|
||||
try:
|
||||
timestamp = datetime.fromisoformat(timestamp_obj)
|
||||
# Ensure timezone awareness
|
||||
if timestamp.tzinfo is None:
|
||||
timestamp = timestamp.replace(tzinfo=UTC)
|
||||
except ValueError:
|
||||
# If parsing fails, assume document is stale
|
||||
return False
|
||||
else:
|
||||
# Unknown timestamp format, assume stale
|
||||
return False
|
||||
|
||||
cutoff = datetime.now(UTC) - timedelta(days=stale_after_days)
|
||||
return timestamp >= cutoff
|
||||
|
||||
# If no timestamp, assume it exists and is valid
|
||||
return True
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""Open WebUI storage adapter."""
|
||||
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from typing import Final, TypedDict, cast
|
||||
import time
|
||||
from typing import Final, NamedTuple, TypedDict
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -18,6 +18,7 @@ LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
||||
|
||||
class OpenWebUIFileResponse(TypedDict, total=False):
|
||||
"""OpenWebUI API file response structure."""
|
||||
|
||||
id: str
|
||||
filename: str
|
||||
name: str
|
||||
@@ -29,6 +30,7 @@ class OpenWebUIFileResponse(TypedDict, total=False):
|
||||
|
||||
class OpenWebUIKnowledgeBase(TypedDict, total=False):
|
||||
"""OpenWebUI knowledge base response structure."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
@@ -38,13 +40,19 @@ class OpenWebUIKnowledgeBase(TypedDict, total=False):
|
||||
updated_at: str
|
||||
|
||||
|
||||
class CacheEntry(NamedTuple):
|
||||
"""Cache entry with value and expiration time."""
|
||||
|
||||
value: str
|
||||
expires_at: float
|
||||
|
||||
|
||||
class OpenWebUIStorage(BaseStorage):
|
||||
"""Storage adapter for Open WebUI knowledge endpoints."""
|
||||
|
||||
http_client: TypedHttpClient
|
||||
_knowledge_cache: dict[str, str]
|
||||
_knowledge_cache: dict[str, CacheEntry]
|
||||
_cache_ttl: float
|
||||
|
||||
def __init__(self, config: StorageConfig):
|
||||
"""
|
||||
@@ -61,6 +69,7 @@ class OpenWebUIStorage(BaseStorage):
|
||||
timeout=30.0,
|
||||
)
|
||||
self._knowledge_cache = {}
|
||||
self._cache_ttl = 300.0 # 5 minutes TTL
|
||||
|
||||
@override
|
||||
async def initialize(self) -> None:
|
||||
@@ -106,9 +115,25 @@ class OpenWebUIStorage(BaseStorage):
|
||||
return []
|
||||
normalized: list[OpenWebUIKnowledgeBase] = []
|
||||
for item in data:
|
||||
if isinstance(item, dict):
|
||||
# Cast to our expected structure
|
||||
kb_item = cast(OpenWebUIKnowledgeBase, item)
|
||||
if (
|
||||
isinstance(item, dict)
|
||||
and "id" in item
|
||||
and "name" in item
|
||||
and isinstance(item["id"], str)
|
||||
and isinstance(item["name"], str)
|
||||
):
|
||||
# Create a new dict with known structure
|
||||
kb_item: OpenWebUIKnowledgeBase = {
|
||||
"id": item["id"],
|
||||
"name": item["name"],
|
||||
"description": item.get("description", ""),
|
||||
"created_at": item.get("created_at", ""),
|
||||
"updated_at": item.get("updated_at", ""),
|
||||
}
|
||||
if "files" in item and isinstance(item["files"], list):
|
||||
kb_item["files"] = item["files"]
|
||||
if "data" in item and isinstance(item["data"], dict):
|
||||
kb_item["data"] = item["data"]
|
||||
normalized.append(kb_item)
|
||||
return normalized
|
||||
|
||||
@@ -124,22 +149,29 @@ class OpenWebUIStorage(BaseStorage):
|
||||
if not target:
|
||||
raise StorageError("Knowledge base name is required")
|
||||
|
||||
if cached := self._knowledge_cache.get(target):
|
||||
return cached
|
||||
# Check cache with TTL
|
||||
if cached_entry := self._knowledge_cache.get(target):
|
||||
if time.time() < cached_entry.expires_at:
|
||||
return cached_entry.value
|
||||
else:
|
||||
# Entry expired, remove it
|
||||
del self._knowledge_cache[target]
|
||||
|
||||
knowledge_bases = await self._fetch_knowledge_bases()
|
||||
for kb in knowledge_bases:
|
||||
if kb.get("name") == target:
|
||||
kb_id = kb.get("id")
|
||||
if isinstance(kb_id, str):
|
||||
self._knowledge_cache[target] = kb_id
|
||||
expires_at = time.time() + self._cache_ttl
|
||||
self._knowledge_cache[target] = CacheEntry(kb_id, expires_at)
|
||||
return kb_id
|
||||
|
||||
if not create:
|
||||
return None
|
||||
|
||||
knowledge_id = await self._create_collection(target)
|
||||
self._knowledge_cache[target] = knowledge_id
|
||||
expires_at = time.time() + self._cache_ttl
|
||||
self._knowledge_cache[target] = CacheEntry(knowledge_id, expires_at)
|
||||
return knowledge_id
|
||||
|
||||
@override
|
||||
@@ -165,7 +197,7 @@ class OpenWebUIStorage(BaseStorage):
|
||||
# Use document title from metadata if available, otherwise fall back to ID
|
||||
filename = document.metadata.get("title") or f"doc_{document.id}"
|
||||
# Ensure filename has proper extension
|
||||
if not filename.endswith(('.txt', '.md', '.pdf', '.doc', '.docx')):
|
||||
if not filename.endswith((".txt", ".md", ".pdf", ".doc", ".docx")):
|
||||
filename = f"{filename}.txt"
|
||||
files = {"file": (filename, document.content.encode(), "text/plain")}
|
||||
response = await self.http_client.request(
|
||||
@@ -185,11 +217,9 @@ class OpenWebUIStorage(BaseStorage):
|
||||
|
||||
# Step 2: Add file to knowledge base
|
||||
response = await self.http_client.request(
|
||||
"POST",
|
||||
f"/api/v1/knowledge/{knowledge_id}/file/add",
|
||||
json={"file_id": file_id}
|
||||
"POST", f"/api/v1/knowledge/{knowledge_id}/file/add", json={"file_id": file_id}
|
||||
)
|
||||
|
||||
|
||||
return str(file_id)
|
||||
|
||||
except Exception as e:
|
||||
@@ -220,7 +250,7 @@ class OpenWebUIStorage(BaseStorage):
|
||||
# Use document title from metadata if available, otherwise fall back to ID
|
||||
filename = doc.metadata.get("title") or f"doc_{doc.id}"
|
||||
# Ensure filename has proper extension
|
||||
if not filename.endswith(('.txt', '.md', '.pdf', '.doc', '.docx')):
|
||||
if not filename.endswith((".txt", ".md", ".pdf", ".doc", ".docx")):
|
||||
filename = f"{filename}.txt"
|
||||
files = {"file": (filename, doc.content.encode(), "text/plain")}
|
||||
upload_response = await self.http_client.request(
|
||||
@@ -230,7 +260,9 @@ class OpenWebUIStorage(BaseStorage):
|
||||
params={"process": True, "process_in_background": False},
|
||||
)
|
||||
if upload_response is None:
|
||||
raise StorageError(f"Unexpected None response from file upload for document {doc.id}")
|
||||
raise StorageError(
|
||||
f"Unexpected None response from file upload for document {doc.id}"
|
||||
)
|
||||
|
||||
file_data = upload_response.json()
|
||||
file_id = file_data.get("id")
|
||||
@@ -241,9 +273,7 @@ class OpenWebUIStorage(BaseStorage):
|
||||
)
|
||||
|
||||
await self.http_client.request(
|
||||
"POST",
|
||||
f"/api/v1/knowledge/{knowledge_id}/file/add",
|
||||
json={"file_id": file_id}
|
||||
"POST", f"/api/v1/knowledge/{knowledge_id}/file/add", json={"file_id": file_id}
|
||||
)
|
||||
|
||||
return str(file_id)
|
||||
@@ -259,7 +289,8 @@ class OpenWebUIStorage(BaseStorage):
|
||||
if isinstance(result, Exception):
|
||||
failures.append(f"{doc.id}: {result}")
|
||||
else:
|
||||
file_ids.append(cast(str, result))
|
||||
if isinstance(result, str):
|
||||
file_ids.append(result)
|
||||
|
||||
if failures:
|
||||
LOGGER.warning(
|
||||
@@ -289,10 +320,86 @@ class OpenWebUIStorage(BaseStorage):
|
||||
"""
|
||||
_ = document_id, collection_name # Mark as used
|
||||
# OpenWebUI uses file-based storage without direct document retrieval
|
||||
# This will cause the base check_exists method to return False,
|
||||
# which means documents will always be re-scraped for OpenWebUI
|
||||
raise NotImplementedError("OpenWebUI doesn't support document retrieval by ID")
|
||||
|
||||
@override
|
||||
async def check_exists(
|
||||
self, document_id: str, *, collection_name: str | None = None, stale_after_days: int = 30
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a document exists in OpenWebUI knowledge base by searching files.
|
||||
|
||||
Args:
|
||||
document_id: Document ID to check (usually based on source URL)
|
||||
collection_name: Knowledge base name
|
||||
stale_after_days: Consider document stale after this many days
|
||||
|
||||
Returns:
|
||||
True if document exists and is not stale, False otherwise
|
||||
"""
|
||||
try:
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
# Get knowledge base
|
||||
knowledge_id = await self._get_knowledge_id(collection_name, create=False)
|
||||
if not knowledge_id:
|
||||
return False
|
||||
|
||||
# Get detailed knowledge base info to access files
|
||||
response = await self.http_client.request("GET", f"/api/v1/knowledge/{knowledge_id}")
|
||||
if response is None:
|
||||
return False
|
||||
|
||||
kb_data = response.json()
|
||||
files = kb_data.get("files", [])
|
||||
|
||||
# Look for file with matching document ID or source URL in metadata
|
||||
cutoff = datetime.now(UTC) - timedelta(days=stale_after_days)
|
||||
|
||||
def _parse_openwebui_timestamp(timestamp_str: str) -> datetime | None:
|
||||
"""Parse OpenWebUI timestamp with proper timezone handling."""
|
||||
try:
|
||||
# Handle both 'Z' suffix and explicit timezone
|
||||
normalized = timestamp_str.replace("Z", "+00:00")
|
||||
parsed = datetime.fromisoformat(normalized)
|
||||
# Ensure timezone awareness
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=UTC)
|
||||
return parsed
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
|
||||
def _check_file_freshness(file_info: dict[str, object]) -> bool:
|
||||
"""Check if file is fresh enough based on creation date."""
|
||||
created_at = file_info.get("created_at")
|
||||
if not isinstance(created_at, str):
|
||||
# No date info available, consider stale to be safe
|
||||
return False
|
||||
|
||||
file_date = _parse_openwebui_timestamp(created_at)
|
||||
return file_date is not None and file_date >= cutoff
|
||||
|
||||
for file_info in files:
|
||||
if not isinstance(file_info, dict):
|
||||
continue
|
||||
|
||||
file_id = file_info.get("id")
|
||||
if str(file_id) == document_id:
|
||||
return _check_file_freshness(file_info)
|
||||
|
||||
# Also check meta.source_url if available for URL-based document IDs
|
||||
meta = file_info.get("meta", {})
|
||||
if isinstance(meta, dict):
|
||||
source_url = meta.get("source_url")
|
||||
if source_url and document_id in str(source_url):
|
||||
return _check_file_freshness(file_info)
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
LOGGER.debug("Error checking document existence in OpenWebUI: %s", e)
|
||||
return False
|
||||
|
||||
@override
|
||||
async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool:
|
||||
"""
|
||||
@@ -316,14 +423,10 @@ class OpenWebUIStorage(BaseStorage):
|
||||
await self.http_client.request(
|
||||
"POST",
|
||||
f"/api/v1/knowledge/{knowledge_id}/file/remove",
|
||||
json={"file_id": document_id}
|
||||
json={"file_id": document_id},
|
||||
)
|
||||
|
||||
await self.http_client.request(
|
||||
"DELETE",
|
||||
f"/api/v1/files/{document_id}",
|
||||
allow_404=True
|
||||
)
|
||||
await self.http_client.request("DELETE", f"/api/v1/files/{document_id}", allow_404=True)
|
||||
return True
|
||||
except Exception as exc:
|
||||
LOGGER.error("Error deleting file %s from OpenWebUI", document_id, exc_info=exc)
|
||||
@@ -366,9 +469,7 @@ class OpenWebUIStorage(BaseStorage):
|
||||
|
||||
# Delete the knowledge base using the OpenWebUI API
|
||||
await self.http_client.request(
|
||||
"DELETE",
|
||||
f"/api/v1/knowledge/{knowledge_id}/delete",
|
||||
allow_404=True
|
||||
"DELETE", f"/api/v1/knowledge/{knowledge_id}/delete", allow_404=True
|
||||
)
|
||||
|
||||
# Remove from cache if it exists
|
||||
@@ -379,13 +480,16 @@ class OpenWebUIStorage(BaseStorage):
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
if hasattr(e, 'response'):
|
||||
response_attr = getattr(e, 'response', None)
|
||||
if response_attr is not None and hasattr(response_attr, 'status_code'):
|
||||
if hasattr(e, "response"):
|
||||
response_attr = getattr(e, "response", None)
|
||||
if response_attr is not None and hasattr(response_attr, "status_code"):
|
||||
with contextlib.suppress(Exception):
|
||||
status_code = response_attr.status_code # type: ignore[attr-defined]
|
||||
status_code = response_attr.status_code
|
||||
if status_code == 404:
|
||||
LOGGER.info("Knowledge base %s was already deleted or not found", collection_name)
|
||||
LOGGER.info(
|
||||
"Knowledge base %s was already deleted or not found",
|
||||
collection_name,
|
||||
)
|
||||
return True
|
||||
LOGGER.error(
|
||||
"Error deleting knowledge base %s from OpenWebUI",
|
||||
@@ -394,8 +498,6 @@ class OpenWebUIStorage(BaseStorage):
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
|
||||
async def _get_knowledge_base_count(self, kb: OpenWebUIKnowledgeBase) -> int:
|
||||
"""Get the file count for a knowledge base."""
|
||||
kb_id = kb.get("id")
|
||||
@@ -411,14 +513,13 @@ class OpenWebUIStorage(BaseStorage):
|
||||
files = kb.get("files", [])
|
||||
return len(files) if isinstance(files, list) and files is not None else 0
|
||||
|
||||
async def _count_files_from_detailed_info(self, kb_id: str, name: str, kb: OpenWebUIKnowledgeBase) -> int:
|
||||
async def _count_files_from_detailed_info(
|
||||
self, kb_id: str, name: str, kb: OpenWebUIKnowledgeBase
|
||||
) -> int:
|
||||
"""Count files by fetching detailed knowledge base info."""
|
||||
try:
|
||||
LOGGER.debug(f"Fetching detailed info for KB '{name}' from /api/v1/knowledge/{kb_id}")
|
||||
detail_response = await self.http_client.request(
|
||||
"GET",
|
||||
f"/api/v1/knowledge/{kb_id}"
|
||||
)
|
||||
detail_response = await self.http_client.request("GET", f"/api/v1/knowledge/{kb_id}")
|
||||
if detail_response is None:
|
||||
LOGGER.warning(f"Knowledge base '{name}' (ID: {kb_id}) not found")
|
||||
return self._count_files_from_basic_info(kb)
|
||||
@@ -489,10 +590,7 @@ class OpenWebUIStorage(BaseStorage):
|
||||
return 0
|
||||
|
||||
# Get detailed knowledge base information to get accurate file count
|
||||
detail_response = await self.http_client.request(
|
||||
"GET",
|
||||
f"/api/v1/knowledge/{kb_id}"
|
||||
)
|
||||
detail_response = await self.http_client.request("GET", f"/api/v1/knowledge/{kb_id}")
|
||||
if detail_response is None:
|
||||
LOGGER.warning(f"Knowledge base '{collection_name}' (ID: {kb_id}) not found")
|
||||
return self._count_files_from_basic_info(kb)
|
||||
@@ -524,14 +622,28 @@ class OpenWebUIStorage(BaseStorage):
|
||||
return None
|
||||
knowledge_bases = response.json()
|
||||
|
||||
return next(
|
||||
(
|
||||
cast(OpenWebUIKnowledgeBase, kb)
|
||||
for kb in knowledge_bases
|
||||
if isinstance(kb, dict) and kb.get("name") == name
|
||||
),
|
||||
None,
|
||||
)
|
||||
# Find and properly type the knowledge base
|
||||
for kb in knowledge_bases:
|
||||
if (
|
||||
isinstance(kb, dict)
|
||||
and kb.get("name") == name
|
||||
and "id" in kb
|
||||
and isinstance(kb["id"], str)
|
||||
):
|
||||
# Create properly typed response
|
||||
result: OpenWebUIKnowledgeBase = {
|
||||
"id": kb["id"],
|
||||
"name": str(kb["name"]),
|
||||
"description": kb.get("description", ""),
|
||||
"created_at": kb.get("created_at", ""),
|
||||
"updated_at": kb.get("updated_at", ""),
|
||||
}
|
||||
if "files" in kb and isinstance(kb["files"], list):
|
||||
result["files"] = kb["files"]
|
||||
if "data" in kb and isinstance(kb["data"], dict):
|
||||
result["data"] = kb["data"]
|
||||
return result
|
||||
return None
|
||||
except Exception as e:
|
||||
raise StorageError(f"Failed to get knowledge base by name: {e}") from e
|
||||
|
||||
@@ -623,7 +735,9 @@ class OpenWebUIStorage(BaseStorage):
|
||||
elif isinstance(file_info.get("meta"), dict):
|
||||
meta = file_info.get("meta")
|
||||
if isinstance(meta, dict):
|
||||
filename = meta.get("name")
|
||||
filename_value = meta.get("name")
|
||||
if isinstance(filename_value, str):
|
||||
filename = filename_value
|
||||
|
||||
# Final fallback
|
||||
if not filename:
|
||||
@@ -635,9 +749,11 @@ class OpenWebUIStorage(BaseStorage):
|
||||
size = 0
|
||||
meta = file_info.get("meta")
|
||||
if isinstance(meta, dict):
|
||||
size = meta.get("size", 0)
|
||||
size_value = meta.get("size", 0)
|
||||
size = int(size_value) if isinstance(size_value, (int, float)) else 0
|
||||
else:
|
||||
size = file_info.get("size", 0)
|
||||
size_value = file_info.get("size", 0)
|
||||
size = int(size_value) if isinstance(size_value, (int, float)) else 0
|
||||
|
||||
# Estimate word count from file size (very rough approximation)
|
||||
word_count = max(1, int(size / 6)) if isinstance(size, (int, float)) else 0
|
||||
@@ -650,9 +766,7 @@ class OpenWebUIStorage(BaseStorage):
|
||||
"content_type": str(file_info.get("content_type", "text/plain")),
|
||||
"content_preview": f"File uploaded to OpenWebUI: {filename}",
|
||||
"word_count": word_count,
|
||||
"timestamp": str(
|
||||
file_info.get("created_at") or file_info.get("timestamp", "")
|
||||
),
|
||||
"timestamp": str(file_info.get("created_at") or file_info.get("timestamp", "")),
|
||||
}
|
||||
documents.append(doc_info)
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -4,9 +4,10 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Self, TypeVar, cast
|
||||
from typing import Final, Self, TypeVar, cast
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
# Direct imports for runtime and type checking
|
||||
@@ -19,6 +20,8 @@ from ...core.models import Document, DocumentMetadata, IngestionSource, StorageC
|
||||
from ..base import BaseStorage
|
||||
from ..types import DocumentInfo
|
||||
|
||||
LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@@ -183,8 +186,10 @@ class R2RStorage(BaseStorage):
|
||||
) -> list[str]:
|
||||
"""Store multiple documents efficiently with connection reuse."""
|
||||
collection_id = await self._resolve_collection_id(collection_name)
|
||||
print(
|
||||
f"Using collection ID: {collection_id} for collection: {collection_name or self.config.collection_name}"
|
||||
LOGGER.info(
|
||||
"Using collection ID: %s for collection: %s",
|
||||
collection_id,
|
||||
collection_name or self.config.collection_name,
|
||||
)
|
||||
|
||||
# Filter valid documents upfront
|
||||
@@ -218,7 +223,7 @@ class R2RStorage(BaseStorage):
|
||||
if isinstance(result, str):
|
||||
stored_ids.append(result)
|
||||
elif isinstance(result, Exception):
|
||||
print(f"Document upload failed: {result}")
|
||||
LOGGER.error("Document upload failed: %s", result)
|
||||
|
||||
return stored_ids
|
||||
|
||||
@@ -239,12 +244,14 @@ class R2RStorage(BaseStorage):
|
||||
requested_id = str(document.id)
|
||||
|
||||
if not document.content or not document.content.strip():
|
||||
print(f"Skipping document {requested_id}: empty content")
|
||||
LOGGER.warning("Skipping document %s: empty content", requested_id)
|
||||
return False
|
||||
|
||||
if len(document.content) > 1_000_000: # 1MB limit
|
||||
print(
|
||||
f"Skipping document {requested_id}: content too large ({len(document.content)} chars)"
|
||||
LOGGER.warning(
|
||||
"Skipping document %s: content too large (%d chars)",
|
||||
requested_id,
|
||||
len(document.content),
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -263,7 +270,7 @@ class R2RStorage(BaseStorage):
|
||||
) -> str | None:
|
||||
"""Store a single document with retry logic using provided HTTP client."""
|
||||
requested_id = str(document.id)
|
||||
print(f"Creating document with ID: {requested_id}")
|
||||
LOGGER.debug("Creating document with ID: %s", requested_id)
|
||||
|
||||
max_retries = 3
|
||||
retry_delay = 1.0
|
||||
@@ -313,7 +320,7 @@ class R2RStorage(BaseStorage):
|
||||
|
||||
requested_id = str(document.id)
|
||||
metadata = self._build_metadata(document)
|
||||
print(f"Built metadata for document {requested_id}: {metadata}")
|
||||
LOGGER.debug("Built metadata for document %s: %s", requested_id, metadata)
|
||||
|
||||
files = {
|
||||
"raw_text": (None, document.content),
|
||||
@@ -324,10 +331,12 @@ class R2RStorage(BaseStorage):
|
||||
|
||||
if collection_id:
|
||||
files["collection_ids"] = (None, json.dumps([collection_id]))
|
||||
print(f"Creating document {requested_id} with collection_ids: [{collection_id}]")
|
||||
LOGGER.debug(
|
||||
"Creating document %s with collection_ids: [%s]", requested_id, collection_id
|
||||
)
|
||||
|
||||
print(f"Sending to R2R - files keys: {list(files.keys())}")
|
||||
print(f"Metadata JSON: {files['metadata'][1]}")
|
||||
LOGGER.debug("Sending to R2R - files keys: %s", list(files.keys()))
|
||||
LOGGER.debug("Metadata JSON: %s", files["metadata"][1])
|
||||
|
||||
response = await http_client.post(f"{self.endpoint}/v3/documents", files=files) # type: ignore[call-arg]
|
||||
|
||||
@@ -346,15 +355,17 @@ class R2RStorage(BaseStorage):
|
||||
error_detail = (
|
||||
getattr(response, "json", lambda: {})() if hasattr(response, "json") else {}
|
||||
)
|
||||
print(f"R2R validation error for document {requested_id}: {error_detail}")
|
||||
print(f"Document metadata sent: {metadata}")
|
||||
print(f"Response status: {getattr(response, 'status_code', 'unknown')}")
|
||||
print(f"Response headers: {dict(getattr(response, 'headers', {}))}")
|
||||
LOGGER.error("R2R validation error for document %s: %s", requested_id, error_detail)
|
||||
LOGGER.error("Document metadata sent: %s", metadata)
|
||||
LOGGER.error("Response status: %s", getattr(response, "status_code", "unknown"))
|
||||
LOGGER.error("Response headers: %s", dict(getattr(response, "headers", {})))
|
||||
except Exception:
|
||||
print(
|
||||
f"R2R validation error for document {requested_id}: {getattr(response, 'text', 'unknown error')}"
|
||||
LOGGER.error(
|
||||
"R2R validation error for document %s: %s",
|
||||
requested_id,
|
||||
getattr(response, "text", "unknown error"),
|
||||
)
|
||||
print(f"Document metadata sent: {metadata}")
|
||||
LOGGER.error("Document metadata sent: %s", metadata)
|
||||
|
||||
def _process_document_response(
|
||||
self, doc_response: dict[str, object], requested_id: str, collection_id: str
|
||||
@@ -363,14 +374,16 @@ class R2RStorage(BaseStorage):
|
||||
response_payload = doc_response.get("results", doc_response)
|
||||
doc_id = _extract_id(response_payload, requested_id)
|
||||
|
||||
print(f"R2R returned document ID: {doc_id}")
|
||||
LOGGER.info("R2R returned document ID: %s", doc_id)
|
||||
|
||||
if doc_id != requested_id:
|
||||
print(f"Warning: Requested ID {requested_id} but got {doc_id}")
|
||||
LOGGER.warning("Requested ID %s but got %s", requested_id, doc_id)
|
||||
|
||||
if collection_id:
|
||||
print(
|
||||
f"Document {doc_id} should be assigned to collection {collection_id} via creation API"
|
||||
LOGGER.info(
|
||||
"Document %s should be assigned to collection %s via creation API",
|
||||
doc_id,
|
||||
collection_id,
|
||||
)
|
||||
|
||||
return doc_id
|
||||
@@ -387,7 +400,7 @@ class R2RStorage(BaseStorage):
|
||||
if attempt >= max_retries - 1:
|
||||
return False
|
||||
|
||||
print(f"Timeout for document {requested_id}, retrying in {retry_delay}s...")
|
||||
LOGGER.warning("Timeout for document %s, retrying in %ss...", requested_id, retry_delay)
|
||||
await asyncio.sleep(retry_delay)
|
||||
return True
|
||||
|
||||
@@ -404,23 +417,26 @@ class R2RStorage(BaseStorage):
|
||||
if status_code < 500 or attempt >= max_retries - 1:
|
||||
return False
|
||||
|
||||
print(
|
||||
f"Server error {status_code} for document {requested_id}, retrying in {retry_delay}s..."
|
||||
LOGGER.warning(
|
||||
"Server error %s for document %s, retrying in %ss...",
|
||||
status_code,
|
||||
requested_id,
|
||||
retry_delay,
|
||||
)
|
||||
await asyncio.sleep(retry_delay)
|
||||
return True
|
||||
|
||||
def _log_document_error(self, document_id: object, exc: Exception) -> None:
|
||||
"""Log document storage errors with specific categorization."""
|
||||
print(f"Failed to store document {document_id}: {exc}")
|
||||
LOGGER.error("Failed to store document %s: %s", document_id, exc)
|
||||
|
||||
exc_str = str(exc)
|
||||
if "422" in exc_str:
|
||||
print(" → Data validation issue - check document content and metadata format")
|
||||
LOGGER.error(" → Data validation issue - check document content and metadata format")
|
||||
elif "timeout" in exc_str.lower():
|
||||
print(" → Network timeout - R2R may be overloaded")
|
||||
LOGGER.error(" → Network timeout - R2R may be overloaded")
|
||||
elif "500" in exc_str:
|
||||
print(" → Server error - R2R internal issue")
|
||||
LOGGER.error(" → Server error - R2R internal issue")
|
||||
else:
|
||||
import traceback
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import TypedDict
|
||||
|
||||
class CollectionSummary(TypedDict):
|
||||
"""Collection metadata for describe_collections."""
|
||||
|
||||
name: str
|
||||
count: int
|
||||
size_mb: float
|
||||
@@ -12,6 +13,7 @@ class CollectionSummary(TypedDict):
|
||||
|
||||
class DocumentInfo(TypedDict):
|
||||
"""Document information for list_documents."""
|
||||
|
||||
id: str
|
||||
title: str
|
||||
source_url: str
|
||||
@@ -19,4 +21,4 @@ class DocumentInfo(TypedDict):
|
||||
content_type: str
|
||||
content_preview: str
|
||||
word_count: int
|
||||
timestamp: str
|
||||
timestamp: str
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Weaviate storage adapter."""
|
||||
|
||||
from collections.abc import AsyncGenerator, Mapping, Sequence
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Literal, Self, TypeAlias, cast, overload
|
||||
from typing import Literal, Self, TypeAlias, TypeVar, cast, overload
|
||||
from uuid import UUID
|
||||
|
||||
import weaviate
|
||||
@@ -24,6 +25,7 @@ from .base import BaseStorage
|
||||
from .types import CollectionSummary, DocumentInfo
|
||||
|
||||
VectorContainer: TypeAlias = Mapping[str, object] | Sequence[object] | None
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class WeaviateStorage(BaseStorage):
|
||||
@@ -45,6 +47,28 @@ class WeaviateStorage(BaseStorage):
|
||||
self.vectorizer = Vectorizer(config)
|
||||
self._default_collection = self._normalize_collection_name(config.collection_name)
|
||||
|
||||
async def _run_sync(self, func: Callable[..., T], *args: object, **kwargs: object) -> T:
|
||||
"""
|
||||
Run synchronous Weaviate operations in thread pool to avoid blocking event loop.
|
||||
|
||||
Args:
|
||||
func: Synchronous function to run
|
||||
*args: Positional arguments for the function
|
||||
**kwargs: Keyword arguments for the function
|
||||
|
||||
Returns:
|
||||
Result of the function call
|
||||
|
||||
Raises:
|
||||
StorageError: If the operation fails
|
||||
"""
|
||||
try:
|
||||
return await asyncio.to_thread(func, *args, **kwargs)
|
||||
except (WeaviateConnectionError, WeaviateBatchError, WeaviateQueryError) as e:
|
||||
raise StorageError(f"Weaviate operation failed: {e}") from e
|
||||
except Exception as e:
|
||||
raise StorageError(f"Unexpected error in Weaviate operation: {e}") from e
|
||||
|
||||
@override
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize Weaviate client and create collection if needed."""
|
||||
@@ -53,7 +77,7 @@ class WeaviateStorage(BaseStorage):
|
||||
self.client = weaviate.WeaviateClient(
|
||||
connection_params=weaviate.connect.ConnectionParams.from_url(
|
||||
url=str(self.config.endpoint),
|
||||
grpc_port=50051, # Default gRPC port
|
||||
grpc_port=self.config.grpc_port or 50051,
|
||||
),
|
||||
additional_config=weaviate.classes.init.AdditionalConfig(
|
||||
timeout=weaviate.classes.init.Timeout(init=30, query=60, insert=120),
|
||||
@@ -61,7 +85,7 @@ class WeaviateStorage(BaseStorage):
|
||||
)
|
||||
|
||||
# Connect to the client
|
||||
self.client.connect()
|
||||
await self._run_sync(self.client.connect)
|
||||
|
||||
# Ensure the default collection exists
|
||||
await self._ensure_collection(self._default_collection)
|
||||
@@ -76,8 +100,8 @@ class WeaviateStorage(BaseStorage):
|
||||
if not self.client:
|
||||
raise StorageError("Weaviate client not initialized")
|
||||
try:
|
||||
client = cast(weaviate.WeaviateClient, self.client)
|
||||
client.collections.create(
|
||||
await self._run_sync(
|
||||
self.client.collections.create,
|
||||
name=collection_name,
|
||||
properties=[
|
||||
Property(
|
||||
@@ -106,7 +130,7 @@ class WeaviateStorage(BaseStorage):
|
||||
],
|
||||
vectorizer_config=Configure.Vectorizer.none(),
|
||||
)
|
||||
except Exception as e:
|
||||
except (WeaviateConnectionError, WeaviateBatchError) as e:
|
||||
raise StorageError(f"Failed to create collection: {e}") from e
|
||||
|
||||
@staticmethod
|
||||
@@ -114,13 +138,9 @@ class WeaviateStorage(BaseStorage):
|
||||
"""Normalize vector payloads returned by Weaviate into a float list."""
|
||||
if isinstance(vector_raw, Mapping):
|
||||
default_vector = vector_raw.get("default")
|
||||
return WeaviateStorage._extract_vector(
|
||||
cast(VectorContainer, default_vector)
|
||||
)
|
||||
return WeaviateStorage._extract_vector(cast(VectorContainer, default_vector))
|
||||
|
||||
if not isinstance(vector_raw, Sequence) or isinstance(
|
||||
vector_raw, (str, bytes, bytearray)
|
||||
):
|
||||
if not isinstance(vector_raw, Sequence) or isinstance(vector_raw, (str, bytes, bytearray)):
|
||||
return None
|
||||
|
||||
items = list(vector_raw)
|
||||
@@ -135,9 +155,7 @@ class WeaviateStorage(BaseStorage):
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
if isinstance(first_item, Sequence) and not isinstance(
|
||||
first_item, (str, bytes, bytearray)
|
||||
):
|
||||
if isinstance(first_item, Sequence) and not isinstance(first_item, (str, bytes, bytearray)):
|
||||
inner_items = list(first_item)
|
||||
if all(isinstance(item, (int, float)) for item in inner_items):
|
||||
try:
|
||||
@@ -168,8 +186,7 @@ class WeaviateStorage(BaseStorage):
|
||||
properties: object,
|
||||
*,
|
||||
context: str,
|
||||
) -> Mapping[str, object]:
|
||||
...
|
||||
) -> Mapping[str, object]: ...
|
||||
|
||||
@staticmethod
|
||||
@overload
|
||||
@@ -178,8 +195,7 @@ class WeaviateStorage(BaseStorage):
|
||||
*,
|
||||
context: str,
|
||||
allow_missing: Literal[False],
|
||||
) -> Mapping[str, object]:
|
||||
...
|
||||
) -> Mapping[str, object]: ...
|
||||
|
||||
@staticmethod
|
||||
@overload
|
||||
@@ -188,8 +204,7 @@ class WeaviateStorage(BaseStorage):
|
||||
*,
|
||||
context: str,
|
||||
allow_missing: Literal[True],
|
||||
) -> Mapping[str, object] | None:
|
||||
...
|
||||
) -> Mapping[str, object] | None: ...
|
||||
|
||||
@staticmethod
|
||||
def _coerce_properties(
|
||||
@@ -211,6 +226,29 @@ class WeaviateStorage(BaseStorage):
|
||||
|
||||
return cast(Mapping[str, object], properties)
|
||||
|
||||
@staticmethod
|
||||
def _build_document_properties(doc: Document) -> dict[str, object]:
|
||||
"""
|
||||
Build Weaviate properties dict from document.
|
||||
|
||||
Args:
|
||||
doc: Document to build properties for
|
||||
|
||||
Returns:
|
||||
Properties dict suitable for Weaviate
|
||||
"""
|
||||
return {
|
||||
"content": doc.content,
|
||||
"source_url": doc.metadata["source_url"],
|
||||
"title": doc.metadata.get("title", ""),
|
||||
"description": doc.metadata.get("description", ""),
|
||||
"timestamp": doc.metadata["timestamp"].isoformat(),
|
||||
"content_type": doc.metadata["content_type"],
|
||||
"word_count": doc.metadata["word_count"],
|
||||
"char_count": doc.metadata["char_count"],
|
||||
"source": doc.source.value,
|
||||
}
|
||||
|
||||
def _normalize_collection_name(self, collection_name: str | None) -> str:
|
||||
"""Return a canonicalized collection name, defaulting to configured value."""
|
||||
candidate = collection_name or self.config.collection_name
|
||||
@@ -227,7 +265,7 @@ class WeaviateStorage(BaseStorage):
|
||||
if not self.client:
|
||||
raise StorageError("Weaviate client not initialized")
|
||||
|
||||
client = cast(weaviate.WeaviateClient, self.client)
|
||||
client = self.client
|
||||
existing = client.collections.list_all()
|
||||
if collection_name not in existing:
|
||||
await self._create_collection(collection_name)
|
||||
@@ -247,7 +285,7 @@ class WeaviateStorage(BaseStorage):
|
||||
if ensure_exists:
|
||||
await self._ensure_collection(normalized)
|
||||
|
||||
client = cast(weaviate.WeaviateClient, self.client)
|
||||
client = self.client
|
||||
return client.collections.get(normalized), normalized
|
||||
|
||||
@override
|
||||
@@ -271,26 +309,19 @@ class WeaviateStorage(BaseStorage):
|
||||
)
|
||||
|
||||
# Prepare properties
|
||||
properties = {
|
||||
"content": document.content,
|
||||
"source_url": document.metadata["source_url"],
|
||||
"title": document.metadata.get("title", ""),
|
||||
"description": document.metadata.get("description", ""),
|
||||
"timestamp": document.metadata["timestamp"].isoformat(),
|
||||
"content_type": document.metadata["content_type"],
|
||||
"word_count": document.metadata["word_count"],
|
||||
"char_count": document.metadata["char_count"],
|
||||
"source": document.source.value,
|
||||
}
|
||||
properties = self._build_document_properties(document)
|
||||
|
||||
# Insert with vector
|
||||
result = collection.data.insert(
|
||||
properties=properties, vector=document.vector, uuid=str(document.id)
|
||||
result = await self._run_sync(
|
||||
collection.data.insert,
|
||||
properties=properties,
|
||||
vector=document.vector,
|
||||
uuid=str(document.id),
|
||||
)
|
||||
|
||||
return str(result)
|
||||
|
||||
except Exception as e:
|
||||
except (WeaviateConnectionError, WeaviateBatchError, WeaviateQueryError) as e:
|
||||
raise StorageError(f"Failed to store document: {e}") from e
|
||||
|
||||
@override
|
||||
@@ -311,32 +342,24 @@ class WeaviateStorage(BaseStorage):
|
||||
collection_name, ensure_exists=True
|
||||
)
|
||||
|
||||
# Vectorize documents without vectors
|
||||
for doc in documents:
|
||||
if doc.vector is None:
|
||||
doc.vector = await self.vectorizer.vectorize(doc.content)
|
||||
# Vectorize documents without vectors using batch processing
|
||||
to_vectorize = [(i, doc) for i, doc in enumerate(documents) if doc.vector is None]
|
||||
if to_vectorize:
|
||||
contents = [doc.content for _, doc in to_vectorize]
|
||||
vectors = await self.vectorizer.vectorize_batch(contents)
|
||||
for (idx, _), vector in zip(to_vectorize, vectors, strict=False):
|
||||
documents[idx].vector = vector
|
||||
|
||||
# Prepare batch data for insert_many
|
||||
batch_objects = []
|
||||
for doc in documents:
|
||||
properties = {
|
||||
"content": doc.content,
|
||||
"source_url": doc.metadata["source_url"],
|
||||
"title": doc.metadata.get("title", ""),
|
||||
"description": doc.metadata.get("description", ""),
|
||||
"timestamp": doc.metadata["timestamp"].isoformat(),
|
||||
"content_type": doc.metadata["content_type"],
|
||||
"word_count": doc.metadata["word_count"],
|
||||
"char_count": doc.metadata["char_count"],
|
||||
"source": doc.source.value,
|
||||
}
|
||||
|
||||
properties = self._build_document_properties(doc)
|
||||
batch_objects.append(
|
||||
DataObject(properties=properties, vector=doc.vector, uuid=str(doc.id))
|
||||
)
|
||||
|
||||
# Insert batch using insert_many
|
||||
response = collection.data.insert_many(batch_objects)
|
||||
response = await self._run_sync(collection.data.insert_many, batch_objects)
|
||||
|
||||
successful_ids: list[str] = []
|
||||
error_indices = set(response.errors.keys()) if response else set()
|
||||
@@ -361,11 +384,7 @@ class WeaviateStorage(BaseStorage):
|
||||
|
||||
return successful_ids
|
||||
|
||||
except WeaviateBatchError as e:
|
||||
raise StorageError(f"Batch operation failed: {e}") from e
|
||||
except WeaviateConnectionError as e:
|
||||
raise StorageError(f"Connection to Weaviate failed: {e}") from e
|
||||
except Exception as e:
|
||||
except (WeaviateBatchError, WeaviateConnectionError, WeaviateQueryError) as e:
|
||||
raise StorageError(f"Failed to store batch: {e}") from e
|
||||
|
||||
@override
|
||||
@@ -385,7 +404,7 @@ class WeaviateStorage(BaseStorage):
|
||||
collection, resolved_name = await self._prepare_collection(
|
||||
collection_name, ensure_exists=False
|
||||
)
|
||||
result = collection.query.fetch_object_by_id(document_id)
|
||||
result = await self._run_sync(collection.query.fetch_object_by_id, document_id)
|
||||
|
||||
if not result:
|
||||
return None
|
||||
@@ -395,13 +414,30 @@ class WeaviateStorage(BaseStorage):
|
||||
result.properties,
|
||||
context="fetch_object_by_id",
|
||||
)
|
||||
# Parse timestamp to datetime for consistent metadata format
|
||||
from datetime import UTC, datetime
|
||||
|
||||
timestamp_raw = props.get("timestamp")
|
||||
timestamp_parsed: datetime
|
||||
try:
|
||||
if isinstance(timestamp_raw, str):
|
||||
timestamp_parsed = datetime.fromisoformat(timestamp_raw)
|
||||
if timestamp_parsed.tzinfo is None:
|
||||
timestamp_parsed = timestamp_parsed.replace(tzinfo=UTC)
|
||||
elif isinstance(timestamp_raw, datetime):
|
||||
timestamp_parsed = timestamp_raw
|
||||
if timestamp_parsed.tzinfo is None:
|
||||
timestamp_parsed = timestamp_parsed.replace(tzinfo=UTC)
|
||||
else:
|
||||
timestamp_parsed = datetime.now(UTC)
|
||||
except (ValueError, TypeError):
|
||||
timestamp_parsed = datetime.now(UTC)
|
||||
|
||||
metadata_dict = {
|
||||
"source_url": str(props["source_url"]),
|
||||
"title": str(props.get("title")) if props.get("title") else None,
|
||||
"description": str(props.get("description"))
|
||||
if props.get("description")
|
||||
else None,
|
||||
"timestamp": str(props["timestamp"]),
|
||||
"description": str(props.get("description")) if props.get("description") else None,
|
||||
"timestamp": timestamp_parsed,
|
||||
"content_type": str(props["content_type"]),
|
||||
"word_count": int(str(props["word_count"])),
|
||||
"char_count": int(str(props["char_count"])),
|
||||
@@ -424,11 +460,13 @@ class WeaviateStorage(BaseStorage):
|
||||
except WeaviateConnectionError as e:
|
||||
# Connection issues should be logged and return None
|
||||
import logging
|
||||
|
||||
logging.warning(f"Weaviate connection error retrieving document {document_id}: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
# Log unexpected errors for debugging
|
||||
import logging
|
||||
|
||||
logging.warning(f"Unexpected error retrieving document {document_id}: {e}")
|
||||
return None
|
||||
|
||||
@@ -437,9 +475,7 @@ class WeaviateStorage(BaseStorage):
|
||||
metadata_dict = {
|
||||
"source_url": str(props["source_url"]),
|
||||
"title": str(props.get("title")) if props.get("title") else None,
|
||||
"description": str(props.get("description"))
|
||||
if props.get("description")
|
||||
else None,
|
||||
"description": str(props.get("description")) if props.get("description") else None,
|
||||
"timestamp": str(props["timestamp"]),
|
||||
"content_type": str(props["content_type"]),
|
||||
"word_count": int(str(props["word_count"])),
|
||||
@@ -462,6 +498,7 @@ class WeaviateStorage(BaseStorage):
|
||||
return max(0.0, 1.0 - distance_value)
|
||||
except (TypeError, ValueError) as e:
|
||||
import logging
|
||||
|
||||
logging.debug(f"Invalid distance value {raw_distance}: {e}")
|
||||
return None
|
||||
|
||||
@@ -506,37 +543,39 @@ class WeaviateStorage(BaseStorage):
|
||||
collection_name: str | None = None,
|
||||
) -> AsyncGenerator[Document, None]:
|
||||
"""
|
||||
Search for documents in Weaviate.
|
||||
Search for documents in Weaviate using hybrid search.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
limit: Maximum results
|
||||
threshold: Similarity threshold
|
||||
threshold: Similarity threshold (not used in hybrid search)
|
||||
|
||||
Yields:
|
||||
Matching documents
|
||||
"""
|
||||
try:
|
||||
query_vector = await self.vectorizer.vectorize(query)
|
||||
if not self.client:
|
||||
raise StorageError("Weaviate client not initialized")
|
||||
|
||||
collection, resolved_name = await self._prepare_collection(
|
||||
collection_name, ensure_exists=False
|
||||
)
|
||||
|
||||
results = collection.query.near_vector(
|
||||
near_vector=query_vector,
|
||||
limit=limit,
|
||||
distance=1 - threshold,
|
||||
return_metadata=["distance"],
|
||||
)
|
||||
# Try hybrid search first, fall back to BM25 keyword search
|
||||
try:
|
||||
response = await self._run_sync(
|
||||
collection.query.hybrid, query=query, limit=limit, return_metadata=["score"]
|
||||
)
|
||||
except (WeaviateQueryError, StorageError):
|
||||
# Fall back to BM25 if hybrid search is not supported or fails
|
||||
response = await self._run_sync(
|
||||
collection.query.bm25, query=query, limit=limit, return_metadata=["score"]
|
||||
)
|
||||
|
||||
for result in results.objects:
|
||||
yield self._build_search_document(result, resolved_name)
|
||||
for obj in response.objects:
|
||||
yield self._build_document_from_search(obj, resolved_name)
|
||||
|
||||
except WeaviateQueryError as e:
|
||||
raise StorageError(f"Search query failed: {e}") from e
|
||||
except WeaviateConnectionError as e:
|
||||
raise StorageError(f"Connection to Weaviate failed during search: {e}") from e
|
||||
except Exception as e:
|
||||
except (WeaviateQueryError, WeaviateConnectionError) as e:
|
||||
raise StorageError(f"Search failed: {e}") from e
|
||||
|
||||
@override
|
||||
@@ -552,7 +591,7 @@ class WeaviateStorage(BaseStorage):
|
||||
"""
|
||||
try:
|
||||
collection, _ = await self._prepare_collection(collection_name, ensure_exists=False)
|
||||
collection.data.delete_by_id(document_id)
|
||||
await self._run_sync(collection.data.delete_by_id, document_id)
|
||||
return True
|
||||
except WeaviateQueryError as e:
|
||||
raise StorageError(f"Delete operation failed: {e}") from e
|
||||
@@ -589,10 +628,10 @@ class WeaviateStorage(BaseStorage):
|
||||
if not self.client:
|
||||
raise StorageError("Weaviate client not initialized")
|
||||
|
||||
client = cast(weaviate.WeaviateClient, self.client)
|
||||
client = self.client
|
||||
return list(client.collections.list_all())
|
||||
|
||||
except Exception as e:
|
||||
except WeaviateConnectionError as e:
|
||||
raise StorageError(f"Failed to list collections: {e}") from e
|
||||
|
||||
async def describe_collections(self) -> list[CollectionSummary]:
|
||||
@@ -601,7 +640,7 @@ class WeaviateStorage(BaseStorage):
|
||||
raise StorageError("Weaviate client not initialized")
|
||||
|
||||
try:
|
||||
client = cast(weaviate.WeaviateClient, self.client)
|
||||
client = self.client
|
||||
collections: list[CollectionSummary] = []
|
||||
for name in client.collections.list_all():
|
||||
collection_obj = client.collections.get(name)
|
||||
@@ -639,7 +678,7 @@ class WeaviateStorage(BaseStorage):
|
||||
)
|
||||
|
||||
# Query for sample documents
|
||||
response = collection.query.fetch_objects(limit=limit)
|
||||
response = await self._run_sync(collection.query.fetch_objects, limit=limit)
|
||||
|
||||
documents = []
|
||||
for obj in response.objects:
|
||||
@@ -712,9 +751,7 @@ class WeaviateStorage(BaseStorage):
|
||||
return {
|
||||
"source_url": str(props.get("source_url", "")),
|
||||
"title": str(props.get("title", "")) if props.get("title") else None,
|
||||
"description": str(props.get("description", ""))
|
||||
if props.get("description")
|
||||
else None,
|
||||
"description": str(props.get("description", "")) if props.get("description") else None,
|
||||
"timestamp": datetime.fromisoformat(
|
||||
str(props.get("timestamp", datetime.now(UTC).isoformat()))
|
||||
),
|
||||
@@ -737,6 +774,7 @@ class WeaviateStorage(BaseStorage):
|
||||
return float(raw_score)
|
||||
except (TypeError, ValueError) as e:
|
||||
import logging
|
||||
|
||||
logging.debug(f"Invalid score value {raw_score}: {e}")
|
||||
return None
|
||||
|
||||
@@ -780,31 +818,11 @@ class WeaviateStorage(BaseStorage):
|
||||
Returns:
|
||||
List of matching documents
|
||||
"""
|
||||
try:
|
||||
if not self.client:
|
||||
raise StorageError("Weaviate client not initialized")
|
||||
|
||||
collection, resolved_name = await self._prepare_collection(
|
||||
collection_name, ensure_exists=False
|
||||
)
|
||||
|
||||
# Try hybrid search first, fall back to BM25 keyword search
|
||||
try:
|
||||
response = collection.query.hybrid(
|
||||
query=query, limit=limit, return_metadata=["score"]
|
||||
)
|
||||
except Exception:
|
||||
response = collection.query.bm25(
|
||||
query=query, limit=limit, return_metadata=["score"]
|
||||
)
|
||||
|
||||
return [
|
||||
self._build_document_from_search(obj, resolved_name)
|
||||
for obj in response.objects
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
raise StorageError(f"Failed to search documents: {e}") from e
|
||||
# Delegate to the unified search method
|
||||
results = []
|
||||
async for document in self.search(query, limit=limit, collection_name=collection_name):
|
||||
results.append(document)
|
||||
return results
|
||||
|
||||
async def list_documents(
|
||||
self,
|
||||
@@ -830,8 +848,11 @@ class WeaviateStorage(BaseStorage):
|
||||
collection, _ = await self._prepare_collection(collection_name, ensure_exists=False)
|
||||
|
||||
# Query documents with pagination
|
||||
response = collection.query.fetch_objects(
|
||||
limit=limit, offset=offset, return_metadata=["creation_time"]
|
||||
response = await self._run_sync(
|
||||
collection.query.fetch_objects,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
return_metadata=["creation_time"],
|
||||
)
|
||||
|
||||
documents: list[DocumentInfo] = []
|
||||
@@ -896,7 +917,9 @@ class WeaviateStorage(BaseStorage):
|
||||
)
|
||||
|
||||
delete_filter = Filter.by_id().contains_any(document_ids)
|
||||
response = collection.data.delete_many(where=delete_filter, verbose=True)
|
||||
response = await self._run_sync(
|
||||
collection.data.delete_many, where=delete_filter, verbose=True
|
||||
)
|
||||
|
||||
if objects := getattr(response, "objects", None):
|
||||
for result_obj in objects:
|
||||
@@ -938,20 +961,22 @@ class WeaviateStorage(BaseStorage):
|
||||
|
||||
# Get documents matching filter
|
||||
if where_filter:
|
||||
response = collection.query.fetch_objects(
|
||||
response = await self._run_sync(
|
||||
collection.query.fetch_objects,
|
||||
filters=where_filter,
|
||||
limit=1000, # Max batch size
|
||||
)
|
||||
else:
|
||||
response = collection.query.fetch_objects(
|
||||
limit=1000 # Max batch size
|
||||
response = await self._run_sync(
|
||||
collection.query.fetch_objects,
|
||||
limit=1000, # Max batch size
|
||||
)
|
||||
|
||||
# Delete matching documents
|
||||
deleted_count = 0
|
||||
for obj in response.objects:
|
||||
try:
|
||||
collection.data.delete_by_id(obj.uuid)
|
||||
await self._run_sync(collection.data.delete_by_id, obj.uuid)
|
||||
deleted_count += 1
|
||||
except Exception:
|
||||
continue
|
||||
@@ -975,7 +1000,7 @@ class WeaviateStorage(BaseStorage):
|
||||
target = self._normalize_collection_name(collection_name)
|
||||
|
||||
# Delete the collection using the client's collections API
|
||||
client = cast(weaviate.WeaviateClient, self.client)
|
||||
client = self.client
|
||||
client.collections.delete(target)
|
||||
|
||||
return True
|
||||
@@ -997,20 +1022,29 @@ class WeaviateStorage(BaseStorage):
|
||||
await self.close()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close client connection."""
|
||||
"""Close client connection and vectorizer HTTP client."""
|
||||
if self.client:
|
||||
try:
|
||||
client = cast(weaviate.WeaviateClient, self.client)
|
||||
client = self.client
|
||||
client.close()
|
||||
except Exception as e:
|
||||
except (WeaviateConnectionError, AttributeError) as e:
|
||||
import logging
|
||||
|
||||
logging.warning(f"Error closing Weaviate client: {e}")
|
||||
|
||||
# Close vectorizer HTTP client to prevent resource leaks
|
||||
try:
|
||||
await self.vectorizer.close()
|
||||
except (AttributeError, OSError) as e:
|
||||
import logging
|
||||
|
||||
logging.warning(f"Error closing vectorizer client: {e}")
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Clean up client connection as fallback."""
|
||||
if self.client:
|
||||
try:
|
||||
client = cast(weaviate.WeaviateClient, self.client)
|
||||
client = self.client
|
||||
client.close()
|
||||
except Exception:
|
||||
pass # Ignore errors in destructor
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
117
ingest_pipeline/utils/async_helpers.py
Normal file
117
ingest_pipeline/utils/async_helpers.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""Async utilities for task management and backpressure control."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Final, TypeVar
|
||||
|
||||
LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class AsyncTaskManager:
|
||||
"""Manages concurrent tasks with backpressure control."""
|
||||
|
||||
def __init__(self, max_concurrent: int = 10):
|
||||
"""
|
||||
Initialize task manager.
|
||||
|
||||
Args:
|
||||
max_concurrent: Maximum number of concurrent tasks
|
||||
"""
|
||||
self.semaphore = asyncio.Semaphore(max_concurrent)
|
||||
self.max_concurrent = max_concurrent
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire(self) -> AsyncGenerator[None, None]:
|
||||
"""Acquire a slot for task execution."""
|
||||
async with self.semaphore:
|
||||
yield
|
||||
|
||||
async def run_tasks(
|
||||
self, tasks: Iterable[Awaitable[T]], return_exceptions: bool = False
|
||||
) -> list[T | BaseException]:
|
||||
"""
|
||||
Run multiple tasks with backpressure control.
|
||||
|
||||
Args:
|
||||
tasks: Iterable of awaitable tasks
|
||||
return_exceptions: Whether to return exceptions or raise them
|
||||
|
||||
Returns:
|
||||
List of task results or exceptions
|
||||
"""
|
||||
|
||||
async def _controlled_task(task: Awaitable[T]) -> T:
|
||||
async with self.acquire():
|
||||
return await task
|
||||
|
||||
controlled_tasks = [_controlled_task(task) for task in tasks]
|
||||
|
||||
if return_exceptions:
|
||||
results = await asyncio.gather(*controlled_tasks, return_exceptions=True)
|
||||
return list(results)
|
||||
else:
|
||||
results = await asyncio.gather(*controlled_tasks)
|
||||
return list(results)
|
||||
|
||||
async def map_async(
|
||||
self, func: Callable[[T], Awaitable[T]], items: Iterable[T], return_exceptions: bool = False
|
||||
) -> list[T | BaseException]:
|
||||
"""
|
||||
Apply async function to items with backpressure control.
|
||||
|
||||
Args:
|
||||
func: Async function to apply
|
||||
items: Items to process
|
||||
return_exceptions: Whether to return exceptions or raise them
|
||||
|
||||
Returns:
|
||||
List of processed results or exceptions
|
||||
"""
|
||||
tasks = [func(item) for item in items]
|
||||
return await self.run_tasks(tasks, return_exceptions=return_exceptions)
|
||||
|
||||
|
||||
async def run_with_semaphore(semaphore: asyncio.Semaphore, coro: Awaitable[T]) -> T:
|
||||
"""Run coroutine with semaphore-controlled concurrency."""
|
||||
async with semaphore:
|
||||
return await coro
|
||||
|
||||
|
||||
async def batch_process(
|
||||
items: list[T],
|
||||
processor: Callable[[T], Awaitable[T]],
|
||||
batch_size: int = 50,
|
||||
max_concurrent: int = 5,
|
||||
) -> list[T]:
|
||||
"""
|
||||
Process items in batches with controlled concurrency.
|
||||
|
||||
Args:
|
||||
items: Items to process
|
||||
processor: Async function to process each item
|
||||
batch_size: Number of items per batch
|
||||
max_concurrent: Maximum concurrent tasks per batch
|
||||
|
||||
Returns:
|
||||
List of processed results
|
||||
"""
|
||||
task_manager = AsyncTaskManager(max_concurrent)
|
||||
results: list[T] = []
|
||||
|
||||
for i in range(0, len(items), batch_size):
|
||||
batch = items[i : i + batch_size]
|
||||
LOGGER.debug(
|
||||
"Processing batch %d-%d of %d items", i, min(i + batch_size, len(items)), len(items)
|
||||
)
|
||||
|
||||
batch_results = await task_manager.map_async(processor, batch, return_exceptions=False)
|
||||
# If return_exceptions=False, exceptions would have been raised, so all results are successful
|
||||
# Type checker doesn't know this, so we need to cast
|
||||
successful_results: list[T] = [r for r in batch_results if not isinstance(r, BaseException)]
|
||||
results.extend(successful_results)
|
||||
|
||||
return results
|
||||
@@ -6,12 +6,12 @@ from typing import Final, Protocol, TypedDict, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from ..config import get_settings
|
||||
from ..core.exceptions import IngestionError
|
||||
from ..core.models import Document
|
||||
|
||||
JSON_CONTENT_TYPE: Final[str] = "application/json"
|
||||
AUTHORIZATION_HEADER: Final[str] = "Authorization"
|
||||
from ..config import get_settings
|
||||
|
||||
|
||||
class HttpResponse(Protocol):
|
||||
@@ -24,12 +24,7 @@ class HttpResponse(Protocol):
|
||||
class AsyncHttpClient(Protocol):
|
||||
"""Protocol for async HTTP client."""
|
||||
|
||||
async def post(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
json: dict[str, object] | None = None
|
||||
) -> HttpResponse: ...
|
||||
async def post(self, url: str, *, json: dict[str, object] | None = None) -> HttpResponse: ...
|
||||
|
||||
async def aclose(self) -> None: ...
|
||||
|
||||
@@ -45,16 +40,19 @@ class AsyncHttpClient(Protocol):
|
||||
|
||||
class LlmResponse(TypedDict):
|
||||
"""Type for LLM API response structure."""
|
||||
|
||||
choices: list[dict[str, object]]
|
||||
|
||||
|
||||
class LlmChoice(TypedDict):
|
||||
"""Type for individual choice in LLM response."""
|
||||
|
||||
message: dict[str, object]
|
||||
|
||||
|
||||
class LlmMessage(TypedDict):
|
||||
"""Type for message in LLM choice."""
|
||||
|
||||
content: str
|
||||
|
||||
|
||||
@@ -96,7 +94,7 @@ class MetadataTagger:
|
||||
"""
|
||||
settings = get_settings()
|
||||
endpoint_value = llm_endpoint or str(settings.llm_endpoint)
|
||||
self.endpoint = endpoint_value.rstrip('/')
|
||||
self.endpoint = endpoint_value.rstrip("/")
|
||||
self.model = model or settings.metadata_model
|
||||
|
||||
resolved_timeout = timeout if timeout is not None else float(settings.request_timeout)
|
||||
@@ -161,12 +159,16 @@ class MetadataTagger:
|
||||
# Build a proper DocumentMetadata instance with only valid keys
|
||||
new_metadata: CoreDocumentMetadata = {
|
||||
"source_url": str(updated_metadata.get("source_url", "")),
|
||||
"timestamp": (
|
||||
lambda ts: ts if isinstance(ts, datetime) else datetime.now(UTC)
|
||||
)(updated_metadata.get("timestamp", datetime.now(UTC))),
|
||||
"timestamp": (lambda ts: ts if isinstance(ts, datetime) else datetime.now(UTC))(
|
||||
updated_metadata.get("timestamp", datetime.now(UTC))
|
||||
),
|
||||
"content_type": str(updated_metadata.get("content_type", "text/plain")),
|
||||
"word_count": (lambda wc: int(wc) if isinstance(wc, (int, str)) else 0)(updated_metadata.get("word_count", 0)),
|
||||
"char_count": (lambda cc: int(cc) if isinstance(cc, (int, str)) else 0)(updated_metadata.get("char_count", 0)),
|
||||
"word_count": (lambda wc: int(wc) if isinstance(wc, (int, str)) else 0)(
|
||||
updated_metadata.get("word_count", 0)
|
||||
),
|
||||
"char_count": (lambda cc: int(cc) if isinstance(cc, (int, str)) else 0)(
|
||||
updated_metadata.get("char_count", 0)
|
||||
),
|
||||
}
|
||||
|
||||
# Add optional fields if they exist
|
||||
|
||||
@@ -1,20 +1,79 @@
|
||||
"""Vectorizer utility for generating embeddings."""
|
||||
|
||||
import asyncio
|
||||
from types import TracebackType
|
||||
from typing import Final, Self, cast
|
||||
from typing import Final, NotRequired, Self, TypedDict
|
||||
|
||||
import httpx
|
||||
|
||||
from typings import EmbeddingResponse
|
||||
|
||||
from ..config import get_settings
|
||||
from ..core.exceptions import VectorizationError
|
||||
from ..core.models import StorageConfig, VectorConfig
|
||||
from ..config import get_settings
|
||||
|
||||
JSON_CONTENT_TYPE: Final[str] = "application/json"
|
||||
AUTHORIZATION_HEADER: Final[str] = "Authorization"
|
||||
|
||||
|
||||
class EmbeddingData(TypedDict):
|
||||
"""Structure for embedding data from providers."""
|
||||
|
||||
embedding: list[float]
|
||||
index: NotRequired[int]
|
||||
object: NotRequired[str]
|
||||
|
||||
|
||||
class EmbeddingResponse(TypedDict):
|
||||
"""Embedding response format for multiple providers."""
|
||||
|
||||
data: list[EmbeddingData]
|
||||
model: NotRequired[str]
|
||||
object: NotRequired[str]
|
||||
usage: NotRequired[dict[str, int]]
|
||||
# Alternative formats
|
||||
embedding: NotRequired[list[float]]
|
||||
vector: NotRequired[list[float]]
|
||||
embeddings: NotRequired[list[list[float]]]
|
||||
|
||||
|
||||
def _extract_embedding_from_response(response_data: dict[str, object]) -> list[float]:
|
||||
"""Extract embedding vector from provider response."""
|
||||
# OpenAI/Ollama format: {"data": [{"embedding": [...]}]}
|
||||
if "data" in response_data:
|
||||
data_list = response_data["data"]
|
||||
if isinstance(data_list, list) and data_list:
|
||||
first_item = data_list[0]
|
||||
if isinstance(first_item, dict) and "embedding" in first_item:
|
||||
embedding = first_item["embedding"]
|
||||
if isinstance(embedding, list) and all(
|
||||
isinstance(x, (int, float)) for x in embedding
|
||||
):
|
||||
return [float(x) for x in embedding]
|
||||
|
||||
# Direct embedding format: {"embedding": [...]}
|
||||
if "embedding" in response_data:
|
||||
embedding = response_data["embedding"]
|
||||
if isinstance(embedding, list) and all(isinstance(x, (int, float)) for x in embedding):
|
||||
return [float(x) for x in embedding]
|
||||
|
||||
# Vector format: {"vector": [...]}
|
||||
if "vector" in response_data:
|
||||
vector = response_data["vector"]
|
||||
if isinstance(vector, list) and all(isinstance(x, (int, float)) for x in vector):
|
||||
return [float(x) for x in vector]
|
||||
|
||||
# Embeddings array format: {"embeddings": [[...]]}
|
||||
if "embeddings" in response_data:
|
||||
embeddings = response_data["embeddings"]
|
||||
if isinstance(embeddings, list) and embeddings:
|
||||
first_embedding = embeddings[0]
|
||||
if isinstance(first_embedding, list) and all(
|
||||
isinstance(x, (int, float)) for x in first_embedding
|
||||
):
|
||||
return [float(x) for x in first_embedding]
|
||||
|
||||
raise VectorizationError("Unrecognized embedding response format")
|
||||
|
||||
|
||||
class Vectorizer:
|
||||
"""Handles text vectorization using LLM endpoints."""
|
||||
|
||||
@@ -72,21 +131,34 @@ class Vectorizer:
|
||||
|
||||
async def vectorize_batch(self, texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
Generate embeddings for multiple texts.
|
||||
Generate embeddings for multiple texts in parallel.
|
||||
|
||||
Args:
|
||||
texts: List of texts to vectorize
|
||||
|
||||
Returns:
|
||||
List of embedding vectors
|
||||
|
||||
Raises:
|
||||
VectorizationError: If any vectorization fails
|
||||
"""
|
||||
vectors: list[list[float]] = []
|
||||
|
||||
for text in texts:
|
||||
vector = await self.vectorize(text)
|
||||
vectors.append(vector)
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
return vectors
|
||||
# Use semaphore to limit concurrent requests and prevent overwhelming the endpoint
|
||||
semaphore = asyncio.Semaphore(20)
|
||||
|
||||
async def vectorize_with_semaphore(text: str) -> list[float]:
|
||||
async with semaphore:
|
||||
return await self.vectorize(text)
|
||||
|
||||
try:
|
||||
# Execute all vectorization requests concurrently
|
||||
vectors = await asyncio.gather(*[vectorize_with_semaphore(text) for text in texts])
|
||||
return list(vectors)
|
||||
except Exception as e:
|
||||
raise VectorizationError(f"Batch vectorization failed: {e}") from e
|
||||
|
||||
async def _ollama_embed(self, text: str) -> list[float]:
|
||||
"""
|
||||
@@ -112,23 +184,12 @@ class Vectorizer:
|
||||
_ = response.raise_for_status()
|
||||
|
||||
response_json = response.json()
|
||||
# Response is expected to be dict[str, object] from our type stub
|
||||
if not isinstance(response_json, dict):
|
||||
raise VectorizationError("Invalid JSON response format")
|
||||
|
||||
response_data = cast(EmbeddingResponse, cast(object, response_json))
|
||||
# Extract embedding using type-safe helper
|
||||
embedding = _extract_embedding_from_response(response_json)
|
||||
|
||||
# Parse OpenAI-compatible response format
|
||||
embeddings_list = response_data.get("data", [])
|
||||
if not embeddings_list:
|
||||
raise VectorizationError("No embeddings returned")
|
||||
|
||||
first_embedding = embeddings_list[0]
|
||||
embedding_raw = first_embedding.get("embedding")
|
||||
if not embedding_raw:
|
||||
raise VectorizationError("Invalid embedding format")
|
||||
|
||||
# Convert to float list and validate
|
||||
embedding: list[float] = []
|
||||
embedding.extend(float(item) for item in embedding_raw)
|
||||
# Ensure correct dimension
|
||||
if len(embedding) != self.dimension:
|
||||
raise VectorizationError(
|
||||
@@ -157,22 +218,12 @@ class Vectorizer:
|
||||
_ = response.raise_for_status()
|
||||
|
||||
response_json = response.json()
|
||||
# Response is expected to be dict[str, object] from our type stub
|
||||
if not isinstance(response_json, dict):
|
||||
raise VectorizationError("Invalid JSON response format")
|
||||
|
||||
response_data = cast(EmbeddingResponse, cast(object, response_json))
|
||||
# Extract embedding using type-safe helper
|
||||
embedding = _extract_embedding_from_response(response_json)
|
||||
|
||||
embeddings_list = response_data.get("data", [])
|
||||
if not embeddings_list:
|
||||
raise VectorizationError("No embeddings returned")
|
||||
|
||||
first_embedding = embeddings_list[0]
|
||||
embedding_raw = first_embedding.get("embedding")
|
||||
if not embedding_raw:
|
||||
raise VectorizationError("Invalid embedding format")
|
||||
|
||||
# Convert to float list and validate
|
||||
embedding: list[float] = []
|
||||
embedding.extend(float(item) for item in embedding_raw)
|
||||
# Ensure correct dimension
|
||||
if len(embedding) != self.dimension:
|
||||
raise VectorizationError(
|
||||
@@ -185,6 +236,14 @@ class Vectorizer:
|
||||
"""Async context manager entry."""
|
||||
return self
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client connection."""
|
||||
try:
|
||||
await self.client.aclose()
|
||||
except Exception:
|
||||
# Already closed or connection lost
|
||||
pass
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
@@ -192,4 +251,4 @@ class Vectorizer:
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
"""Async context manager exit."""
|
||||
await self.client.aclose()
|
||||
await self.close()
|
||||
|
||||
27
notebooks/welcome.py
Normal file
27
notebooks/welcome.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import marimo
|
||||
|
||||
__generated_with = "0.16.0"
|
||||
app = marimo.App()
|
||||
|
||||
|
||||
@app.cell
|
||||
def __():
|
||||
import marimo as mo
|
||||
|
||||
return (mo,)
|
||||
|
||||
|
||||
@app.cell
|
||||
def __(mo):
|
||||
mo.md("# Welcome to Marimo!")
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def __(mo):
|
||||
mo.md("This is your interactive notebook environment.")
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run()
|
||||
7628
repomix-output.xml
7628
repomix-output.xml
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -59,6 +59,7 @@ class StubbedResponse:
|
||||
def raise_for_status(self) -> None:
|
||||
if self.status_code < 400:
|
||||
return
|
||||
|
||||
# Create a minimal exception for testing - we don't need full httpx objects for tests
|
||||
class TestHTTPError(Exception):
|
||||
request: object | None
|
||||
@@ -232,7 +233,7 @@ class AsyncClientStub:
|
||||
# Convert params to the format expected by other methods
|
||||
converted_params: dict[str, object] | None = None
|
||||
if params:
|
||||
converted_params = {k: v for k, v in params.items()}
|
||||
converted_params = dict(params)
|
||||
|
||||
method_upper = method.upper()
|
||||
if method_upper == "GET":
|
||||
@@ -441,7 +442,6 @@ def r2r_spec() -> OpenAPISpec:
|
||||
return OpenAPISpec.from_file(PROJECT_ROOT / "r2r.json")
|
||||
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def firecrawl_spec() -> OpenAPISpec:
|
||||
"""Load Firecrawl OpenAPI specification."""
|
||||
@@ -539,7 +539,9 @@ def firecrawl_client_stub(
|
||||
links = [SimpleNamespace(url=cast(str, item.get("url", ""))) for item in links_data]
|
||||
return SimpleNamespace(success=payload.get("success", True), links=links)
|
||||
|
||||
async def scrape(self, url: str, formats: list[str] | None = None, **_: object) -> SimpleNamespace:
|
||||
async def scrape(
|
||||
self, url: str, formats: list[str] | None = None, **_: object
|
||||
) -> SimpleNamespace:
|
||||
payload = cast(MockResponseData, self._service.scrape_response(url, formats))
|
||||
data = cast(MockResponseData, payload.get("data", {}))
|
||||
metadata_payload = cast(MockResponseData, data.get("metadata", {}))
|
||||
|
||||
@@ -158,7 +158,9 @@ class OpenAPISpec:
|
||||
return []
|
||||
return [segment for segment in path.strip("/").split("/") if segment]
|
||||
|
||||
def find_operation(self, method: str, path: str) -> tuple[Mapping[str, Any] | None, dict[str, str]]:
|
||||
def find_operation(
|
||||
self, method: str, path: str
|
||||
) -> tuple[Mapping[str, Any] | None, dict[str, str]]:
|
||||
method = method.lower()
|
||||
normalized = "/" if path in {"", "/"} else "/" + path.strip("/")
|
||||
actual_segments = self._split_path(normalized)
|
||||
@@ -188,9 +190,7 @@ class OpenAPISpec:
|
||||
status: str | None = None,
|
||||
) -> tuple[int, Any]:
|
||||
responses = operation.get("responses", {})
|
||||
target_status = status or (
|
||||
"200" if "200" in responses else next(iter(responses), "200")
|
||||
)
|
||||
target_status = status or ("200" if "200" in responses else next(iter(responses), "200"))
|
||||
status_code = 200
|
||||
try:
|
||||
status_code = int(target_status)
|
||||
@@ -222,7 +222,7 @@ class OpenAPISpec:
|
||||
base = self.generate({"$ref": ref})
|
||||
if overrides is None or not isinstance(base, Mapping):
|
||||
return copy.deepcopy(base)
|
||||
merged = copy.deepcopy(base)
|
||||
merged: dict[str, Any] = dict(base)
|
||||
for key, value in overrides.items():
|
||||
if isinstance(value, Mapping) and isinstance(merged.get(key), Mapping):
|
||||
merged[key] = self.generate_from_mapping(
|
||||
@@ -237,8 +237,8 @@ class OpenAPISpec:
|
||||
self,
|
||||
base: Mapping[str, Any],
|
||||
overrides: Mapping[str, Any],
|
||||
) -> Mapping[str, Any]:
|
||||
result = copy.deepcopy(base)
|
||||
) -> dict[str, Any]:
|
||||
result: dict[str, Any] = dict(base)
|
||||
for key, value in overrides.items():
|
||||
if isinstance(value, Mapping) and isinstance(result.get(key), Mapping):
|
||||
result[key] = self.generate_from_mapping(result[key], value)
|
||||
@@ -380,18 +380,20 @@ class OpenWebUIMockService(OpenAPIMockService):
|
||||
|
||||
def _knowledge_user_response(self, knowledge: Mapping[str, Any]) -> dict[str, Any]:
|
||||
payload = self.spec.generate_from_ref("#/components/schemas/KnowledgeUserResponse")
|
||||
payload.update({
|
||||
"id": knowledge["id"],
|
||||
"user_id": knowledge["user_id"],
|
||||
"name": knowledge["name"],
|
||||
"description": knowledge.get("description", ""),
|
||||
"data": copy.deepcopy(knowledge.get("data", {})),
|
||||
"meta": copy.deepcopy(knowledge.get("meta", {})),
|
||||
"access_control": copy.deepcopy(knowledge.get("access_control", {})),
|
||||
"created_at": knowledge.get("created_at", self._timestamp()),
|
||||
"updated_at": knowledge.get("updated_at", self._timestamp()),
|
||||
"files": copy.deepcopy(knowledge.get("files", [])),
|
||||
})
|
||||
payload.update(
|
||||
{
|
||||
"id": knowledge["id"],
|
||||
"user_id": knowledge["user_id"],
|
||||
"name": knowledge["name"],
|
||||
"description": knowledge.get("description", ""),
|
||||
"data": copy.deepcopy(knowledge.get("data", {})),
|
||||
"meta": copy.deepcopy(knowledge.get("meta", {})),
|
||||
"access_control": copy.deepcopy(knowledge.get("access_control", {})),
|
||||
"created_at": knowledge.get("created_at", self._timestamp()),
|
||||
"updated_at": knowledge.get("updated_at", self._timestamp()),
|
||||
"files": copy.deepcopy(knowledge.get("files", [])),
|
||||
}
|
||||
)
|
||||
return payload
|
||||
|
||||
def _knowledge_files_response(self, knowledge: Mapping[str, Any]) -> dict[str, Any]:
|
||||
@@ -435,7 +437,9 @@ class OpenWebUIMockService(OpenAPIMockService):
|
||||
# Delegate to parent
|
||||
return super().handle(method=method, path=path, json=json, params=params, files=files)
|
||||
|
||||
def _handle_knowledge_endpoints(self, method: str, segments: list[str], json: Mapping[str, Any] | None) -> tuple[int, Any]:
|
||||
def _handle_knowledge_endpoints(
|
||||
self, method: str, segments: list[str], json: Mapping[str, Any] | None
|
||||
) -> tuple[int, Any]:
|
||||
"""Handle knowledge-related API endpoints."""
|
||||
method_upper = method.upper()
|
||||
segment_count = len(segments)
|
||||
@@ -471,7 +475,9 @@ class OpenWebUIMockService(OpenAPIMockService):
|
||||
)
|
||||
return 200, entry
|
||||
|
||||
def _handle_specific_knowledge(self, method_upper: str, segments: list[str], json: Mapping[str, Any] | None) -> tuple[int, Any]:
|
||||
def _handle_specific_knowledge(
|
||||
self, method_upper: str, segments: list[str], json: Mapping[str, Any] | None
|
||||
) -> tuple[int, Any]:
|
||||
"""Handle endpoints for specific knowledge entries."""
|
||||
knowledge_id = segments[3]
|
||||
knowledge = self._knowledge.get(knowledge_id)
|
||||
@@ -486,7 +492,9 @@ class OpenWebUIMockService(OpenAPIMockService):
|
||||
|
||||
# File operations
|
||||
if segment_count == 6 and segments[4] == "file":
|
||||
return self._handle_knowledge_file_operations(method_upper, segments[5], knowledge, json or {})
|
||||
return self._handle_knowledge_file_operations(
|
||||
method_upper, segments[5], knowledge, json or {}
|
||||
)
|
||||
|
||||
# Delete knowledge
|
||||
if method_upper == "DELETE" and segment_count == 5 and segments[4] == "delete":
|
||||
@@ -495,7 +503,13 @@ class OpenWebUIMockService(OpenAPIMockService):
|
||||
|
||||
return 404, {"detail": "Knowledge operation not found"}
|
||||
|
||||
def _handle_knowledge_file_operations(self, method_upper: str, operation: str, knowledge: dict[str, Any], payload: Mapping[str, Any]) -> tuple[int, Any]:
|
||||
def _handle_knowledge_file_operations(
|
||||
self,
|
||||
method_upper: str,
|
||||
operation: str,
|
||||
knowledge: dict[str, Any],
|
||||
payload: Mapping[str, Any],
|
||||
) -> tuple[int, Any]:
|
||||
"""Handle file operations on knowledge entries."""
|
||||
if method_upper != "POST":
|
||||
return 405, {"detail": "Method not allowed"}
|
||||
@@ -507,7 +521,9 @@ class OpenWebUIMockService(OpenAPIMockService):
|
||||
|
||||
return 404, {"detail": "File operation not found"}
|
||||
|
||||
def _add_file_to_knowledge(self, knowledge: dict[str, Any], payload: Mapping[str, Any]) -> tuple[int, Any]:
|
||||
def _add_file_to_knowledge(
|
||||
self, knowledge: dict[str, Any], payload: Mapping[str, Any]
|
||||
) -> tuple[int, Any]:
|
||||
"""Add a file to a knowledge entry."""
|
||||
file_id = str(payload.get("file_id", ""))
|
||||
if not file_id:
|
||||
@@ -524,15 +540,21 @@ class OpenWebUIMockService(OpenAPIMockService):
|
||||
|
||||
return 200, self._knowledge_files_response(knowledge)
|
||||
|
||||
def _remove_file_from_knowledge(self, knowledge: dict[str, Any], payload: Mapping[str, Any]) -> tuple[int, Any]:
|
||||
def _remove_file_from_knowledge(
|
||||
self, knowledge: dict[str, Any], payload: Mapping[str, Any]
|
||||
) -> tuple[int, Any]:
|
||||
"""Remove a file from a knowledge entry."""
|
||||
file_id = str(payload.get("file_id", ""))
|
||||
knowledge["files"] = [item for item in knowledge.get("files", []) if item.get("id") != file_id]
|
||||
knowledge["files"] = [
|
||||
item for item in knowledge.get("files", []) if item.get("id") != file_id
|
||||
]
|
||||
knowledge["updated_at"] = self._timestamp()
|
||||
|
||||
return 200, self._knowledge_files_response(knowledge)
|
||||
|
||||
def _handle_file_endpoints(self, method: str, segments: list[str], files: Any) -> tuple[int, Any]:
|
||||
def _handle_file_endpoints(
|
||||
self, method: str, segments: list[str], files: Any
|
||||
) -> tuple[int, Any]:
|
||||
"""Handle file-related API endpoints."""
|
||||
method_upper = method.upper()
|
||||
|
||||
@@ -561,7 +583,9 @@ class OpenWebUIMockService(OpenAPIMockService):
|
||||
"""Delete a file and remove from all knowledge entries."""
|
||||
self._files.pop(file_id, None)
|
||||
for knowledge in self._knowledge.values():
|
||||
knowledge["files"] = [item for item in knowledge.get("files", []) if item.get("id") != file_id]
|
||||
knowledge["files"] = [
|
||||
item for item in knowledge.get("files", []) if item.get("id") != file_id
|
||||
]
|
||||
|
||||
return 200, {"deleted": True}
|
||||
|
||||
@@ -620,10 +644,14 @@ class R2RMockService(OpenAPIMockService):
|
||||
None,
|
||||
)
|
||||
|
||||
def _set_collection_document_ids(self, collection_id: str, document_id: str, *, add: bool) -> None:
|
||||
def _set_collection_document_ids(
|
||||
self, collection_id: str, document_id: str, *, add: bool
|
||||
) -> None:
|
||||
collection = self._collections.get(collection_id)
|
||||
if collection is None:
|
||||
collection = self.create_collection(name=f"Collection {collection_id}", collection_id=collection_id)
|
||||
collection = self.create_collection(
|
||||
name=f"Collection {collection_id}", collection_id=collection_id
|
||||
)
|
||||
documents = collection.setdefault("documents", [])
|
||||
if add:
|
||||
if document_id not in documents:
|
||||
@@ -677,7 +705,9 @@ class R2RMockService(OpenAPIMockService):
|
||||
self._set_collection_document_ids(collection_id, document_id, add=False)
|
||||
return True
|
||||
|
||||
def append_document_metadata(self, document_id: str, metadata_list: list[dict[str, Any]]) -> dict[str, Any] | None:
|
||||
def append_document_metadata(
|
||||
self, document_id: str, metadata_list: list[dict[str, Any]]
|
||||
) -> dict[str, Any] | None:
|
||||
entry = self._documents.get(document_id)
|
||||
if entry is None:
|
||||
return None
|
||||
@@ -712,7 +742,9 @@ class R2RMockService(OpenAPIMockService):
|
||||
# Delegate to parent
|
||||
return super().handle(method=method, path=path, json=json, params=params, files=files)
|
||||
|
||||
def _handle_collections_endpoint(self, method_upper: str, json: Mapping[str, Any] | None) -> tuple[int, Any]:
|
||||
def _handle_collections_endpoint(
|
||||
self, method_upper: str, json: Mapping[str, Any] | None
|
||||
) -> tuple[int, Any]:
|
||||
"""Handle collection-related endpoints."""
|
||||
if method_upper == "GET":
|
||||
return self._list_collections()
|
||||
@@ -732,10 +764,12 @@ class R2RMockService(OpenAPIMockService):
|
||||
clone["document_count"] = len(clone.get("documents", []))
|
||||
results.append(clone)
|
||||
|
||||
payload.update({
|
||||
"results": results,
|
||||
"total_entries": len(results),
|
||||
})
|
||||
payload.update(
|
||||
{
|
||||
"results": results,
|
||||
"total_entries": len(results),
|
||||
}
|
||||
)
|
||||
return 200, payload
|
||||
|
||||
def _create_collection_endpoint(self, body: Mapping[str, Any]) -> tuple[int, Any]:
|
||||
@@ -748,7 +782,9 @@ class R2RMockService(OpenAPIMockService):
|
||||
payload.update({"results": entry})
|
||||
return 200, payload
|
||||
|
||||
def _handle_documents_endpoint(self, method_upper: str, segments: list[str], json: Mapping[str, Any] | None, files: Any) -> tuple[int, Any]:
|
||||
def _handle_documents_endpoint(
|
||||
self, method_upper: str, segments: list[str], json: Mapping[str, Any] | None, files: Any
|
||||
) -> tuple[int, Any]:
|
||||
"""Handle document-related endpoints."""
|
||||
if method_upper == "POST" and len(segments) == 2:
|
||||
return self._create_document_endpoint(files)
|
||||
@@ -774,11 +810,13 @@ class R2RMockService(OpenAPIMockService):
|
||||
"#/components/schemas/R2RResults_IngestionResponse_"
|
||||
)
|
||||
ingestion = self.spec.generate_from_ref("#/components/schemas/IngestionResponse")
|
||||
ingestion.update({
|
||||
"message": ingestion.get("message") or "Ingestion task queued successfully.",
|
||||
"document_id": document["id"],
|
||||
"task_id": ingestion.get("task_id") or str(uuid4()),
|
||||
})
|
||||
ingestion.update(
|
||||
{
|
||||
"message": ingestion.get("message") or "Ingestion task queued successfully.",
|
||||
"document_id": document["id"],
|
||||
"task_id": ingestion.get("task_id") or str(uuid4()),
|
||||
}
|
||||
)
|
||||
response_payload["results"] = ingestion
|
||||
return 202, response_payload
|
||||
|
||||
@@ -835,7 +873,11 @@ class R2RMockService(OpenAPIMockService):
|
||||
|
||||
def _extract_content_from_files(self, raw_text_entry: Any) -> str:
|
||||
"""Extract content from files entry."""
|
||||
if isinstance(raw_text_entry, tuple) and len(raw_text_entry) >= 2 and raw_text_entry[1] is not None:
|
||||
if (
|
||||
isinstance(raw_text_entry, tuple)
|
||||
and len(raw_text_entry) >= 2
|
||||
and raw_text_entry[1] is not None
|
||||
):
|
||||
return str(raw_text_entry[1])
|
||||
return ""
|
||||
|
||||
@@ -886,7 +928,7 @@ class R2RMockService(OpenAPIMockService):
|
||||
|
||||
results = response_payload.get("results")
|
||||
if isinstance(results, Mapping):
|
||||
results = copy.deepcopy(results)
|
||||
results = dict(results)
|
||||
results.update({"success": success})
|
||||
else:
|
||||
results = {"success": success}
|
||||
@@ -894,7 +936,9 @@ class R2RMockService(OpenAPIMockService):
|
||||
response_payload["results"] = results
|
||||
return (200 if success else 404), response_payload
|
||||
|
||||
def _update_document_metadata(self, doc_id: str, json: Mapping[str, Any] | None) -> tuple[int, Any]:
|
||||
def _update_document_metadata(
|
||||
self, doc_id: str, json: Mapping[str, Any] | None
|
||||
) -> tuple[int, Any]:
|
||||
"""Update document metadata."""
|
||||
metadata_list = [dict(item) for item in json] if isinstance(json, list) else []
|
||||
document = self.append_document_metadata(doc_id, metadata_list)
|
||||
@@ -919,7 +963,15 @@ class FirecrawlMockService(OpenAPIMockService):
|
||||
def register_map_result(self, origin: str, links: list[str]) -> None:
|
||||
self._maps[origin] = list(links)
|
||||
|
||||
def register_page(self, url: str, *, markdown: str | None = None, html: str | None = None, metadata: Mapping[str, Any] | None = None, links: list[str] | None = None) -> None:
|
||||
def register_page(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
markdown: str | None = None,
|
||||
html: str | None = None,
|
||||
metadata: Mapping[str, Any] | None = None,
|
||||
links: list[str] | None = None,
|
||||
) -> None:
|
||||
self._pages[url] = {
|
||||
"markdown": markdown,
|
||||
"html": html,
|
||||
|
||||
Binary file not shown.
15
tests/unit/automations/conftest.py
Normal file
15
tests/unit/automations/conftest.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from ingest_pipeline.automations import AUTOMATION_TEMPLATES, get_automation_yaml_templates
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def automation_templates_copy() -> dict[str, str]:
|
||||
return get_automation_yaml_templates()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def automation_templates_reference() -> dict[str, str]:
|
||||
return AUTOMATION_TEMPLATES
|
||||
28
tests/unit/automations/test_init.py
Normal file
28
tests/unit/automations/test_init.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("template_key", "expected_snippet"),
|
||||
[
|
||||
("cancel_long_running", "Cancel Long Running Ingestion Flows"),
|
||||
("retry_failed", "Retry Failed Ingestion Flows"),
|
||||
("resource_monitoring", "Manage Work Pool Based on Resources"),
|
||||
],
|
||||
)
|
||||
def test_automation_templates_include_expected_entries(
|
||||
automation_templates_copy,
|
||||
template_key: str,
|
||||
expected_snippet: str,
|
||||
) -> None:
|
||||
assert expected_snippet in automation_templates_copy[template_key]
|
||||
|
||||
|
||||
def test_get_automation_yaml_templates_returns_copy(
|
||||
automation_templates_copy,
|
||||
automation_templates_reference,
|
||||
) -> None:
|
||||
automation_templates_copy["cancel_long_running"] = "overridden"
|
||||
|
||||
assert "overridden" not in automation_templates_reference["cancel_long_running"]
|
||||
Binary file not shown.
Binary file not shown.
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -22,7 +23,9 @@ from ingest_pipeline.core.models import (
|
||||
),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_ingestion_collection_resolution(monkeypatch: pytest.MonkeyPatch, collection_arg: str | None, expected: str) -> None:
|
||||
async def test_run_ingestion_collection_resolution(
|
||||
monkeypatch: pytest.MonkeyPatch, collection_arg: str | None, expected: str
|
||||
) -> None:
|
||||
recorded: dict[str, object] = {}
|
||||
|
||||
async def fake_flow(**kwargs: object) -> IngestionResult:
|
||||
@@ -128,7 +131,7 @@ async def test_run_search_collects_results(monkeypatch: pytest.MonkeyPatch) -> N
|
||||
threshold: float = 0.7,
|
||||
*,
|
||||
collection_name: str | None = None,
|
||||
) -> object:
|
||||
) -> AsyncGenerator[SimpleNamespace, None]:
|
||||
yield SimpleNamespace(title="Title", content="Body text", score=0.91)
|
||||
|
||||
dummy_settings = SimpleNamespace(
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -6,9 +6,11 @@ from unittest.mock import AsyncMock
|
||||
import pytest
|
||||
|
||||
from ingest_pipeline.core.models import (
|
||||
Document,
|
||||
IngestionJob,
|
||||
IngestionSource,
|
||||
StorageBackend,
|
||||
StorageConfig,
|
||||
)
|
||||
from ingest_pipeline.flows import ingestion
|
||||
from ingest_pipeline.flows.ingestion import (
|
||||
@@ -18,6 +20,7 @@ from ingest_pipeline.flows.ingestion import (
|
||||
filter_existing_documents_task,
|
||||
ingest_documents_task,
|
||||
)
|
||||
from ingest_pipeline.storage.base import BaseStorage
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -52,15 +55,34 @@ async def test_filter_existing_documents_task_filters_known_urls(
|
||||
new_url = "https://new.example.com"
|
||||
existing_id = str(FirecrawlIngestor.compute_document_id(existing_url))
|
||||
|
||||
class StubStorage:
|
||||
display_name = "stub-storage"
|
||||
class StubStorage(BaseStorage):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(StorageConfig(backend=StorageBackend.WEAVIATE, endpoint="http://test.local"))
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return "stub-storage"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def store(self, document: Document, *, collection_name: str | None = None) -> str:
|
||||
return "stub-id"
|
||||
|
||||
async def store_batch(
|
||||
self, documents: list[Document], *, collection_name: str | None = None
|
||||
) -> list[str]:
|
||||
return ["stub-id"] * len(documents)
|
||||
|
||||
async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool:
|
||||
return True
|
||||
|
||||
async def check_exists(
|
||||
self,
|
||||
document_id: str,
|
||||
*,
|
||||
collection_name: str | None = None,
|
||||
stale_after_days: int,
|
||||
stale_after_days: int = 30,
|
||||
) -> bool:
|
||||
return document_id == existing_id
|
||||
|
||||
@@ -102,7 +124,9 @@ async def test_ingest_documents_task_invokes_helpers(
|
||||
) -> object:
|
||||
return storage_sentinel
|
||||
|
||||
async def fake_create_ingestor(job_arg: IngestionJob, config_block_name: str | None = None) -> object:
|
||||
async def fake_create_ingestor(
|
||||
job_arg: IngestionJob, config_block_name: str | None = None
|
||||
) -> object:
|
||||
return ingestor_sentinel
|
||||
|
||||
monkeypatch.setattr(ingestion, "_create_ingestor", fake_create_ingestor)
|
||||
|
||||
@@ -4,6 +4,8 @@ from datetime import timedelta
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from prefect.deployments.runner import RunnerDeployment
|
||||
from prefect.schedules import Cron, Interval
|
||||
|
||||
from ingest_pipeline.flows import scheduler
|
||||
|
||||
@@ -17,7 +19,6 @@ def test_create_scheduled_deployment_cron(monkeypatch: pytest.MonkeyPatch) -> No
|
||||
captured |= kwargs
|
||||
return SimpleNamespace(**kwargs)
|
||||
|
||||
|
||||
monkeypatch.setattr(scheduler, "create_ingestion_flow", DummyFlow())
|
||||
|
||||
deployment = scheduler.create_scheduled_deployment(
|
||||
@@ -28,8 +29,14 @@ def test_create_scheduled_deployment_cron(monkeypatch: pytest.MonkeyPatch) -> No
|
||||
cron_expression="0 * * * *",
|
||||
)
|
||||
|
||||
assert captured["schedule"].cron == "0 * * * *"
|
||||
assert captured["parameters"]["source_type"] == "web"
|
||||
schedule = captured["schedule"]
|
||||
# Check that it's a cron schedule by verifying it has a cron attribute
|
||||
assert hasattr(schedule, "cron")
|
||||
assert schedule.cron == "0 * * * *"
|
||||
|
||||
parameters = captured["parameters"]
|
||||
assert isinstance(parameters, dict)
|
||||
assert parameters["source_type"] == "web"
|
||||
assert deployment.tags == ["web", "weaviate"]
|
||||
|
||||
|
||||
@@ -42,7 +49,6 @@ def test_create_scheduled_deployment_interval(monkeypatch: pytest.MonkeyPatch) -
|
||||
captured |= kwargs
|
||||
return SimpleNamespace(**kwargs)
|
||||
|
||||
|
||||
monkeypatch.setattr(scheduler, "create_ingestion_flow", DummyFlow())
|
||||
|
||||
deployment = scheduler.create_scheduled_deployment(
|
||||
@@ -55,7 +61,11 @@ def test_create_scheduled_deployment_interval(monkeypatch: pytest.MonkeyPatch) -
|
||||
tags=["custom"],
|
||||
)
|
||||
|
||||
assert captured["schedule"].interval == timedelta(minutes=15)
|
||||
schedule = captured["schedule"]
|
||||
# Check that it's an interval schedule by verifying it has an interval attribute
|
||||
assert hasattr(schedule, "interval")
|
||||
assert schedule.interval == timedelta(minutes=15)
|
||||
|
||||
assert captured["tags"] == ["custom"]
|
||||
assert deployment.parameters["storage_backend"] == "open_webui"
|
||||
|
||||
@@ -63,12 +73,13 @@ def test_create_scheduled_deployment_interval(monkeypatch: pytest.MonkeyPatch) -
|
||||
def test_serve_deployments_invokes_prefect(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
called: dict[str, object] = {}
|
||||
|
||||
def fake_serve(*deployments: object, limit: int) -> None:
|
||||
def fake_serve(*deployments: RunnerDeployment, limit: int) -> None:
|
||||
called["deployments"] = deployments
|
||||
called["limit"] = limit
|
||||
|
||||
monkeypatch.setattr(scheduler, "prefect_serve", fake_serve)
|
||||
|
||||
# Create a mock deployment using SimpleNamespace to avoid Prefect complexity
|
||||
deployment = SimpleNamespace(name="only")
|
||||
scheduler.serve_deployments([deployment])
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -9,7 +9,7 @@ from ingest_pipeline.ingestors.repomix import RepomixIngestor
|
||||
@pytest.mark.parametrize(
|
||||
("content", "expected_keys"),
|
||||
(
|
||||
("## File: src/app.py\nprint(\"hi\")", ["src/app.py"]),
|
||||
('## File: src/app.py\nprint("hi")', ["src/app.py"]),
|
||||
("plain content without markers", ["repository"]),
|
||||
),
|
||||
)
|
||||
@@ -44,7 +44,7 @@ def test_chunk_content_respects_max_size(
|
||||
("file_path", "content", "expected"),
|
||||
(
|
||||
("src/app.py", "def feature():\n return True", "python"),
|
||||
("scripts/run", "#!/usr/bin/env python\nprint(\"ok\")", "python"),
|
||||
("scripts/run", '#!/usr/bin/env python\nprint("ok")', "python"),
|
||||
("documentation.md", "# Title", "markdown"),
|
||||
("unknown.ext", "text", None),
|
||||
),
|
||||
@@ -81,5 +81,8 @@ def test_create_document_enriches_metadata() -> None:
|
||||
assert document.metadata["repository_name"] == "demo"
|
||||
assert document.metadata["branch_name"] == "main"
|
||||
assert document.metadata["commit_hash"] == "deadbeef"
|
||||
assert document.metadata["title"].endswith("(chunk 1)")
|
||||
|
||||
title = document.metadata["title"]
|
||||
assert title is not None
|
||||
assert title.endswith("(chunk 1)")
|
||||
assert document.collection == job.storage_backend.value
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user