This commit is contained in:
2025-09-21 01:38:47 +00:00
parent 43c69573f9
commit db854b8ec8
125 changed files with 6840 additions and 4319 deletions

View File

@@ -8,7 +8,8 @@
"Bash(cat:*)",
"mcp__firecrawl__firecrawl_search",
"mcp__firecrawl__firecrawl_scrape",
"Bash(python:*)"
"Bash(python:*)",
"mcp__coder__coder_report_task"
],
"deny": [],
"ask": []

View File

@@ -1,263 +1,255 @@
## Codebase Analysis Report: RAG Manager Ingestion Pipeline
TL;DR / Highestimpact fixes (do these first)
**Status:** Validated against current codebase implementation
**Target:** Enhanced implementation guidance for efficient agent execution
Event-loop blocking in the TUI (user-visible stutter):
time.sleep(2) is called at the end of an ingestion run on the IngestionScreen. Even though this runs in a thread worker, it blocks that worker and delays UI transition; prefer a scheduled UI callback instead.
This analysis has been validated against the actual codebase structure and provides implementation-specific details for executing recommended improvements. The codebase demonstrates solid architecture with clear separation of concerns between ingestion flows, storage adapters, and TUI components.
repomix-output (2)
### Architecture Overview
- **Storage Backends**: Weaviate, OpenWebUI, R2R with unified `BaseStorage` interface
- **TUI Framework**: Textual-based with reactive components and async worker patterns
- **Orchestration**: Prefect flows with retry logic and progress callbacks
- **Configuration**: Pydantic-based settings with environment variable support
Blocking Weaviate client calls inside async methods:
Several async methods in WeaviateStorage call the synchronous Weaviate client directly (connect(), collections.create, queries, inserts, deletes). Wrap those in asyncio.to_thread(...) (or equivalent) to avoid freezing the loop when these calls take time.
### Validated Implementation Analysis
Embedding at scale: use batch vectorization:
store_batch vectorizes each document one by one; you already have Vectorizer.vectorize_batch. Switching to batch reduces HTTP round trips and improves throughput under backpressure.
### 1. Bug Fixes & Potential Issues
Connection lifecycle: close the vectorizer client:
WeaviateStorage.close() closes the Weaviate client but not the httpx.AsyncClient inside the Vectorizer. Add an await to close it to prevent leaked connections/sockets under heavy usage.
These are areas where the code may not function as intended or could lead to errors.
Broadened exception handling in UI utilities:
Multiple places catch Exception broadly, making failures ambiguous and harder to surface to users (e.g., storage manager and list builders). Narrow these to expected exceptions and fail fast with user-friendly messages where appropriate.
* <details>
<summary>
<b>HIGH PRIORITY: `R2RStorage.store_batch` inefficient looping (Lines 161-179)</b>
</summary>
repomix-output (2)
* **File:** `ingest_pipeline/storage/r2r/storage.py:161-179`
* **Issue:** CONFIRMED - Method loops through documents calling `_store_single_document` individually
* **Impact:** ~5-10x performance degradation for batch operations
* **Implementation:** Check R2R v3 API for bulk endpoints; current implementation uses `/v3/documents` per document
* **Effort:** Medium (API research + refactor)
* **Priority:** High - affects all R2R ingestion workflows
</details>
Correctness / Bugs
* <details>
<summary>
<b>MEDIUM PRIORITY: Mixed HTTP client usage in `R2RStorage` (Lines 80, 99, 258)</b>
</summary>
Blocking sleep in TUI:
time.sleep(2) after posting notifications and before app.pop_screen() blocks the worker thread; use Textual timers instead.
* **File:** `ingest_pipeline/storage/r2r/storage.py:80,99,258`
* **Issue:** VALIDATED - Mixes `R2RAsyncClient` (line 80) with direct `httpx.AsyncClient` (lines 99, 258)
* **Specific Methods:** `initialize()`, `_ensure_collection()`, `_attempt_document_creation()`
* **Impact:** Inconsistent auth/header handling, connection pooling inefficiency
* **Implementation:** Extend `R2RAsyncClient` or create adapter pattern for missing endpoints
* **Test Coverage:** Check if affected methods have unit tests before refactoring
* **Effort:** Medium (requires SDK analysis)
</details>
repomix-output (2)
* <details>
<summary>
<b>MEDIUM PRIORITY: TUI blocking during storage init (Line 91)</b>
</summary>
Synchronous SDK in async contexts (Weaviate):
* **File:** `ingest_pipeline/cli/tui/utils/runners.py:91`
* **Issue:** CONFIRMED - `await storage_manager.initialize_all_backends()` blocks TUI startup
* **Current Implementation:** 30s timeout per backend in `StorageManager.initialize_all_backends()`
* **User Impact:** Frozen terminal for up to 90s if all backends timeout
* **Solution:** Move to `CollectionOverviewScreen.on_mount()` as `@work` task
* **Dependencies:** `dashboard.py:304` already has worker pattern for `refresh_collections`
* **Implementation:** Use existing loading indicators and status updates (lines 308-312)
* **Effort:** Low (pattern exists, needs relocation)
</details>
initialize() calls self.client.connect() (sync). Wrap with asyncio.to_thread(self.client.connect).
* <details>
<summary>
<b>LOW PRIORITY: Weak URL validation in `IngestionScreen` (Lines 240-260)</b>
</summary>
repomix-output (2)
* **File:** `ingest_pipeline/cli/tui/screens/ingestion.py:240-260`
* **Issue:** CONFIRMED - Method accepts `foo/bar` as valid (line 258)
* **Security Risk:** Medium - malicious URLs could be passed to ingestors
* **Current Logic:** Basic prefix checks only (http/https/file://)
* **Enhancement:** Add `pathlib.Path.exists()` for file:// paths, `.git` directory check for repos
* **Dependencies:** Import `pathlib` and add proper regex validation
* **Alternative:** Use `validators` library (not currently imported)
* **Effort:** Low (validation logic only)
</details>
Object operations such as collection.data.insert(...), collection.query.fetch_objects(...), collection.data.delete_many(...), and client.collections.delete(...) are sync calls invoked from async methods. These can stall the event loop under latency; use asyncio.to_thread(...) (see snippet below).
### 2. Code Redundancy & Refactoring Opportunities
HTTP client lifecycle (Vectorizer):
Vectorizer owns an httpx.AsyncClient but WeaviateStorage.close() doesnt close it—add a close call to avoid resource leaks.
These suggestions aim to make the code more concise, maintainable, and reusable (D.R.Y. - Don't Repeat Yourself).
Heuristic “word_count” for OpenWebUI file listing:
word_count is estimated via size/6. Thats fine as a placeholder but can mislead paging and UI logic downstream—consider a sentinel or a clearer label (e.g., estimated_word_count).
* <details>
<summary>
<b>HIGH IMPACT: Redundant collection logic in dashboard (Lines 356-424)</b>
</summary>
repomix-output (2)
* **File:** `ingest_pipeline/cli/tui/screens/dashboard.py:356-424`
* **Issue:** CONFIRMED - `list_weaviate_collections()` and `list_openwebui_collections()` duplicate `StorageManager.get_all_collections()`
* **Code Duplication:** ~70 lines of redundant collection listing logic
* **Architecture Violation:** UI layer coupled to specific storage implementations
* **Current Usage:** `refresh_collections()` calls `get_all_collections()` (line 327), making methods obsolete
* **Action:** DELETE methods `list_weaviate_collections` and `list_openwebui_collections`
* **Impact:** Code reduction ~70 lines, improved maintainability
* **Risk:** Low - methods appear unused in current flow
* **Effort:** Low (deletion only)
</details>
Wide except Exception:
Broad catches appear in places like the storage manager and TUI screen update closures; they hide actionable errors. Prefer catching StorageError, IngestionError, or specific SDK exceptions—and surface toasts with actionable details.
* <details>
<summary>
<b>MEDIUM IMPACT: Repetitive backend init pattern (Lines 255-291)</b>
</summary>
repomix-output (2)
* **File:** `ingest_pipeline/cli/tui/utils/storage_manager.py:255-291`
* **Issue:** CONFIRMED - Pattern repeated 3x for each backend type
* **Code Structure:** Check settings → Create config → Add task (12 lines × 3 backends)
* **Current Backends:** Weaviate (258-267), OpenWebUI (270-279), R2R (282-291)
* **Refactor Pattern:** Create `BackendConfig` dataclass with `(backend_type, endpoint_setting, api_key_setting, storage_class)`
* **Implementation:** Loop over config list, reducing ~36 lines to ~15 lines
* **Extensibility:** Adding new backend becomes one-line config addition
* **Testing:** Ensure `asyncio.gather()` behavior unchanged (line 296)
* **Effort:** Medium (requires dataclass design + testing)
</details>
Performance & Scalability
* <details>
<summary>
<b>MEDIUM IMPACT: Repeated Prefect block loading pattern (Lines 266-311)</b>
</summary>
Batch embeddings:
In WeaviateStorage.store_batch you conditionally vectorize each doc inline. Use your existing vectorize_batch and map the results back to documents. This cuts request overhead and enables controlled concurrency (you already have AsyncTaskManager to help).
* **File:** `ingest_pipeline/flows/ingestion.py:266-311`
* **Issue:** CONFIRMED - Pattern in `_create_ingestor()` and `_create_storage()` methods
* **Duplication:** `Block.aload()` + fallback logic repeated 4x across both methods
* **Variable Resolution:** Batch size logic (lines 244-255) also needs abstraction
* **Helper Functions Needed:**
- `load_block_with_fallback(block_slug: str, default_config: T) -> T`
- `resolve_prefect_variable(var_name: str, default: T, type_cast: Type[T]) -> T`
* **Impact:** Cleaner flow logic, better error handling, type safety
* **Lines Reduced:** ~20 lines of repetitive code
* **Effort:** Medium (requires generic typing)
</details>
Async-friendly Weaviate calls:
Offload sync SDK operations to a thread so your Textual UI remains responsive while bulk inserting/querying/deleting. (See patch below.)
### 3. User Experience (UX) Enhancements
Retry & backoff are solid in your HTTP base client:
Your TypedHttpClient.request adds exponential backoff + jitter; this is great and worth keeping consistent across adapters.
These are suggestions to make your TUI more powerful, intuitive, and enjoyable for the user.
repomix-output (2)
* <details>
<summary>
<b>HIGH IMPACT: Document content viewer modal (Add to documents.py)</b>
</summary>
UX notes (Textual TUI)
* **Target File:** `ingest_pipeline/cli/tui/screens/documents.py`
* **Current State:** READY - `DocumentManagementScreen` has table selection (line 212)
* **Implementation:**
- Add `Binding("v", "view_document", "View")` to BINDINGS (line 27)
- Create `DocumentContentModal(ModalScreen)` with `ScrollableContainer` + `Markdown`
- Use existing `get_current_document()` method (line 212)
- Fetch full content via `storage.retrieve(document_id)`
* **Dependencies:** Import `ModalScreen`, `ScrollableContainer`, `Markdown` from textual
* **User Value:** HIGH - essential for content inspection workflow
* **Effort:** Low-Medium (~50 lines of modal code)
* **Pattern:** Follow existing modal patterns in codebase
</details>
Notifications are good—make them consistent:
app.safe_notify(...) exists—use that everywhere instead of self.notify(...) to normalize markup handling & safety.
* <details>
<summary>
<b>HIGH IMPACT: Analytics tab visualization (Lines 164-189)</b>
</summary>
repomix-output (2)
* **Target File:** `ingest_pipeline/cli/tui/screens/dashboard.py:164-189`
* **Current State:** PLACEHOLDER - Static widgets with dummy content
* **Data Source:** Use existing `self.collections` (line 65) populated by `refresh_collections()`
* **Implementation Options:**
1. **Simple Text Chart:** ASCII bar chart using existing collections data
2. **textual-plotext:** Add dependency + bar chart widget
3. **Custom Widget:** Simple bar visualization with Static widgets
* **Metrics to Show:**
- Documents per collection (data available)
- Storage usage per backend (calculated in `_calculate_metrics()`)
- Ingestion timeline (requires timestamp tracking)
* **Effort:** Low-Medium (depends on visualization complexity)
* **Dependencies:** Consider `textual-plotext` or pure ASCII approach
</details>
Avoid visible “freeze” scenarios:
Replace the final time.sleep(2) with a timer and transition immediately; users see their actions complete without lag. (Patch below.)
* <details>
<summary>
<b>MEDIUM IMPACT: Global search implementation (Button exists, needs screen)</b>
</summary>
repomix-output (2)
* **Target File:** `ingest_pipeline/cli/tui/screens/dashboard.py`
* **Current State:** READY - "Search All" button exists (line 122), handler stubbed
* **Backend Support:** `StorageManager.search_across_backends()` method exists (line 413-441)
* **Implementation:**
- Create `GlobalSearchScreen(ModalScreen)` with search input + results table
- Use existing `search_across_backends()` method for data
- Add "Backend" column to results table showing data source
- Handle async search with loading indicators
* **Current Limitation:** Search only works for Weaviate (line 563), need to extend
* **Data Flow:** Input → `storage_manager.search_across_backends()` → Results display
* **Effort:** Medium (~100 lines for new screen + search logic)
</details>
Search table quick find:
You have EnhancedDataTable with quick-search messages; wire a shortcut (e.g., /) to focus a search input and filter rows live.
* <details>
<summary>
<b>MEDIUM IMPACT: R2R advanced features integration (Widgets ready)</b>
</summary>
repomix-output (2)
* **Target File:** `ingest_pipeline/cli/tui/screens/documents.py`
* **Available Widgets:** CONFIRMED - `ChunkViewer`, `EntityGraph`, `CollectionStats`, `DocumentOverview` in `r2r_widgets.py`
* **Current Implementation:** Basic document table only, R2R-specific features unused
* **Integration Points:**
- Add "R2R Details" button when `collection["type"] == "r2r"` (conditional UI)
- Create `R2RDocumentDetailsScreen` using existing widgets
- Use `StorageManager.get_r2r_storage()` method (exists at line 442)
* **R2R Methods Available:**
- `get_document_chunks()`, `extract_entities()`, `get_document_overview()`
* **User Value:** Medium-High for R2R users, showcases advanced features
* **Effort:** Low-Medium (widgets exist, need screen integration)
</details>
Theme file size & maintainability:
The theming system is thorough; consider splitting styles.py into smaller modules or generating CSS at build time to keep the Python file lean. (The responsive CSS generators are consolidated here.)
* <details>
<summary>
<b>LOW IMPACT: Create collection dialog (Backend methods exist)</b>
</summary>
repomix-output (2)
* **Target File:** `ingest_pipeline/cli/tui/screens/dashboard.py`
* **Backend Support:** CONFIRMED - `create_collection()` method exists for R2R storage (line 690)
* **Current State:** No "Create Collection" button in existing UI
* **Implementation:**
- Add "New Collection" button to dashboard action buttons
- Create `CreateCollectionModal` with name input + backend checkboxes
- Iterate over `storage_manager.get_available_backends()` for backend selection
- Call `storage.create_collection()` on selected backends
* **Backend Compatibility:** Check which storage backends support collection creation
* **User Value:** Low-Medium (manual workflow, not critical)
* **Effort:** Low-Medium (~75 lines for modal + integration)
</details>
Modularity / Redundancy
## Implementation Priority Matrix
Converge repeated property mapping:
WeaviateStorage.store and store_batch build the same properties dict; factor this into a small helper to keep schemas in one place (less drift, easier to extend).
### Quick Wins (High Impact, Low Effort)
1. **Delete redundant collection methods** (dashboard.py:356-424) - 5 min
2. **Fix TUI startup blocking** (runners.py:91) - 15 min
3. **Document content viewer modal** (documents.py) - 30 min
repomix-output (2)
### High Impact Fixes (Medium Effort)
1. **R2R batch operation optimization** (storage.py:161-179) - Research R2R v3 API + implementation
2. **Analytics tab visualization** (dashboard.py:164-189) - Choose visualization approach + implement
3. **Backend initialization refactoring** (storage_manager.py:255-291) - Dataclass design + testing
Common “describe/list/count” patterns across storages:
R2RStorage, OpenWebUIStorage, and WeaviateStorage present similar collection/document listing and counting methods. Consider a small “collection view” mixin with shared helpers; each backend implements only the API-specific steps.
### Technical Debt (Long-term)
1. **R2R client consistency** (storage.py) - SDK analysis + refactoring
2. **Prefect block loading helpers** (ingestion.py:266-311) - Generic typing + testing
3. **URL validation enhancement** (ingestion.py:240-260) - Security + validation logic
Security & Reliability
### Feature Enhancements (User Value)
1. **Global search implementation** - Medium effort, requires search backend extension
2. **R2R advanced features integration** - Showcase existing widget capabilities
3. **Create collection dialog** - Nice-to-have administrative feature
Input sanitization for LLM metadata:
Your MetadataTagger sanitizes and bounds fields (e.g., max lengths, language whitelist). This is a strong pattern—keep it.
## Agent Execution Notes
repomix-output (2)
**Context Efficiency Tips:**
- Focus on one priority tier at a time
- Read specific file ranges mentioned in line numbers
- Use existing patterns (worker decorators, modal screens, async methods)
- Test changes incrementally, especially async operations
- Verify import dependencies before implementation
Timeouts and typed HTTP clients:
You standardize HTTP clients, headers, and timeouts. Good foundation for consistent behavior & observability.
**Architecture Constraints:**
- Maintain async/await patterns throughout
- Follow Textual reactive widget patterns
- Preserve Prefect flow structure for orchestration
- Keep storage backend abstraction intact
repomix-output (2)
The codebase demonstrates excellent architectural foundations - these enhancements build upon existing strengths rather than requiring structural changes.
Suggested patches (dropin)
1) Dont block the UI when closing the ingestion screen
Current (blocking):
repomix-output (2)
import time
time.sleep(2)
cast("CollectionManagementApp", self.app).pop_screen()
Safer (schedule via the apps timer)
def _pop() -> None:
try:
self.app.pop_screen()
except Exception:
pass
# schedule from the worker thread
cast("CollectionManagementApp", self.app).call_from_thread(
lambda: self.app.set_timer(2.0, _pop)
)
2) Offload Weaviate sync calls from async methods
Example insert (from WeaviateStorage.store)
repomix-output (2)
# before (sync in async method)
collection.data.insert(properties=properties, vector=vector)
# after
await asyncio.to_thread(collection.data.insert, properties=properties, vector=vector)
Example fetch/delete by filter (from delete_by_filter)
repomix-output (2)
response = await asyncio.to_thread(
collection.query.fetch_objects, filters=where_filter, limit=1000
)
for obj in response.objects:
await asyncio.to_thread(collection.data.delete_by_id, obj.uuid)
Example connect during initialize
repomix-output (2)
await asyncio.to_thread(self.client.connect)
3) Batch embeddings in store_batch
Current (perdoc):
repomix-output (2)
for doc in documents:
if doc.vector is None:
doc.vector = await self.vectorizer.vectorize(doc.content)
Proposed (batch):
repomix-output (2)
# collect contents needing vectors
to_embed_idxs = [i for i, d in enumerate(documents) if d.vector is None]
if to_embed_idxs:
contents = [documents[i].content for i in to_embed_idxs]
vectors = await self.vectorizer.vectorize_batch(contents)
for j, idx in enumerate(to_embed_idxs):
documents[idx].vector = vectors[j]
4) Close the Vectorizer HTTP client when storage closes
Current: WeaviateStorage.close() only closes the Weaviate client.
repomix-output (2)
Add:
async def close(self) -> None:
if self.client:
try:
cast(weaviate.WeaviateClient, self.client).close()
except Exception as e:
import logging
logging.warning("Error closing Weaviate client: %s", e)
# NEW: close vectorizer HTTP client too
try:
await self.vectorizer.client.aclose()
except Exception:
pass
(Your Vectorizer owns an httpx.AsyncClient with headers/timeouts set.)
repomix-output (2)
5) Prefer safe_notify consistently
Replace direct self.notify(...) calls inside TUI screens with cast(AppType, self.app).safe_notify(...) (same severity, markup=False by default), centralizing styling/sanitization.
repomix-output (2)
Smaller improvements (quality of life)
Quick search focus: Wire the EnhancedDataTable quick-search to a keyboard binding (e.g., /) so users can filter rows without reaching for the mouse.
repomix-output (2)
Refactor repeated properties dict: Extract the property construction in WeaviateStorage.store/store_batch into a helper to avoid drift and reduce line count.
repomix-output (2)
Styles organization: styles.py is hefty by design. Consider splitting “theme palette”, “components”, and “responsive” generators into separate modules to keep diffs small and reviews easier.
repomix-output (2)
Architecture & Modularity
Storage adapters: The three adapters share “describe/list/count” concepts. Introduce a tiny shared “CollectionIntrospection” mixin (interfaces + default fallbacks), and keep only API specifics in each adapter. This will simplify the TUIs StorageManager as well.
Ingestion flows: Good use of Prefect tasks/flows with retries, variables, and tagging. The Firecrawl→R2R specialization cleanly reuses common steps. The batch boundaries are clear and progress updates are threaded back to the UI.
DX / Testing
Unit tests: In this export I dont see tests. Add lightweight tests for:
Vector extraction/format parsing in Vectorizer (covers multiple providers).
repomix-output (2)
Weaviate adapters: property building, name normalization, vector extraction.
repomix-output (2)
MetadataTagger sanitization rules.
repomix-output (2)
TUI: use Textuals pilot to test that notifications and timers trigger and that screens transition without blocking (verifies the sleep fix).
Static analysis: You already have excellent typing. Add ruff + mypy --strict and a pre-commit config to keep it consistent across contributors.

View File

@@ -19,7 +19,6 @@ actions:
source: inferred
enabled: true
""",
"retry_failed": """
name: Retry Failed Ingestion Flows
description: Retries failed ingestion flows with original parameters
@@ -39,7 +38,6 @@ actions:
validate_first: false
enabled: true
""",
"resource_monitoring": """
name: Manage Work Pool Based on Resources
description: Pauses work pool when system resources are constrained

View File

@@ -139,6 +139,7 @@ def ingest(
# If we're already in an event loop (e.g., in Jupyter), use nest_asyncio
try:
import nest_asyncio
nest_asyncio.apply()
result = asyncio.run(run_with_progress())
except ImportError:
@@ -449,7 +450,9 @@ def search(
def blocks_command() -> None:
"""🧩 List and manage Prefect Blocks."""
console.print("[bold cyan]📦 Prefect Blocks Management[/bold cyan]")
console.print("Use 'prefect block register --module ingest_pipeline.core.models' to register custom blocks")
console.print(
"Use 'prefect block register --module ingest_pipeline.core.models' to register custom blocks"
)
console.print("Use 'prefect block ls' to list available blocks")
@@ -507,7 +510,9 @@ async def run_list_collections() -> None:
weaviate_config = StorageConfig(
backend=StorageBackend.WEAVIATE,
endpoint=settings.weaviate_endpoint,
api_key=SecretStr(settings.weaviate_api_key) if settings.weaviate_api_key is not None else None,
api_key=SecretStr(settings.weaviate_api_key)
if settings.weaviate_api_key is not None
else None,
collection_name="default",
)
weaviate = WeaviateStorage(weaviate_config)
@@ -528,7 +533,9 @@ async def run_list_collections() -> None:
openwebui_config = StorageConfig(
backend=StorageBackend.OPEN_WEBUI,
endpoint=settings.openwebui_endpoint,
api_key=SecretStr(settings.openwebui_api_key) if settings.openwebui_api_key is not None else None,
api_key=SecretStr(settings.openwebui_api_key)
if settings.openwebui_api_key is not None
else None,
collection_name="default",
)
openwebui = OpenWebUIStorage(openwebui_config)
@@ -593,7 +600,9 @@ async def run_search(query: str, collection: str | None, backend: str, limit: in
weaviate_config = StorageConfig(
backend=StorageBackend.WEAVIATE,
endpoint=settings.weaviate_endpoint,
api_key=SecretStr(settings.weaviate_api_key) if settings.weaviate_api_key is not None else None,
api_key=SecretStr(settings.weaviate_api_key)
if settings.weaviate_api_key is not None
else None,
collection_name=collection or "default",
)
weaviate = WeaviateStorage(weaviate_config)

View File

@@ -31,7 +31,6 @@ else: # pragma: no cover - optional dependency fallback
R2RStorage = BaseStorage
class CollectionManagementApp(App[None]):
"""Enhanced modern Textual application with comprehensive keyboard navigation."""

View File

@@ -57,7 +57,9 @@ class ResponsiveGrid(Container):
markup: bool = True,
) -> None:
"""Initialize responsive grid."""
super().__init__(*children, name=name, id=id, classes=classes, disabled=disabled, markup=markup)
super().__init__(
*children, name=name, id=id, classes=classes, disabled=disabled, markup=markup
)
self._columns: int = columns
self._auto_fit: bool = auto_fit
self._compact: bool = compact
@@ -327,7 +329,9 @@ class CardLayout(ResponsiveGrid):
) -> None:
"""Initialize card layout with default settings for cards."""
# Default to auto-fit cards with minimum width
super().__init__(auto_fit=True, name=name, id=id, classes=classes, disabled=disabled, markup=markup)
super().__init__(
auto_fit=True, name=name, id=id, classes=classes, disabled=disabled, markup=markup
)
class SplitPane(Container):
@@ -396,7 +400,9 @@ class SplitPane(Container):
if self._vertical:
cast(Widget, self).add_class("vertical")
pane_classes = ("top-pane", "bottom-pane") if self._vertical else ("left-pane", "right-pane")
pane_classes = (
("top-pane", "bottom-pane") if self._vertical else ("left-pane", "right-pane")
)
yield Container(self._left_content, classes=pane_classes[0])
yield Static("", classes="splitter")

View File

@@ -1,7 +1,7 @@
"""Data models and TypedDict definitions for the TUI."""
from enum import IntEnum
from typing import Any, TypedDict
from typing import TypedDict
class StorageCapabilities(IntEnum):
@@ -47,7 +47,7 @@ class ChunkInfo(TypedDict):
content: str
start_index: int
end_index: int
metadata: dict[str, Any]
metadata: dict[str, object]
class EntityInfo(TypedDict):
@@ -57,7 +57,7 @@ class EntityInfo(TypedDict):
name: str
type: str
confidence: float
metadata: dict[str, Any]
metadata: dict[str, object]
class FirecrawlOptions(TypedDict, total=False):
@@ -77,7 +77,7 @@ class FirecrawlOptions(TypedDict, total=False):
max_depth: int
# Extraction options
extract_schema: dict[str, Any] | None
extract_schema: dict[str, object] | None
extract_prompt: str | None

View File

@@ -29,7 +29,7 @@ class BaseScreen(Screen[object]):
name: str | None = None,
id: str | None = None,
classes: str | None = None,
**kwargs: object
**kwargs: object,
) -> None:
"""Initialize base screen."""
super().__init__(name=name, id=id, classes=classes)
@@ -55,7 +55,7 @@ class CRUDScreen(BaseScreen, Generic[T]):
name: str | None = None,
id: str | None = None,
classes: str | None = None,
**kwargs: object
**kwargs: object,
) -> None:
"""Initialize CRUD screen."""
super().__init__(storage_manager, name=name, id=id, classes=classes)
@@ -311,7 +311,7 @@ class FormScreen(ModalScreen[T], Generic[T]):
name: str | None = None,
id: str | None = None,
classes: str | None = None,
**kwargs: object
**kwargs: object,
) -> None:
"""Initialize form screen."""
super().__init__(name=name, id=id, classes=classes)

View File

@@ -1,7 +1,6 @@
"""Main dashboard screen with collections overview."""
import logging
from datetime import datetime
from typing import TYPE_CHECKING, Final
from textual import work
@@ -357,7 +356,6 @@ class CollectionOverviewScreen(Screen[None]):
self.is_loading = False
loading_indicator.display = False
async def update_collections_table(self) -> None:
"""Update the collections table with enhanced formatting."""
table = self.query_one("#collections_table", EnhancedDataTable)

View File

@@ -82,7 +82,10 @@ class ConfirmDeleteScreen(Screen[None]):
try:
if self.collection["type"] == "weaviate" and self.parent_screen.weaviate:
# Delete Weaviate collection
if self.parent_screen.weaviate.client and self.parent_screen.weaviate.client.collections:
if (
self.parent_screen.weaviate.client
and self.parent_screen.weaviate.client.collections
):
self.parent_screen.weaviate.client.collections.delete(self.collection["name"])
self.notify(
f"Deleted Weaviate collection: {self.collection['name']}",
@@ -100,7 +103,7 @@ class ConfirmDeleteScreen(Screen[None]):
return
# Check if the storage backend supports collection deletion
if not hasattr(storage_backend, 'delete_collection'):
if not hasattr(storage_backend, "delete_collection"):
self.notify(
f"❌ Collection deletion not supported for {self.collection['type']} backend",
severity="error",
@@ -113,10 +116,13 @@ class ConfirmDeleteScreen(Screen[None]):
collection_name = str(self.collection["name"])
collection_type = str(self.collection["type"])
self.notify(f"Deleting {collection_type} collection: {collection_name}...", severity="information")
self.notify(
f"Deleting {collection_type} collection: {collection_name}...",
severity="information",
)
# Use the standard delete_collection method for all backends
if hasattr(storage_backend, 'delete_collection'):
if hasattr(storage_backend, "delete_collection"):
success = await storage_backend.delete_collection(collection_name)
else:
self.notify("❌ Backend does not support collection deletion", severity="error")
@@ -149,7 +155,6 @@ class ConfirmDeleteScreen(Screen[None]):
self.parent_screen.refresh_collections()
class ConfirmDocumentDeleteScreen(Screen[None]):
"""Screen for confirming document deletion."""
@@ -223,11 +228,11 @@ class ConfirmDocumentDeleteScreen(Screen[None]):
try:
results: dict[str, bool] = {}
if hasattr(self.parent_screen, 'storage') and self.parent_screen.storage:
if hasattr(self.parent_screen, "storage") and self.parent_screen.storage:
# Delete documents via storage
# The storage should have delete_documents method for weaviate
storage = self.parent_screen.storage
if hasattr(storage, 'delete_documents'):
if hasattr(storage, "delete_documents"):
results = await storage.delete_documents(
self.doc_ids,
collection_name=self.collection["name"],
@@ -280,7 +285,9 @@ class LogViewerScreen(ModalScreen[None]):
yield Header(show_clock=True)
yield Container(
Static("📜 Live Application Logs", classes="title"),
Static("Logs update in real time. Press S to reveal the log file path.", classes="subtitle"),
Static(
"Logs update in real time. Press S to reveal the log file path.", classes="subtitle"
),
RichLog(id="log_stream", classes="log-stream", wrap=True, highlight=False),
Static("", id="log_file_path", classes="subtitle"),
classes="main_container log-viewer-container",
@@ -291,13 +298,13 @@ class LogViewerScreen(ModalScreen[None]):
"""Attach this viewer to the parent application once mounted."""
self._log_widget = self.query_one(RichLog)
if hasattr(self.app, 'attach_log_viewer'):
if hasattr(self.app, "attach_log_viewer"):
self.app.attach_log_viewer(self) # type: ignore[arg-type]
def on_unmount(self) -> None:
"""Detach from the parent application when closed."""
if hasattr(self.app, 'detach_log_viewer'):
if hasattr(self.app, "detach_log_viewer"):
self.app.detach_log_viewer(self) # type: ignore[arg-type]
def _get_log_widget(self) -> RichLog:
@@ -340,4 +347,6 @@ class LogViewerScreen(ModalScreen[None]):
if self._log_file is None:
self.notify("File logging is disabled for this session.", severity="warning")
else:
self.notify(f"Log file available at: {self._log_file}", severity="information", markup=False)
self.notify(
f"Log file available at: {self._log_file}", severity="information", markup=False
)

View File

@@ -116,7 +116,9 @@ class DocumentManagementScreen(Screen[None]):
content_type=str(doc.get("content_type", "text/plain")),
content_preview=str(doc.get("content_preview", "")),
word_count=(
lambda wc_val: int(wc_val) if isinstance(wc_val, (int, str)) and str(wc_val).isdigit() else 0
lambda wc_val: int(wc_val)
if isinstance(wc_val, (int, str)) and str(wc_val).isdigit()
else 0
)(doc.get("word_count", 0)),
timestamp=str(doc.get("timestamp", "")),
)
@@ -126,7 +128,7 @@ class DocumentManagementScreen(Screen[None]):
# For storage backends that don't support document listing, show a message
self.notify(
f"Document listing not supported for {self.storage.__class__.__name__}",
severity="information"
severity="information",
)
self.documents = []
@@ -330,7 +332,9 @@ class DocumentManagementScreen(Screen[None]):
"""View the content of the currently selected document."""
if doc := self.get_current_document():
if self.storage:
self.app.push_screen(DocumentContentModal(doc, self.storage, self.collection["name"]))
self.app.push_screen(
DocumentContentModal(doc, self.storage, self.collection["name"])
)
else:
self.notify("No storage backend available", severity="error")
else:
@@ -381,13 +385,13 @@ class DocumentContentModal(ModalScreen[None]):
yield Container(
Static(
f"📄 Document: {self.document['title'][:60]}{'...' if len(self.document['title']) > 60 else ''}",
classes="modal-header"
classes="modal-header",
),
ScrollableContainer(
Markdown("Loading document content...", id="document_content"),
LoadingIndicator(id="content_loading"),
classes="modal-content"
)
classes="modal-content",
),
)
async def on_mount(self) -> None:
@@ -398,30 +402,29 @@ class DocumentContentModal(ModalScreen[None]):
try:
# Get full document content
doc_content = await self.storage.retrieve(
self.document["id"],
collection_name=self.collection_name
self.document["id"], collection_name=self.collection_name
)
# Format content for display
if isinstance(doc_content, str):
formatted_content = f"""# {self.document['title']}
formatted_content = f"""# {self.document["title"]}
**Source:** {self.document.get('source_url', 'N/A')}
**Type:** {self.document.get('content_type', 'text/plain')}
**Words:** {self.document.get('word_count', 0):,}
**Timestamp:** {self.document.get('timestamp', 'N/A')}
**Source:** {self.document.get("source_url", "N/A")}
**Type:** {self.document.get("content_type", "text/plain")}
**Words:** {self.document.get("word_count", 0):,}
**Timestamp:** {self.document.get("timestamp", "N/A")}
---
{doc_content}
"""
else:
formatted_content = f"""# {self.document['title']}
formatted_content = f"""# {self.document["title"]}
**Source:** {self.document.get('source_url', 'N/A')}
**Type:** {self.document.get('content_type', 'text/plain')}
**Words:** {self.document.get('word_count', 0):,}
**Timestamp:** {self.document.get('timestamp', 'N/A')}
**Source:** {self.document.get("source_url", "N/A")}
**Type:** {self.document.get("content_type", "text/plain")}
**Words:** {self.document.get("word_count", 0):,}
**Timestamp:** {self.document.get("timestamp", "N/A")}
---
@@ -431,6 +434,8 @@ class DocumentContentModal(ModalScreen[None]):
content_widget.update(formatted_content)
except Exception as e:
content_widget.update(f"# Error Loading Document\n\nFailed to load document content: {e}")
content_widget.update(
f"# Error Loading Document\n\nFailed to load document content: {e}"
)
finally:
loading.display = False

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast
from textual import work
from textual.app import ComposeResult
@@ -105,12 +105,23 @@ class IngestionScreen(ModalScreen[None]):
Label("📋 Source Type (Press 1/2/3):", classes="input-label"),
Horizontal(
Button("🌐 Web (1)", id="web_btn", variant="primary", classes="type-button"),
Button("📦 Repository (2)", id="repo_btn", variant="default", classes="type-button"),
Button("📖 Documentation (3)", id="docs_btn", variant="default", classes="type-button"),
Button(
"📦 Repository (2)", id="repo_btn", variant="default", classes="type-button"
),
Button(
"📖 Documentation (3)",
id="docs_btn",
variant="default",
classes="type-button",
),
classes="type_buttons",
),
Rule(line_style="dashed"),
Label(f"🗄️ Target Storages ({len(self.available_backends)} available):", classes="input-label", id="backend_label"),
Label(
f"🗄️ Target Storages ({len(self.available_backends)} available):",
classes="input-label",
id="backend_label",
),
Container(
*self._create_backend_checkbox_widgets(),
classes="backend-selection",
@@ -139,7 +150,6 @@ class IngestionScreen(ModalScreen[None]):
yield LoadingIndicator(id="loading", classes="pulse")
def _create_backend_checkbox_widgets(self) -> list[Checkbox]:
"""Create checkbox widgets for each available backend."""
checkboxes: list[Checkbox] = [
@@ -219,19 +229,26 @@ class IngestionScreen(ModalScreen[None]):
collection_name = collection_input.value.strip()
if not source_url:
self.notify("🔍 Please enter a source URL", severity="error")
cast("CollectionManagementApp", self.app).safe_notify(
"🔍 Please enter a source URL", severity="error"
)
url_input.focus()
return
# Validate URL format
if not self._validate_url(source_url):
self.notify("❌ Invalid URL format. Please enter a valid HTTP/HTTPS URL or file:// path", severity="error")
cast("CollectionManagementApp", self.app).safe_notify(
"❌ Invalid URL format. Please enter a valid HTTP/HTTPS URL or file:// path",
severity="error",
)
url_input.focus()
return
resolved_backends = self._resolve_selected_backends()
if not resolved_backends:
self.notify("⚠️ Select at least one storage backend", severity="warning")
cast("CollectionManagementApp", self.app).safe_notify(
"⚠️ Select at least one storage backend", severity="warning"
)
return
self.selected_backends = resolved_backends
@@ -246,18 +263,16 @@ class IngestionScreen(ModalScreen[None]):
url_lower = url.lower()
# Allow HTTP/HTTPS URLs
if url_lower.startswith(('http://', 'https://')):
if url_lower.startswith(("http://", "https://")):
# Additional validation could be added here
return True
# Allow file:// URLs for repository paths
if url_lower.startswith('file://'):
if url_lower.startswith("file://"):
return True
# Allow local file paths that look like repositories
return '/' in url and not url_lower.startswith(
('javascript:', 'data:', 'vbscript:')
)
return "/" in url and not url_lower.startswith(("javascript:", "data:", "vbscript:"))
def _resolve_selected_backends(self) -> list[StorageBackend]:
selected: list[StorageBackend] = []
@@ -288,13 +303,14 @@ class IngestionScreen(ModalScreen[None]):
if not self.selected_backends:
status_widget.update("📋 Selected: None")
elif len(self.selected_backends) == 1:
backend_name = BACKEND_LABELS.get(self.selected_backends[0], self.selected_backends[0].value)
backend_name = BACKEND_LABELS.get(
self.selected_backends[0], self.selected_backends[0].value
)
status_widget.update(f"📋 Selected: {backend_name}")
else:
# Multiple backends selected
backend_names = [
BACKEND_LABELS.get(backend, backend.value)
for backend in self.selected_backends
BACKEND_LABELS.get(backend, backend.value) for backend in self.selected_backends
]
if len(backend_names) <= 3:
# Show all names if 3 or fewer
@@ -319,7 +335,10 @@ class IngestionScreen(ModalScreen[None]):
for backend in BACKEND_ORDER:
if backend not in self.available_backends:
continue
if backend.value.lower() == backend_name_lower or backend.name.lower() == backend_name_lower:
if (
backend.value.lower() == backend_name_lower
or backend.name.lower() == backend_name_lower
):
matched_backends.append(backend)
break
return matched_backends or [self.available_backends[0]]
@@ -351,6 +370,7 @@ class IngestionScreen(ModalScreen[None]):
loading.display = False
except Exception:
pass
cast("CollectionManagementApp", self.app).call_from_thread(_update)
def progress_reporter(percent: int, message: str) -> None:
@@ -362,6 +382,7 @@ class IngestionScreen(ModalScreen[None]):
progress_text.update(message)
except Exception:
pass
cast("CollectionManagementApp", self.app).call_from_thread(_update_progress)
try:
@@ -377,13 +398,18 @@ class IngestionScreen(ModalScreen[None]):
for i, backend in enumerate(backends):
progress_percent = 20 + (60 * i) // len(backends)
progress_reporter(progress_percent, f"🔗 Processing {backend.value} backend ({i+1}/{len(backends)})...")
progress_reporter(
progress_percent,
f"🔗 Processing {backend.value} backend ({i + 1}/{len(backends)})...",
)
try:
# Run the Prefect flow for this backend using asyncio.run with timeout
import asyncio
async def run_flow_with_timeout(current_backend: StorageBackend = backend) -> IngestionResult:
async def run_flow_with_timeout(
current_backend: StorageBackend = backend,
) -> IngestionResult:
return await asyncio.wait_for(
create_ingestion_flow(
source_url=source_url,
@@ -392,7 +418,7 @@ class IngestionScreen(ModalScreen[None]):
collection_name=final_collection_name,
progress_callback=progress_reporter,
),
timeout=600.0 # 10 minute timeout
timeout=600.0, # 10 minute timeout
)
result = asyncio.run(run_flow_with_timeout())
@@ -401,25 +427,33 @@ class IngestionScreen(ModalScreen[None]):
total_failed += result.documents_failed
if result.error_messages:
flow_errors.extend([f"{backend.value}: {err}" for err in result.error_messages])
flow_errors.extend(
[f"{backend.value}: {err}" for err in result.error_messages]
)
except TimeoutError:
error_msg = f"{backend.value}: Timeout after 10 minutes"
flow_errors.append(error_msg)
progress_reporter(0, f"{backend.value} timed out")
def notify_timeout(msg: str = f"{backend.value} flow timed out after 10 minutes") -> None:
def notify_timeout(
msg: str = f"{backend.value} flow timed out after 10 minutes",
) -> None:
try:
self.notify(msg, severity="error", markup=False)
except Exception:
pass
cast("CollectionManagementApp", self.app).call_from_thread(notify_timeout)
except Exception as exc:
flow_errors.append(f"{backend.value}: {exc}")
def notify_error(msg: str = f"{backend.value} flow failed: {exc}") -> None:
try:
self.notify(msg, severity="error", markup=False)
except Exception:
pass
cast("CollectionManagementApp", self.app).call_from_thread(notify_error)
successful = total_successful
@@ -445,20 +479,37 @@ class IngestionScreen(ModalScreen[None]):
cast("CollectionManagementApp", self.app).call_from_thread(notify_results)
import time
time.sleep(2)
cast("CollectionManagementApp", self.app).pop_screen()
def _pop() -> None:
try:
self.app.pop_screen()
except Exception:
pass
# Schedule screen pop via timer instead of blocking
cast("CollectionManagementApp", self.app).call_from_thread(
lambda: self.app.set_timer(2.0, _pop)
)
except Exception as exc: # pragma: no cover - defensive
progress_reporter(0, f"❌ Prefect flows error: {exc}")
def notify_error(msg: str = f"❌ Prefect flows failed: {exc}") -> None:
try:
self.notify(msg, severity="error")
except Exception:
pass
cast("CollectionManagementApp", self.app).call_from_thread(notify_error)
import time
time.sleep(2)
def _pop_on_error() -> None:
try:
self.app.pop_screen()
except Exception:
pass
# Schedule screen pop via timer for error case too
cast("CollectionManagementApp", self.app).call_from_thread(
lambda: self.app.set_timer(2.0, _pop_on_error)
)
finally:
update_ui("hide_loading")

View File

@@ -43,12 +43,24 @@ class SearchScreen(Screen[None]):
@override
def compose(self) -> ComposeResult:
yield Header()
# Check if search is supported for this backend
backends = self.collection["backend"]
if isinstance(backends, str):
backends = [backends]
search_supported = "weaviate" in backends
search_indicator = "✅ Search supported" if search_supported else "❌ Search not supported"
yield Container(
Static(
f"🔍 Search in: {self.collection['name']} ({self.collection['backend']})",
f"🔍 Search in: {self.collection['name']} ({', '.join(backends)}) - {search_indicator}",
classes="title",
),
Static("Press / or Ctrl+F to focus search, Enter to search", classes="subtitle"),
Static(
"Press / or Ctrl+F to focus search, Enter to search"
if search_supported
else "Search functionality not available for this backend",
classes="subtitle",
),
Input(placeholder="Enter search query... (press Enter to search)", id="search_input"),
Button("🔍 Search", id="search_btn", variant="primary"),
Button("🗑️ Clear Results", id="clear_btn", variant="default"),
@@ -127,7 +139,9 @@ class SearchScreen(Screen[None]):
finally:
loading.display = False
def _setup_search_ui(self, loading: LoadingIndicator, table: EnhancedDataTable, status: Static, query: str) -> None:
def _setup_search_ui(
self, loading: LoadingIndicator, table: EnhancedDataTable, status: Static, query: str
) -> None:
"""Setup the search UI elements."""
loading.display = True
status.update(f"🔍 Searching for '{query}'...")
@@ -139,10 +153,18 @@ class SearchScreen(Screen[None]):
if self.collection["type"] == "weaviate" and self.weaviate:
return await self.search_weaviate(query)
elif self.collection["type"] == "openwebui" and self.openwebui:
return await self.search_openwebui(query)
# OpenWebUI search is not yet implemented
self.notify("Search not supported for OpenWebUI collections", severity="warning")
return []
elif self.collection["type"] == "r2r":
# R2R search would go here when implemented
self.notify("Search not supported for R2R collections", severity="warning")
return []
return []
def _populate_results_table(self, table: EnhancedDataTable, results: list[dict[str, str | float]]) -> None:
def _populate_results_table(
self, table: EnhancedDataTable, results: list[dict[str, str | float]]
) -> None:
"""Populate the results table with search results."""
for result in results:
row_data = self._format_result_row(result)
@@ -193,7 +215,11 @@ class SearchScreen(Screen[None]):
return str(score)
def _update_search_status(
self, status: Static, query: str, results: list[dict[str, str | float]], table: EnhancedDataTable
self,
status: Static,
query: str,
results: list[dict[str, str | float]],
table: EnhancedDataTable,
) -> None:
"""Update search status and notifications based on results."""
if not results:

View File

@@ -13,6 +13,8 @@ TextualApp = App[object]
class AppProtocol(Protocol):
"""Protocol for apps that support CSS and refresh."""
CSS: str
def refresh(self) -> None:
"""Refresh the app."""
...
@@ -1122,16 +1124,16 @@ def get_css_for_theme(theme_type: ThemeType) -> str:
def apply_theme_to_app(app: TextualApp | AppProtocol, theme_type: ThemeType) -> None:
"""Apply a theme to a Textual app instance."""
try:
css = set_theme(theme_type)
# Set CSS using the standard Textual approach
if hasattr(app, "CSS") or isinstance(app, App):
setattr(app, "CSS", css)
# Refresh the app to apply new CSS
if hasattr(app, "refresh"):
app.refresh()
# Note: CSS class variable cannot be changed at runtime
# This function would need to be called during app initialization
# or implement a different approach for dynamic theming
_ = set_theme(theme_type) # Keep for future implementation
if hasattr(app, "refresh"):
app.refresh()
except Exception as e:
# Graceful fallback - log but don't crash the UI
import logging
logging.debug(f"Failed to apply theme to app: {e}")
@@ -1185,11 +1187,11 @@ class ThemeSwitcher:
# Responsive breakpoints for dynamic layout adaptation
RESPONSIVE_BREAKPOINTS = {
"xs": 40, # Extra small terminals
"sm": 60, # Small terminals
"md": 100, # Medium terminals
"lg": 140, # Large terminals
"xl": 180, # Extra large terminals
"xs": 40, # Extra small terminals
"sm": 60, # Small terminals
"md": 100, # Medium terminals
"lg": 140, # Large terminals
"xl": 180, # Extra large terminals
}

View File

@@ -10,8 +10,9 @@ from pathlib import Path
from queue import Queue
from typing import NamedTuple
import platformdirs
from ....config import configure_prefect, get_settings
from ....core.models import StorageBackend
from .storage_manager import StorageManager
@@ -53,21 +54,36 @@ def _configure_tui_logging(*, log_level: str) -> _TuiLoggingContext:
log_file: Path | None = None
try:
# Try current directory first for development
log_dir = Path.cwd() / "logs"
log_dir.mkdir(parents=True, exist_ok=True)
log_file = log_dir / "tui.log"
file_handler = RotatingFileHandler(
log_file,
maxBytes=2_000_000,
backupCount=5,
encoding="utf-8",
)
file_handler.setLevel(resolved_level)
file_handler.setFormatter(formatter)
root_logger.addHandler(file_handler)
except OSError as exc: # pragma: no cover - filesystem specific
fallback = logging.getLogger(__name__)
fallback.warning("Failed to configure file logging for TUI: %s", exc)
except OSError:
# Fall back to user log directory
try:
log_dir = Path(platformdirs.user_log_dir("ingest-pipeline", "ingest-pipeline"))
log_dir.mkdir(parents=True, exist_ok=True)
log_file = log_dir / "tui.log"
except OSError as exc:
fallback = logging.getLogger(__name__)
fallback.warning("Failed to create log directory, file logging disabled: %s", exc)
log_file = None
if log_file:
try:
file_handler = RotatingFileHandler(
log_file,
maxBytes=2_000_000,
backupCount=5,
encoding="utf-8",
)
file_handler.setLevel(resolved_level)
file_handler.setFormatter(formatter)
root_logger.addHandler(file_handler)
except OSError as exc:
fallback = logging.getLogger(__name__)
fallback.warning("Failed to configure file logging for TUI: %s", exc)
log_file = None
_logging_context = _TuiLoggingContext(log_queue, formatter, log_file)
return _logging_context
@@ -93,6 +109,7 @@ async def run_textual_tui() -> None:
# Import here to avoid circular import
from ..app import CollectionManagementApp
app = CollectionManagementApp(
storage_manager,
None, # weaviate - will be available after initialization

View File

@@ -1,6 +1,5 @@
"""Storage management utilities for TUI applications."""
from __future__ import annotations
import asyncio
@@ -11,12 +10,11 @@ from pydantic import SecretStr
from ....core.exceptions import StorageError
from ....core.models import Document, StorageBackend, StorageConfig
from ..models import CollectionInfo, StorageCapabilities
from ....storage.base import BaseStorage
from ....storage.openwebui import OpenWebUIStorage
from ....storage.r2r.storage import R2RStorage
from ....storage.weaviate import WeaviateStorage
from ..models import CollectionInfo, StorageCapabilities
if TYPE_CHECKING:
from ....config.settings import Settings
@@ -39,7 +37,6 @@ class StorageBackendProtocol(Protocol):
async def close(self) -> None: ...
class MultiStorageAdapter(BaseStorage):
"""Mirror writes to multiple storage backends."""
@@ -70,7 +67,10 @@ class MultiStorageAdapter(BaseStorage):
# Replicate to secondary backends concurrently
if len(self._storages) > 1:
async def replicate_to_backend(storage: BaseStorage) -> tuple[BaseStorage, bool, Exception | None]:
async def replicate_to_backend(
storage: BaseStorage,
) -> tuple[BaseStorage, bool, Exception | None]:
try:
await storage.store(document, collection_name=collection_name)
return storage, True, None
@@ -106,11 +106,16 @@ class MultiStorageAdapter(BaseStorage):
self, documents: list[Document], *, collection_name: str | None = None
) -> list[str]:
# Store in primary backend first
primary_ids: list[str] = await self._primary.store_batch(documents, collection_name=collection_name)
primary_ids: list[str] = await self._primary.store_batch(
documents, collection_name=collection_name
)
# Replicate to secondary backends concurrently
if len(self._storages) > 1:
async def replicate_batch_to_backend(storage: BaseStorage) -> tuple[BaseStorage, bool, Exception | None]:
async def replicate_batch_to_backend(
storage: BaseStorage,
) -> tuple[BaseStorage, bool, Exception | None]:
try:
await storage.store_batch(documents, collection_name=collection_name)
return storage, True, None
@@ -135,7 +140,9 @@ class MultiStorageAdapter(BaseStorage):
if failures:
backends = ", ".join(failures)
primary_error = errors[0] if errors else Exception("Unknown batch replication error")
primary_error = (
errors[0] if errors else Exception("Unknown batch replication error")
)
raise StorageError(
f"Batch stored in primary backend but replication failed for: {backends}"
) from primary_error
@@ -144,11 +151,16 @@ class MultiStorageAdapter(BaseStorage):
async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool:
# Delete from primary backend first
primary_deleted: bool = await self._primary.delete(document_id, collection_name=collection_name)
primary_deleted: bool = await self._primary.delete(
document_id, collection_name=collection_name
)
# Delete from secondary backends concurrently
if len(self._storages) > 1:
async def delete_from_backend(storage: BaseStorage) -> tuple[BaseStorage, bool, Exception | None]:
async def delete_from_backend(
storage: BaseStorage,
) -> tuple[BaseStorage, bool, Exception | None]:
try:
await storage.delete(document_id, collection_name=collection_name)
return storage, True, None
@@ -222,7 +234,6 @@ class MultiStorageAdapter(BaseStorage):
return class_name
class StorageManager:
"""Centralized manager for all storage backend operations."""
@@ -237,7 +248,9 @@ class StorageManager:
"""Initialize all available storage backends with timeout protection."""
results: dict[StorageBackend, bool] = {}
async def init_backend(backend_type: StorageBackend, config: StorageConfig, storage_class: type[BaseStorage]) -> bool:
async def init_backend(
backend_type: StorageBackend, config: StorageConfig, storage_class: type[BaseStorage]
) -> bool:
"""Initialize a single backend with timeout."""
try:
storage = storage_class(config)
@@ -261,10 +274,17 @@ class StorageManager:
config = StorageConfig(
backend=StorageBackend.WEAVIATE,
endpoint=self.settings.weaviate_endpoint,
api_key=SecretStr(self.settings.weaviate_api_key) if self.settings.weaviate_api_key else None,
api_key=SecretStr(self.settings.weaviate_api_key)
if self.settings.weaviate_api_key
else None,
collection_name="default",
)
tasks.append((StorageBackend.WEAVIATE, init_backend(StorageBackend.WEAVIATE, config, WeaviateStorage)))
tasks.append(
(
StorageBackend.WEAVIATE,
init_backend(StorageBackend.WEAVIATE, config, WeaviateStorage),
)
)
else:
results[StorageBackend.WEAVIATE] = False
@@ -273,10 +293,17 @@ class StorageManager:
config = StorageConfig(
backend=StorageBackend.OPEN_WEBUI,
endpoint=self.settings.openwebui_endpoint,
api_key=SecretStr(self.settings.openwebui_api_key) if self.settings.openwebui_api_key else None,
api_key=SecretStr(self.settings.openwebui_api_key)
if self.settings.openwebui_api_key
else None,
collection_name="default",
)
tasks.append((StorageBackend.OPEN_WEBUI, init_backend(StorageBackend.OPEN_WEBUI, config, OpenWebUIStorage)))
tasks.append(
(
StorageBackend.OPEN_WEBUI,
init_backend(StorageBackend.OPEN_WEBUI, config, OpenWebUIStorage),
)
)
else:
results[StorageBackend.OPEN_WEBUI] = False
@@ -295,7 +322,9 @@ class StorageManager:
# Execute initialization tasks concurrently
if tasks:
backend_types, task_coroutines = zip(*tasks, strict=False)
task_results: Sequence[bool | BaseException] = await asyncio.gather(*task_coroutines, return_exceptions=True)
task_results: Sequence[bool | BaseException] = await asyncio.gather(
*task_coroutines, return_exceptions=True
)
for backend_type, task_result in zip(backend_types, task_results, strict=False):
results[backend_type] = task_result if isinstance(task_result, bool) else False
@@ -312,7 +341,9 @@ class StorageManager:
storages: list[BaseStorage] = []
seen: set[StorageBackend] = set()
for backend in backends:
backend_enum = backend if isinstance(backend, StorageBackend) else StorageBackend(backend)
backend_enum = (
backend if isinstance(backend, StorageBackend) else StorageBackend(backend)
)
if backend_enum in seen:
continue
seen.add(backend_enum)
@@ -350,12 +381,18 @@ class StorageManager:
except StorageError as e:
# Storage-specific errors - log and use 0 count
import logging
logging.warning(f"Failed to get count for {collection_name} on {backend_type.value}: {e}")
logging.warning(
f"Failed to get count for {collection_name} on {backend_type.value}: {e}"
)
count = 0
except Exception as e:
# Unexpected errors - log and skip this collection from this backend
import logging
logging.warning(f"Unexpected error counting {collection_name} on {backend_type.value}: {e}")
logging.warning(
f"Unexpected error counting {collection_name} on {backend_type.value}: {e}"
)
continue
size_mb = count * 0.01 # Rough estimate: 10KB per document
@@ -446,7 +483,9 @@ class StorageManager:
storage = self.backends.get(StorageBackend.R2R)
return storage if isinstance(storage, R2RStorage) else None
async def get_backend_status(self) -> dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]]:
async def get_backend_status(
self,
) -> dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]]:
"""Get comprehensive status for all backends."""
status: dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]] = {}

View File

@@ -158,9 +158,7 @@ class ScrapeOptionsForm(Widget):
formats.append("screenshot")
options: dict[str, object] = {
"formats": formats,
"only_main_content": self.query_one(
"#only_main_content", Switch
).value,
"only_main_content": self.query_one("#only_main_content", Switch).value,
}
include_tags_input = self.query_one("#include_tags", Input).value
if include_tags_input.strip():
@@ -195,11 +193,15 @@ class ScrapeOptionsForm(Widget):
if include_tags := options.get("include_tags", []):
include_list = include_tags if isinstance(include_tags, list) else []
self.query_one("#include_tags", Input).value = ", ".join(str(tag) for tag in include_list)
self.query_one("#include_tags", Input).value = ", ".join(
str(tag) for tag in include_list
)
if exclude_tags := options.get("exclude_tags", []):
exclude_list = exclude_tags if isinstance(exclude_tags, list) else []
self.query_one("#exclude_tags", Input).value = ", ".join(str(tag) for tag in exclude_list)
self.query_one("#exclude_tags", Input).value = ", ".join(
str(tag) for tag in exclude_list
)
# Set performance
wait_for = options.get("wait_for")
@@ -503,7 +505,9 @@ class ExtractOptionsForm(Widget):
"content": "string",
"tags": ["string"]
}"""
prompt_widget.text = "Extract article title, author, publication date, main content, and associated tags"
prompt_widget.text = (
"Extract article title, author, publication date, main content, and associated tags"
)
elif event.button.id == "preset_product":
schema_widget.text = """{
@@ -513,7 +517,9 @@ class ExtractOptionsForm(Widget):
"category": "string",
"availability": "string"
}"""
prompt_widget.text = "Extract product name, price, description, category, and availability status"
prompt_widget.text = (
"Extract product name, price, description, category, and availability status"
)
elif event.button.id == "preset_contact":
schema_widget.text = """{
@@ -523,7 +529,9 @@ class ExtractOptionsForm(Widget):
"company": "string",
"position": "string"
}"""
prompt_widget.text = "Extract contact information including name, email, phone, company, and position"
prompt_widget.text = (
"Extract contact information including name, email, phone, company, and position"
)
elif event.button.id == "preset_data":
schema_widget.text = """{

View File

@@ -26,8 +26,12 @@ class StatusIndicator(Static):
status_lower = status.lower()
if (status_lower in {"active", "online", "connected", "✓ active"} or
status_lower.endswith("active") or "" in status_lower and "active" in status_lower):
if (
status_lower in {"active", "online", "connected", "✓ active"}
or status_lower.endswith("active")
or "" in status_lower
and "active" in status_lower
):
self.add_class("status-active")
self.add_class("glow")
self.update(f"🟢 {status}")

View File

@@ -99,10 +99,17 @@ class ChunkViewer(Widget):
"id": str(chunk_data.get("id", "")),
"document_id": self.document_id,
"content": str(chunk_data.get("text", "")),
"start_index": (lambda si: int(si) if isinstance(si, (int, str)) else 0)(chunk_data.get("start_index", 0)),
"end_index": (lambda ei: int(ei) if isinstance(ei, (int, str)) else 0)(chunk_data.get("end_index", 0)),
"start_index": (lambda si: int(si) if isinstance(si, (int, str)) else 0)(
chunk_data.get("start_index", 0)
),
"end_index": (lambda ei: int(ei) if isinstance(ei, (int, str)) else 0)(
chunk_data.get("end_index", 0)
),
"metadata": (
dict(metadata_val) if (metadata_val := chunk_data.get("metadata")) and isinstance(metadata_val, dict) else {}
dict(metadata_val)
if (metadata_val := chunk_data.get("metadata"))
and isinstance(metadata_val, dict)
else {}
),
}
self.chunks.append(chunk_info)
@@ -275,10 +282,10 @@ class EntityGraph(Widget):
"""Show detailed information about an entity."""
details_widget = self.query_one("#entity_details", Static)
details_text = f"""**Entity:** {entity['name']}
**Type:** {entity['type']}
**Confidence:** {entity['confidence']:.2%}
**ID:** {entity['id']}
details_text = f"""**Entity:** {entity["name"]}
**Type:** {entity["type"]}
**Confidence:** {entity["confidence"]:.2%}
**ID:** {entity["id"]}
**Metadata:**
"""

View File

@@ -26,9 +26,15 @@ def _setup_prefect_settings() -> tuple[object, object, object]:
if registry is not None:
Setting = getattr(ps, "Setting", None)
if Setting is not None:
api_key = registry.get("PREFECT_API_KEY") or Setting("PREFECT_API_KEY", type_=str, default=None)
api_url = registry.get("PREFECT_API_URL") or Setting("PREFECT_API_URL", type_=str, default=None)
work_pool = registry.get("PREFECT_DEFAULT_WORK_POOL_NAME") or Setting("PREFECT_DEFAULT_WORK_POOL_NAME", type_=str, default=None)
api_key = registry.get("PREFECT_API_KEY") or Setting(
"PREFECT_API_KEY", type_=str, default=None
)
api_url = registry.get("PREFECT_API_URL") or Setting(
"PREFECT_API_URL", type_=str, default=None
)
work_pool = registry.get("PREFECT_DEFAULT_WORK_POOL_NAME") or Setting(
"PREFECT_DEFAULT_WORK_POOL_NAME", type_=str, default=None
)
return api_key, api_url, work_pool
except ImportError:
@@ -37,6 +43,7 @@ def _setup_prefect_settings() -> tuple[object, object, object]:
# Ultimate fallback
return None, None, None
PREFECT_API_KEY, PREFECT_API_URL, PREFECT_DEFAULT_WORK_POOL_NAME = _setup_prefect_settings()
# Import after Prefect settings setup to avoid circular dependencies
@@ -53,20 +60,30 @@ def configure_prefect(settings: Settings) -> None:
overrides: dict[Setting, str] = {}
if settings.prefect_api_url is not None and PREFECT_API_URL is not None and isinstance(PREFECT_API_URL, Setting):
if (
settings.prefect_api_url is not None
and PREFECT_API_URL is not None
and isinstance(PREFECT_API_URL, Setting)
):
overrides[PREFECT_API_URL] = str(settings.prefect_api_url)
if settings.prefect_api_key and PREFECT_API_KEY is not None and isinstance(PREFECT_API_KEY, Setting):
if (
settings.prefect_api_key
and PREFECT_API_KEY is not None
and isinstance(PREFECT_API_KEY, Setting)
):
overrides[PREFECT_API_KEY] = settings.prefect_api_key
if settings.prefect_work_pool and PREFECT_DEFAULT_WORK_POOL_NAME is not None and isinstance(PREFECT_DEFAULT_WORK_POOL_NAME, Setting):
if (
settings.prefect_work_pool
and PREFECT_DEFAULT_WORK_POOL_NAME is not None
and isinstance(PREFECT_DEFAULT_WORK_POOL_NAME, Setting)
):
overrides[PREFECT_DEFAULT_WORK_POOL_NAME] = settings.prefect_work_pool
if not overrides:
return
filtered_overrides = {
setting: value
for setting, value in overrides.items()
if setting.value() != value
setting: value for setting, value in overrides.items() if setting.value() != value
}
if not filtered_overrides:

View File

@@ -139,10 +139,11 @@ class Settings(BaseSettings):
key_name, key_value = required_keys[backend]
if not key_value:
import warnings
warnings.warn(
f"{key_name} not set - authentication may fail for {backend} backend",
UserWarning,
stacklevel=2
stacklevel=2,
)
return self
@@ -165,16 +166,24 @@ class PrefectVariableConfig:
def __init__(self) -> None:
self._settings: Settings = get_settings()
self._variable_names: list[str] = [
"default_batch_size", "max_file_size", "max_crawl_depth", "max_crawl_pages",
"default_storage_backend", "default_collection_prefix", "max_concurrent_tasks",
"request_timeout", "default_schedule_interval"
"default_batch_size",
"max_file_size",
"max_crawl_depth",
"max_crawl_pages",
"default_storage_backend",
"default_collection_prefix",
"max_concurrent_tasks",
"request_timeout",
"default_schedule_interval",
]
def _get_fallback_value(self, name: str, default_value: object = None) -> object:
"""Get fallback value from settings or default."""
return default_value or getattr(self._settings, name, default_value)
def get_with_fallback(self, name: str, default_value: str | int | float | None = None) -> str | int | float | None:
def get_with_fallback(
self, name: str, default_value: str | int | float | None = None
) -> str | int | float | None:
"""Get variable value with fallback synchronously."""
fallback = self._get_fallback_value(name, default_value)
# Ensure fallback is a type that Variable expects
@@ -195,7 +204,9 @@ class PrefectVariableConfig:
return fallback
return str(fallback) if fallback is not None else None
async def get_with_fallback_async(self, name: str, default_value: str | int | float | None = None) -> str | int | float | None:
async def get_with_fallback_async(
self, name: str, default_value: str | int | float | None = None
) -> str | int | float | None:
"""Get variable value with fallback asynchronously."""
fallback = self._get_fallback_value(name, default_value)
variable_fallback = str(fallback) if fallback is not None else None

View File

@@ -12,35 +12,36 @@ from ..config import get_settings
def _default_embedding_model() -> str:
return get_settings().embedding_model
return str(get_settings().embedding_model)
def _default_embedding_endpoint() -> HttpUrl:
return get_settings().llm_endpoint
endpoint = get_settings().llm_endpoint
return endpoint if isinstance(endpoint, HttpUrl) else HttpUrl(str(endpoint))
def _default_embedding_dimension() -> int:
return get_settings().embedding_dimension
return int(get_settings().embedding_dimension)
def _default_batch_size() -> int:
return get_settings().default_batch_size
return int(get_settings().default_batch_size)
def _default_collection_name() -> str:
return get_settings().default_collection_prefix
return str(get_settings().default_collection_prefix)
def _default_max_crawl_depth() -> int:
return get_settings().max_crawl_depth
return int(get_settings().max_crawl_depth)
def _default_max_crawl_pages() -> int:
return get_settings().max_crawl_pages
return int(get_settings().max_crawl_pages)
def _default_max_file_size() -> int:
return get_settings().max_file_size
return int(get_settings().max_file_size)
class IngestionStatus(str, Enum):
@@ -84,13 +85,16 @@ class StorageConfig(Block):
_block_type_name: ClassVar[str | None] = "Storage Configuration"
_block_type_slug: ClassVar[str | None] = "storage-config"
_description: ClassVar[str | None] = "Configures storage backend connections and settings for document ingestion"
_description: ClassVar[str | None] = (
"Configures storage backend connections and settings for document ingestion"
)
backend: StorageBackend
endpoint: HttpUrl
api_key: SecretStr | None = Field(default=None)
collection_name: str = Field(default_factory=_default_collection_name)
batch_size: Annotated[int, Field(gt=0, le=1000)] = Field(default_factory=_default_batch_size)
grpc_port: int | None = Field(default=None, description="gRPC port for Weaviate connections")
class FirecrawlConfig(Block):
@@ -112,7 +116,9 @@ class RepomixConfig(Block):
_block_type_name: ClassVar[str | None] = "Repomix Configuration"
_block_type_slug: ClassVar[str | None] = "repomix-config"
_description: ClassVar[str | None] = "Configures repository ingestion patterns and file processing settings"
_description: ClassVar[str | None] = (
"Configures repository ingestion patterns and file processing settings"
)
include_patterns: list[str] = Field(
default_factory=lambda: ["*.py", "*.js", "*.ts", "*.md", "*.yaml", "*.json"]
@@ -129,7 +135,9 @@ class R2RConfig(Block):
_block_type_name: ClassVar[str | None] = "R2R Configuration"
_block_type_slug: ClassVar[str | None] = "r2r-config"
_description: ClassVar[str | None] = "Configures R2R-specific ingestion settings including chunking and graph enrichment"
_description: ClassVar[str | None] = (
"Configures R2R-specific ingestion settings including chunking and graph enrichment"
)
chunk_size: Annotated[int, Field(ge=100, le=8192)] = 1000
chunk_overlap: Annotated[int, Field(ge=0, le=1000)] = 200
@@ -139,6 +147,7 @@ class R2RConfig(Block):
class DocumentMetadataRequired(TypedDict):
"""Required metadata fields for a document."""
source_url: str
timestamp: datetime
content_type: str

View File

@@ -103,8 +103,13 @@ async def initialize_storage_task(config: StorageConfig | str) -> BaseStorage:
return storage
@task(name="map_firecrawl_site", retries=2, retry_delay_seconds=15, tags=["firecrawl", "map"],
cache_key_fn=lambda ctx, p: _safe_cache_key("firecrawl_map", p, "source_url"))
@task(
name="map_firecrawl_site",
retries=2,
retry_delay_seconds=15,
tags=["firecrawl", "map"],
cache_key_fn=lambda ctx, p: _safe_cache_key("firecrawl_map", p, "source_url"),
)
async def map_firecrawl_site_task(source_url: str, config: FirecrawlConfig | str) -> list[str]:
"""Map a site using Firecrawl and return discovered URLs."""
# Load block if string provided
@@ -118,8 +123,13 @@ async def map_firecrawl_site_task(source_url: str, config: FirecrawlConfig | str
return mapped or [source_url]
@task(name="filter_existing_documents", retries=1, retry_delay_seconds=5, tags=["dedup"],
cache_key_fn=lambda ctx, p: _safe_cache_key("filter_docs", p, "urls")) # Cache based on URL list
@task(
name="filter_existing_documents",
retries=1,
retry_delay_seconds=5,
tags=["dedup"],
cache_key_fn=lambda ctx, p: _safe_cache_key("filter_docs", p, "urls"),
) # Cache based on URL list
async def filter_existing_documents_task(
urls: list[str],
storage_client: BaseStorage,
@@ -128,19 +138,40 @@ async def filter_existing_documents_task(
collection_name: str | None = None,
) -> list[str]:
"""Filter URLs to only those that need scraping (missing or stale in storage)."""
import asyncio
logger = get_run_logger()
eligible: list[str] = []
for url in urls:
document_id = str(FirecrawlIngestor.compute_document_id(url))
exists = await storage_client.check_exists(
document_id,
collection_name=collection_name,
stale_after_days=stale_after_days
)
# Use semaphore to limit concurrent existence checks
semaphore = asyncio.Semaphore(20)
if not exists:
eligible.append(url)
async def check_url_exists(url: str) -> tuple[str, bool]:
async with semaphore:
try:
document_id = str(FirecrawlIngestor.compute_document_id(url))
exists = await storage_client.check_exists(
document_id, collection_name=collection_name, stale_after_days=stale_after_days
)
return url, exists
except Exception as e:
logger.warning("Error checking existence for URL %s: %s", url, e)
# Assume doesn't exist on error to ensure we scrape it
return url, False
# Check all URLs in parallel - use return_exceptions=True for partial failure handling
results = await asyncio.gather(*[check_url_exists(url) for url in urls], return_exceptions=True)
# Collect URLs that need scraping, handling any exceptions
eligible = []
for result in results:
if isinstance(result, Exception):
logger.error("Unexpected error in parallel existence check: %s", result)
continue
# Type narrowing: result is now known to be tuple[str, bool]
if isinstance(result, tuple) and len(result) == 2:
url, exists = result
if not exists:
eligible.append(url)
skipped = len(urls) - len(eligible)
if skipped > 0:
@@ -260,7 +291,9 @@ async def ingest_documents_task(
if progress_callback:
progress_callback(40, "Starting document processing...")
return await _process_documents(ingestor, storage, job, batch_size, collection_name, progress_callback)
return await _process_documents(
ingestor, storage, job, batch_size, collection_name, progress_callback
)
async def _create_ingestor(job: IngestionJob, config_block_name: str | None = None) -> BaseIngestor:
@@ -287,7 +320,9 @@ async def _create_ingestor(job: IngestionJob, config_block_name: str | None = No
raise ValueError(f"Unsupported source: {job.source_type}")
async def _create_storage(job: IngestionJob, collection_name: str | None, storage_block_name: str | None = None) -> BaseStorage:
async def _create_storage(
job: IngestionJob, collection_name: str | None, storage_block_name: str | None = None
) -> BaseStorage:
"""Create and initialize storage client."""
if collection_name is None:
# Use variable for default collection prefix
@@ -303,6 +338,7 @@ async def _create_storage(job: IngestionJob, collection_name: str | None, storag
else:
# Fallback to building config from settings
from ..config import get_settings
settings = get_settings()
storage_config = _build_storage_config(job, settings, collection_name)
@@ -391,7 +427,7 @@ async def _process_documents(
progress_callback(45, "Ingesting documents from source...")
# Use smart ingestion with deduplication if storage supports it
if hasattr(storage, 'check_exists'):
if hasattr(storage, "check_exists"):
try:
# Try to use the smart ingestion method
document_generator = ingestor.ingest_with_dedup(
@@ -412,7 +448,7 @@ async def _process_documents(
if progress_callback:
progress_callback(
45 + min(35, (batch_count * 10)),
f"Processing batch {batch_count} ({total_documents} documents so far)..."
f"Processing batch {batch_count} ({total_documents} documents so far)...",
)
batch_processed, batch_failed = await _store_batch(storage, batch, collection_name)
@@ -482,7 +518,9 @@ async def _store_batch(
log_prints=True,
)
async def firecrawl_to_r2r_flow(
job: IngestionJob, collection_name: str | None = None, progress_callback: Callable[[int, str], None] | None = None
job: IngestionJob,
collection_name: str | None = None,
progress_callback: Callable[[int, str], None] | None = None,
) -> tuple[int, int]:
"""Specialized flow for Firecrawl ingestion into R2R."""
logger = get_run_logger()
@@ -549,10 +587,8 @@ async def firecrawl_to_r2r_flow(
# Use asyncio.gather for concurrent scraping
import asyncio
scrape_tasks = [
scrape_firecrawl_batch_task(batch, firecrawl_config)
for batch in url_batches
]
scrape_tasks = [scrape_firecrawl_batch_task(batch, firecrawl_config) for batch in url_batches]
batch_results = await asyncio.gather(*scrape_tasks)
scraped_pages: list[FirecrawlPage] = []
@@ -680,9 +716,13 @@ async def create_ingestion_flow(
progress_callback(30, "Starting document ingestion...")
print("Ingesting documents...")
if job.source_type == IngestionSource.WEB and job.storage_backend == StorageBackend.R2R:
processed, failed = await firecrawl_to_r2r_flow(job, collection_name, progress_callback=progress_callback)
processed, failed = await firecrawl_to_r2r_flow(
job, collection_name, progress_callback=progress_callback
)
else:
processed, failed = await ingest_documents_task(job, collection_name, progress_callback=progress_callback)
processed, failed = await ingest_documents_task(
job, collection_name, progress_callback=progress_callback
)
if progress_callback:
progress_callback(90, "Finalizing ingestion...")

View File

@@ -191,9 +191,10 @@ class FirecrawlIngestor(BaseIngestor):
"http://localhost"
):
# Self-hosted instance - try with api_url if supported
self.client = cast(AsyncFirecrawlClient, AsyncFirecrawl(
api_key=api_key, api_url=str(settings.firecrawl_endpoint)
))
self.client = cast(
AsyncFirecrawlClient,
AsyncFirecrawl(api_key=api_key, api_url=str(settings.firecrawl_endpoint)),
)
else:
# Cloud instance - use standard initialization
self.client = cast(AsyncFirecrawlClient, AsyncFirecrawl(api_key=api_key))
@@ -256,9 +257,7 @@ class FirecrawlIngestor(BaseIngestor):
for check_url in site_map:
document_id = str(self.compute_document_id(check_url))
exists = await storage_client.check_exists(
document_id,
collection_name=collection_name,
stale_after_days=stale_after_days
document_id, collection_name=collection_name, stale_after_days=stale_after_days
)
if not exists:
eligible_urls.append(check_url)
@@ -394,7 +393,9 @@ class FirecrawlIngestor(BaseIngestor):
# Extract basic metadata
title: str | None = getattr(metadata, "title", None) if metadata else None
description: str | None = getattr(metadata, "description", None) if metadata else None
description: str | None = (
getattr(metadata, "description", None) if metadata else None
)
# Extract enhanced metadata if available
author: str | None = getattr(metadata, "author", None) if metadata else None
@@ -402,26 +403,38 @@ class FirecrawlIngestor(BaseIngestor):
sitemap_last_modified: str | None = (
getattr(metadata, "sitemap_last_modified", None) if metadata else None
)
source_url: str | None = getattr(metadata, "sourceURL", None) if metadata else None
keywords: str | list[str] | None = getattr(metadata, "keywords", None) if metadata else None
source_url: str | None = (
getattr(metadata, "sourceURL", None) if metadata else None
)
keywords: str | list[str] | None = (
getattr(metadata, "keywords", None) if metadata else None
)
robots: str | None = getattr(metadata, "robots", None) if metadata else None
# Open Graph metadata
og_title: str | None = getattr(metadata, "ogTitle", None) if metadata else None
og_description: str | None = getattr(metadata, "ogDescription", None) if metadata else None
og_description: str | None = (
getattr(metadata, "ogDescription", None) if metadata else None
)
og_url: str | None = getattr(metadata, "ogUrl", None) if metadata else None
og_image: str | None = getattr(metadata, "ogImage", None) if metadata else None
# Twitter metadata
twitter_card: str | None = getattr(metadata, "twitterCard", None) if metadata else None
twitter_site: str | None = getattr(metadata, "twitterSite", None) if metadata else None
twitter_card: str | None = (
getattr(metadata, "twitterCard", None) if metadata else None
)
twitter_site: str | None = (
getattr(metadata, "twitterSite", None) if metadata else None
)
twitter_creator: str | None = (
getattr(metadata, "twitterCreator", None) if metadata else None
)
# Additional metadata
favicon: str | None = getattr(metadata, "favicon", None) if metadata else None
status_code: int | None = getattr(metadata, "statusCode", None) if metadata else None
status_code: int | None = (
getattr(metadata, "statusCode", None) if metadata else None
)
return FirecrawlPage(
url=url,
@@ -594,10 +607,16 @@ class FirecrawlIngestor(BaseIngestor):
"site_name": domain_info["site_name"],
# Document structure
"heading_hierarchy": (
list(hierarchy_val) if (hierarchy_val := structure_info.get("heading_hierarchy")) and isinstance(hierarchy_val, (list, tuple)) else []
list(hierarchy_val)
if (hierarchy_val := structure_info.get("heading_hierarchy"))
and isinstance(hierarchy_val, (list, tuple))
else []
),
"section_depth": (
int(depth_val) if (depth_val := structure_info.get("section_depth")) and isinstance(depth_val, (int, str)) else 0
int(depth_val)
if (depth_val := structure_info.get("section_depth"))
and isinstance(depth_val, (int, str))
else 0
),
"has_code_blocks": bool(structure_info.get("has_code_blocks", False)),
"has_images": bool(structure_info.get("has_images", False)),
@@ -632,7 +651,11 @@ class FirecrawlIngestor(BaseIngestor):
await self.client.close()
except Exception as e:
logging.debug(f"Error closing Firecrawl client: {e}")
elif hasattr(self.client, "_session") and self.client._session and hasattr(self.client._session, "close"):
elif (
hasattr(self.client, "_session")
and self.client._session
and hasattr(self.client._session, "close")
):
try:
await self.client._session.close()
except Exception as e:

View File

@@ -80,6 +80,7 @@ class RepomixIngestor(BaseIngestor):
return result.returncode == 0
except Exception as e:
import logging
logging.warning(f"Failed to validate repository {source_url}: {e}")
return False
@@ -97,9 +98,7 @@ class RepomixIngestor(BaseIngestor):
try:
with tempfile.TemporaryDirectory() as temp_dir:
# Shallow clone to get file count
repo_path = await self._clone_repository(
source_url, temp_dir, shallow=True
)
repo_path = await self._clone_repository(source_url, temp_dir, shallow=True)
# Count files matching patterns
file_count = 0
@@ -111,6 +110,7 @@ class RepomixIngestor(BaseIngestor):
return file_count
except Exception as e:
import logging
logging.warning(f"Failed to estimate size for repository {source_url}: {e}")
return 0
@@ -179,9 +179,7 @@ class RepomixIngestor(BaseIngestor):
return output_file
async def _parse_repomix_output(
self, output_file: Path, job: IngestionJob
) -> list[Document]:
async def _parse_repomix_output(self, output_file: Path, job: IngestionJob) -> list[Document]:
"""
Parse repomix output into documents.
@@ -210,14 +208,17 @@ class RepomixIngestor(BaseIngestor):
chunks = self._chunk_content(file_content)
for i, chunk in enumerate(chunks):
doc = self._create_document(
file_path, chunk, job, chunk_index=i,
git_metadata=git_metadata, repo_info=repo_info
file_path,
chunk,
job,
chunk_index=i,
git_metadata=git_metadata,
repo_info=repo_info,
)
documents.append(doc)
else:
doc = self._create_document(
file_path, file_content, job,
git_metadata=git_metadata, repo_info=repo_info
file_path, file_content, job, git_metadata=git_metadata, repo_info=repo_info
)
documents.append(doc)
@@ -300,64 +301,64 @@ class RepomixIngestor(BaseIngestor):
# Map common extensions to languages
ext_map = {
'.py': 'python',
'.js': 'javascript',
'.ts': 'typescript',
'.jsx': 'javascript',
'.tsx': 'typescript',
'.java': 'java',
'.c': 'c',
'.cpp': 'cpp',
'.cc': 'cpp',
'.cxx': 'cpp',
'.h': 'c',
'.hpp': 'cpp',
'.cs': 'csharp',
'.go': 'go',
'.rs': 'rust',
'.php': 'php',
'.rb': 'ruby',
'.swift': 'swift',
'.kt': 'kotlin',
'.scala': 'scala',
'.sh': 'shell',
'.bash': 'shell',
'.zsh': 'shell',
'.sql': 'sql',
'.html': 'html',
'.css': 'css',
'.scss': 'scss',
'.less': 'less',
'.yaml': 'yaml',
'.yml': 'yaml',
'.json': 'json',
'.xml': 'xml',
'.md': 'markdown',
'.txt': 'text',
'.cfg': 'config',
'.ini': 'config',
'.toml': 'toml',
".py": "python",
".js": "javascript",
".ts": "typescript",
".jsx": "javascript",
".tsx": "typescript",
".java": "java",
".c": "c",
".cpp": "cpp",
".cc": "cpp",
".cxx": "cpp",
".h": "c",
".hpp": "cpp",
".cs": "csharp",
".go": "go",
".rs": "rust",
".php": "php",
".rb": "ruby",
".swift": "swift",
".kt": "kotlin",
".scala": "scala",
".sh": "shell",
".bash": "shell",
".zsh": "shell",
".sql": "sql",
".html": "html",
".css": "css",
".scss": "scss",
".less": "less",
".yaml": "yaml",
".yml": "yaml",
".json": "json",
".xml": "xml",
".md": "markdown",
".txt": "text",
".cfg": "config",
".ini": "config",
".toml": "toml",
}
if extension in ext_map:
return ext_map[extension]
# Try to detect from shebang
if content.startswith('#!'):
first_line = content.split('\n')[0]
if 'python' in first_line:
return 'python'
elif 'node' in first_line or 'javascript' in first_line:
return 'javascript'
elif 'bash' in first_line or 'sh' in first_line:
return 'shell'
if content.startswith("#!"):
first_line = content.split("\n")[0]
if "python" in first_line:
return "python"
elif "node" in first_line or "javascript" in first_line:
return "javascript"
elif "bash" in first_line or "sh" in first_line:
return "shell"
return None
@staticmethod
def _analyze_code_structure(content: str, language: str | None) -> dict[str, object]:
"""Analyze code structure and extract metadata."""
lines = content.split('\n')
lines = content.split("\n")
# Basic metrics
has_functions = False
@@ -366,29 +367,35 @@ class RepomixIngestor(BaseIngestor):
has_comments = False
# Language-specific patterns
if language == 'python':
has_functions = bool(re.search(r'^\s*def\s+\w+', content, re.MULTILINE))
has_classes = bool(re.search(r'^\s*class\s+\w+', content, re.MULTILINE))
has_imports = bool(re.search(r'^\s*(import|from)\s+', content, re.MULTILINE))
has_comments = bool(re.search(r'^\s*#', content, re.MULTILINE))
elif language in ['javascript', 'typescript']:
has_functions = bool(re.search(r'(function\s+\w+|^\s*\w+\s*:\s*function|\w+\s*=>\s*)', content, re.MULTILINE))
has_classes = bool(re.search(r'^\s*class\s+\w+', content, re.MULTILINE))
has_imports = bool(re.search(r'^\s*(import|require)', content, re.MULTILINE))
has_comments = bool(re.search(r'//|/\*', content))
elif language == 'java':
has_functions = bool(re.search(r'(public|private|protected).*\w+\s*\(', content, re.MULTILINE))
has_classes = bool(re.search(r'(public|private)?\s*class\s+\w+', content, re.MULTILINE))
has_imports = bool(re.search(r'^\s*import\s+', content, re.MULTILINE))
has_comments = bool(re.search(r'//|/\*', content))
if language == "python":
has_functions = bool(re.search(r"^\s*def\s+\w+", content, re.MULTILINE))
has_classes = bool(re.search(r"^\s*class\s+\w+", content, re.MULTILINE))
has_imports = bool(re.search(r"^\s*(import|from)\s+", content, re.MULTILINE))
has_comments = bool(re.search(r"^\s*#", content, re.MULTILINE))
elif language in ["javascript", "typescript"]:
has_functions = bool(
re.search(
r"(function\s+\w+|^\s*\w+\s*:\s*function|\w+\s*=>\s*)", content, re.MULTILINE
)
)
has_classes = bool(re.search(r"^\s*class\s+\w+", content, re.MULTILINE))
has_imports = bool(re.search(r"^\s*(import|require)", content, re.MULTILINE))
has_comments = bool(re.search(r"//|/\*", content))
elif language == "java":
has_functions = bool(
re.search(r"(public|private|protected).*\w+\s*\(", content, re.MULTILINE)
)
has_classes = bool(re.search(r"(public|private)?\s*class\s+\w+", content, re.MULTILINE))
has_imports = bool(re.search(r"^\s*import\s+", content, re.MULTILINE))
has_comments = bool(re.search(r"//|/\*", content))
return {
'has_functions': has_functions,
'has_classes': has_classes,
'has_imports': has_imports,
'has_comments': has_comments,
'line_count': len(lines),
'non_empty_lines': len([line for line in lines if line.strip()]),
"has_functions": has_functions,
"has_classes": has_classes,
"has_imports": has_imports,
"has_comments": has_comments,
"line_count": len(lines),
"non_empty_lines": len([line for line in lines if line.strip()]),
}
@staticmethod
@@ -399,18 +406,16 @@ class RepomixIngestor(BaseIngestor):
org_name = None
# Handle different URL formats
if 'github.com' in repo_url or 'gitlab.com' in repo_url:
if path_match := re.search(
r'/([^/]+)/([^/]+?)(?:\.git)?/?$', repo_url
):
if "github.com" in repo_url or "gitlab.com" in repo_url:
if path_match := re.search(r"/([^/]+)/([^/]+?)(?:\.git)?/?$", repo_url):
org_name = path_match[1]
repo_name = path_match[2]
elif path_match := re.search(r'/([^/]+?)(?:\.git)?/?$', repo_url):
elif path_match := re.search(r"/([^/]+?)(?:\.git)?/?$", repo_url):
repo_name = path_match[1]
return {
'repository_name': repo_name or 'unknown',
'organization': org_name or 'unknown',
"repository_name": repo_name or "unknown",
"organization": org_name or "unknown",
}
async def _get_git_metadata(self, repo_path: Path) -> dict[str, str | None]:
@@ -418,31 +423,35 @@ class RepomixIngestor(BaseIngestor):
try:
# Get current branch
branch_result = await self._run_command(
['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
cwd=str(repo_path),
timeout=5
["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd=str(repo_path), timeout=5
)
branch_name = (
branch_result.stdout.decode().strip() if branch_result.returncode == 0 else None
)
branch_name = branch_result.stdout.decode().strip() if branch_result.returncode == 0 else None
# Get current commit hash
commit_result = await self._run_command(
['git', 'rev-parse', 'HEAD'],
cwd=str(repo_path),
timeout=5
["git", "rev-parse", "HEAD"], cwd=str(repo_path), timeout=5
)
commit_hash = (
commit_result.stdout.decode().strip() if commit_result.returncode == 0 else None
)
commit_hash = commit_result.stdout.decode().strip() if commit_result.returncode == 0 else None
return {
'branch_name': branch_name,
'commit_hash': commit_hash[:8] if commit_hash else None, # Short hash
"branch_name": branch_name,
"commit_hash": commit_hash[:8] if commit_hash else None, # Short hash
}
except Exception:
return {'branch_name': None, 'commit_hash': None}
return {"branch_name": None, "commit_hash": None}
def _create_document(
self, file_path: str, content: str, job: IngestionJob, chunk_index: int = 0,
self,
file_path: str,
content: str,
job: IngestionJob,
chunk_index: int = 0,
git_metadata: dict[str, str | None] | None = None,
repo_info: dict[str, str] | None = None
repo_info: dict[str, str] | None = None,
) -> Document:
"""
Create a Document from repository content with enriched metadata.
@@ -466,19 +475,19 @@ class RepomixIngestor(BaseIngestor):
# Determine content type based on language
content_type_map = {
'python': 'text/x-python',
'javascript': 'text/javascript',
'typescript': 'text/typescript',
'java': 'text/x-java-source',
'html': 'text/html',
'css': 'text/css',
'json': 'application/json',
'yaml': 'text/yaml',
'markdown': 'text/markdown',
'shell': 'text/x-shellscript',
'sql': 'text/x-sql',
"python": "text/x-python",
"javascript": "text/javascript",
"typescript": "text/typescript",
"java": "text/x-java-source",
"html": "text/html",
"css": "text/css",
"json": "application/json",
"yaml": "text/yaml",
"markdown": "text/markdown",
"shell": "text/x-shellscript",
"sql": "text/x-sql",
}
content_type = content_type_map.get(programming_language or 'text', 'text/plain')
content_type = content_type_map.get(programming_language or "text", "text/plain")
# Build rich metadata
metadata: DocumentMetadata = {
@@ -487,8 +496,7 @@ class RepomixIngestor(BaseIngestor):
"content_type": content_type,
"word_count": len(content.split()),
"char_count": len(content),
"title": f"{file_path}"
+ (f" (chunk {chunk_index})" if chunk_index > 0 else ""),
"title": f"{file_path}" + (f" (chunk {chunk_index})" if chunk_index > 0 else ""),
"description": f"Repository file: {file_path}",
"category": "source_code" if programming_language else "documentation",
"language": programming_language or "text",
@@ -500,19 +508,23 @@ class RepomixIngestor(BaseIngestor):
# Add repository info if available
if repo_info:
metadata["repository_name"] = repo_info.get('repository_name')
metadata["repository_name"] = repo_info.get("repository_name")
# Add git metadata if available
if git_metadata:
metadata["branch_name"] = git_metadata.get('branch_name')
metadata["commit_hash"] = git_metadata.get('commit_hash')
metadata["branch_name"] = git_metadata.get("branch_name")
metadata["commit_hash"] = git_metadata.get("commit_hash")
# Add code-specific metadata for programming files
if programming_language and structure_info:
# Calculate code quality score
total_lines = structure_info.get('line_count', 1)
non_empty_lines = structure_info.get('non_empty_lines', 0)
if isinstance(total_lines, int) and isinstance(non_empty_lines, int) and total_lines > 0:
total_lines = structure_info.get("line_count", 1)
non_empty_lines = structure_info.get("non_empty_lines", 0)
if (
isinstance(total_lines, int)
and isinstance(non_empty_lines, int)
and total_lines > 0
):
completeness_score = (non_empty_lines / total_lines) * 100
metadata["completeness_score"] = completeness_score
@@ -566,5 +578,6 @@ class RepomixIngestor(BaseIngestor):
await asyncio.wait_for(proc.wait(), timeout=2.0)
except TimeoutError:
import logging
logging.warning(f"Process {proc.pid} did not terminate cleanly")
raise IngestionError(f"Command timed out: {' '.join(cmd)}") from e

View File

@@ -11,6 +11,7 @@ if TYPE_CHECKING:
try:
from .r2r.storage import R2RStorage as _RuntimeR2RStorage
R2RStorage: type[BaseStorage] | None = _RuntimeR2RStorage
except ImportError:
R2RStorage = None

View File

@@ -1,10 +1,12 @@
"""Base storage interface."""
import asyncio
import logging
import random
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from typing import Final
from types import TracebackType
from typing import Final
import httpx
from pydantic import SecretStr
@@ -37,6 +39,8 @@ class TypedHttpClient:
api_key: SecretStr | None = None,
timeout: float = 30.0,
headers: dict[str, str] | None = None,
max_connections: int = 100,
max_keepalive_connections: int = 20,
):
"""
Initialize the typed HTTP client.
@@ -46,6 +50,8 @@ class TypedHttpClient:
api_key: Optional API key for authentication
timeout: Request timeout in seconds
headers: Additional headers to include with requests
max_connections: Maximum total connections in pool
max_keepalive_connections: Maximum keepalive connections
"""
self._base_url = base_url
@@ -54,14 +60,14 @@ class TypedHttpClient:
if api_key:
client_headers["Authorization"] = f"Bearer {api_key.get_secret_value()}"
# Note: Pylance incorrectly reports "No parameter named 'base_url'"
# but base_url is a valid AsyncClient parameter (see HTTPX docs)
client_kwargs: dict[str, str | dict[str, str] | float] = {
"base_url": base_url,
"headers": client_headers,
"timeout": timeout,
}
self.client = httpx.AsyncClient(**client_kwargs) # type: ignore
# Create typed client configuration with connection pooling
limits = httpx.Limits(
max_connections=max_connections, max_keepalive_connections=max_keepalive_connections
)
timeout_config = httpx.Timeout(connect=5.0, read=timeout, write=30.0, pool=10.0)
self.client = httpx.AsyncClient(
base_url=base_url, headers=client_headers, timeout=timeout_config, limits=limits
)
async def request(
self,
@@ -73,44 +79,92 @@ class TypedHttpClient:
data: dict[str, object] | None = None,
files: dict[str, tuple[str, bytes, str]] | None = None,
params: dict[str, str | bool] | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
) -> httpx.Response | None:
"""
Perform an HTTP request with consistent error handling.
Perform an HTTP request with consistent error handling and retries.
Args:
method: HTTP method (GET, POST, DELETE, etc.)
path: URL path relative to base_url
allow_404: If True, return None for 404 responses instead of raising
**kwargs: Arguments passed to httpx request
json: JSON data to send
data: Form data to send
files: Files to upload
params: Query parameters
max_retries: Maximum number of retry attempts
retry_delay: Base delay between retries in seconds
Returns:
HTTP response object, or None if allow_404=True and status is 404
Raises:
StorageError: If request fails
StorageError: If request fails after retries
"""
try:
response = await self.client.request( # type: ignore
method, path, json=json, data=data, files=files, params=params
)
response.raise_for_status() # type: ignore
return response # type: ignore
except Exception as e:
# Handle 404 as special case if requested
if allow_404 and hasattr(e, 'response') and getattr(e.response, 'status_code', None) == 404: # type: ignore
LOGGER.debug("Resource not found (404): %s %s", method, path)
return None
last_exception: Exception | None = None
# Convert all HTTP-related exceptions to StorageError
error_name = e.__class__.__name__
if 'HTTP' in error_name or 'Connect' in error_name or 'Request' in error_name:
if hasattr(e, 'response') and hasattr(e.response, 'status_code'): # type: ignore
status_code = getattr(e.response, 'status_code', 'unknown') # type: ignore
raise StorageError(f"HTTP {status_code} error from {self._base_url}: {e}") from e
else:
raise StorageError(f"Request failed to {self._base_url}: {e}") from e
# Re-raise non-HTTP exceptions
raise
for attempt in range(max_retries + 1):
try:
response = await self.client.request(
method, path, json=json, data=data, files=files, params=params
)
response.raise_for_status()
return response
except httpx.HTTPStatusError as e:
# Handle 404 as special case if requested
if allow_404 and e.response.status_code == 404:
LOGGER.debug("Resource not found (404): %s %s", method, path)
return None
# Don't retry client errors (4xx except for specific cases)
if 400 <= e.response.status_code < 500 and e.response.status_code not in [429, 408]:
raise StorageError(
f"HTTP {e.response.status_code} error from {self._base_url}: {e}"
) from e
last_exception = e
if attempt < max_retries:
# Exponential backoff with jitter for retryable errors
delay = retry_delay * (2**attempt) + random.uniform(0, 1)
LOGGER.warning(
"HTTP %d error on attempt %d/%d, retrying in %.2fs: %s",
e.response.status_code,
attempt + 1,
max_retries + 1,
delay,
e,
)
await asyncio.sleep(delay)
except httpx.HTTPError as e:
last_exception = e
if attempt < max_retries:
# Retry transport errors with backoff
delay = retry_delay * (2**attempt) + random.uniform(0, 1)
LOGGER.warning(
"HTTP transport error on attempt %d/%d, retrying in %.2fs: %s",
attempt + 1,
max_retries + 1,
delay,
e,
)
await asyncio.sleep(delay)
# All retries exhausted - last_exception should always be set if we reach here
if last_exception is None:
raise StorageError(
f"Request to {self._base_url} failed after {max_retries + 1} attempts with unknown error"
)
if isinstance(last_exception, httpx.HTTPStatusError):
raise StorageError(
f"HTTP {last_exception.response.status_code} error from {self._base_url} after {max_retries + 1} attempts: {last_exception}"
) from last_exception
else:
raise StorageError(
f"HTTP transport error to {self._base_url} after {max_retries + 1} attempts: {last_exception}"
) from last_exception
async def close(self) -> None:
"""Close the HTTP client and cleanup resources."""
@@ -127,7 +181,7 @@ class TypedHttpClient:
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None
exc_tb: TracebackType | None,
) -> None:
"""Async context manager exit."""
await self.close()
@@ -224,11 +278,30 @@ class BaseStorage(ABC):
# Check staleness if timestamp is available
if "timestamp" in document.metadata:
from datetime import UTC, datetime, timedelta
timestamp_obj = document.metadata["timestamp"]
# Handle both datetime objects and ISO strings
if isinstance(timestamp_obj, datetime):
timestamp = timestamp_obj
cutoff = datetime.now(UTC) - timedelta(days=stale_after_days)
return timestamp >= cutoff
# Ensure timezone awareness
if timestamp.tzinfo is None:
timestamp = timestamp.replace(tzinfo=UTC)
elif isinstance(timestamp_obj, str):
try:
timestamp = datetime.fromisoformat(timestamp_obj)
# Ensure timezone awareness
if timestamp.tzinfo is None:
timestamp = timestamp.replace(tzinfo=UTC)
except ValueError:
# If parsing fails, assume document is stale
return False
else:
# Unknown timestamp format, assume stale
return False
cutoff = datetime.now(UTC) - timedelta(days=stale_after_days)
return timestamp >= cutoff
# If no timestamp, assume it exists and is valid
return True

View File

@@ -1,10 +1,10 @@
"""Open WebUI storage adapter."""
import asyncio
import contextlib
import logging
from typing import Final, TypedDict, cast
import time
from typing import Final, NamedTuple, TypedDict
from typing_extensions import override
@@ -18,6 +18,7 @@ LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
class OpenWebUIFileResponse(TypedDict, total=False):
"""OpenWebUI API file response structure."""
id: str
filename: str
name: str
@@ -29,6 +30,7 @@ class OpenWebUIFileResponse(TypedDict, total=False):
class OpenWebUIKnowledgeBase(TypedDict, total=False):
"""OpenWebUI knowledge base response structure."""
id: str
name: str
description: str
@@ -38,13 +40,19 @@ class OpenWebUIKnowledgeBase(TypedDict, total=False):
updated_at: str
class CacheEntry(NamedTuple):
"""Cache entry with value and expiration time."""
value: str
expires_at: float
class OpenWebUIStorage(BaseStorage):
"""Storage adapter for Open WebUI knowledge endpoints."""
http_client: TypedHttpClient
_knowledge_cache: dict[str, str]
_knowledge_cache: dict[str, CacheEntry]
_cache_ttl: float
def __init__(self, config: StorageConfig):
"""
@@ -61,6 +69,7 @@ class OpenWebUIStorage(BaseStorage):
timeout=30.0,
)
self._knowledge_cache = {}
self._cache_ttl = 300.0 # 5 minutes TTL
@override
async def initialize(self) -> None:
@@ -106,9 +115,25 @@ class OpenWebUIStorage(BaseStorage):
return []
normalized: list[OpenWebUIKnowledgeBase] = []
for item in data:
if isinstance(item, dict):
# Cast to our expected structure
kb_item = cast(OpenWebUIKnowledgeBase, item)
if (
isinstance(item, dict)
and "id" in item
and "name" in item
and isinstance(item["id"], str)
and isinstance(item["name"], str)
):
# Create a new dict with known structure
kb_item: OpenWebUIKnowledgeBase = {
"id": item["id"],
"name": item["name"],
"description": item.get("description", ""),
"created_at": item.get("created_at", ""),
"updated_at": item.get("updated_at", ""),
}
if "files" in item and isinstance(item["files"], list):
kb_item["files"] = item["files"]
if "data" in item and isinstance(item["data"], dict):
kb_item["data"] = item["data"]
normalized.append(kb_item)
return normalized
@@ -124,22 +149,29 @@ class OpenWebUIStorage(BaseStorage):
if not target:
raise StorageError("Knowledge base name is required")
if cached := self._knowledge_cache.get(target):
return cached
# Check cache with TTL
if cached_entry := self._knowledge_cache.get(target):
if time.time() < cached_entry.expires_at:
return cached_entry.value
else:
# Entry expired, remove it
del self._knowledge_cache[target]
knowledge_bases = await self._fetch_knowledge_bases()
for kb in knowledge_bases:
if kb.get("name") == target:
kb_id = kb.get("id")
if isinstance(kb_id, str):
self._knowledge_cache[target] = kb_id
expires_at = time.time() + self._cache_ttl
self._knowledge_cache[target] = CacheEntry(kb_id, expires_at)
return kb_id
if not create:
return None
knowledge_id = await self._create_collection(target)
self._knowledge_cache[target] = knowledge_id
expires_at = time.time() + self._cache_ttl
self._knowledge_cache[target] = CacheEntry(knowledge_id, expires_at)
return knowledge_id
@override
@@ -165,7 +197,7 @@ class OpenWebUIStorage(BaseStorage):
# Use document title from metadata if available, otherwise fall back to ID
filename = document.metadata.get("title") or f"doc_{document.id}"
# Ensure filename has proper extension
if not filename.endswith(('.txt', '.md', '.pdf', '.doc', '.docx')):
if not filename.endswith((".txt", ".md", ".pdf", ".doc", ".docx")):
filename = f"{filename}.txt"
files = {"file": (filename, document.content.encode(), "text/plain")}
response = await self.http_client.request(
@@ -185,11 +217,9 @@ class OpenWebUIStorage(BaseStorage):
# Step 2: Add file to knowledge base
response = await self.http_client.request(
"POST",
f"/api/v1/knowledge/{knowledge_id}/file/add",
json={"file_id": file_id}
"POST", f"/api/v1/knowledge/{knowledge_id}/file/add", json={"file_id": file_id}
)
return str(file_id)
except Exception as e:
@@ -220,7 +250,7 @@ class OpenWebUIStorage(BaseStorage):
# Use document title from metadata if available, otherwise fall back to ID
filename = doc.metadata.get("title") or f"doc_{doc.id}"
# Ensure filename has proper extension
if not filename.endswith(('.txt', '.md', '.pdf', '.doc', '.docx')):
if not filename.endswith((".txt", ".md", ".pdf", ".doc", ".docx")):
filename = f"{filename}.txt"
files = {"file": (filename, doc.content.encode(), "text/plain")}
upload_response = await self.http_client.request(
@@ -230,7 +260,9 @@ class OpenWebUIStorage(BaseStorage):
params={"process": True, "process_in_background": False},
)
if upload_response is None:
raise StorageError(f"Unexpected None response from file upload for document {doc.id}")
raise StorageError(
f"Unexpected None response from file upload for document {doc.id}"
)
file_data = upload_response.json()
file_id = file_data.get("id")
@@ -241,9 +273,7 @@ class OpenWebUIStorage(BaseStorage):
)
await self.http_client.request(
"POST",
f"/api/v1/knowledge/{knowledge_id}/file/add",
json={"file_id": file_id}
"POST", f"/api/v1/knowledge/{knowledge_id}/file/add", json={"file_id": file_id}
)
return str(file_id)
@@ -259,7 +289,8 @@ class OpenWebUIStorage(BaseStorage):
if isinstance(result, Exception):
failures.append(f"{doc.id}: {result}")
else:
file_ids.append(cast(str, result))
if isinstance(result, str):
file_ids.append(result)
if failures:
LOGGER.warning(
@@ -289,10 +320,86 @@ class OpenWebUIStorage(BaseStorage):
"""
_ = document_id, collection_name # Mark as used
# OpenWebUI uses file-based storage without direct document retrieval
# This will cause the base check_exists method to return False,
# which means documents will always be re-scraped for OpenWebUI
raise NotImplementedError("OpenWebUI doesn't support document retrieval by ID")
@override
async def check_exists(
self, document_id: str, *, collection_name: str | None = None, stale_after_days: int = 30
) -> bool:
"""
Check if a document exists in OpenWebUI knowledge base by searching files.
Args:
document_id: Document ID to check (usually based on source URL)
collection_name: Knowledge base name
stale_after_days: Consider document stale after this many days
Returns:
True if document exists and is not stale, False otherwise
"""
try:
from datetime import UTC, datetime, timedelta
# Get knowledge base
knowledge_id = await self._get_knowledge_id(collection_name, create=False)
if not knowledge_id:
return False
# Get detailed knowledge base info to access files
response = await self.http_client.request("GET", f"/api/v1/knowledge/{knowledge_id}")
if response is None:
return False
kb_data = response.json()
files = kb_data.get("files", [])
# Look for file with matching document ID or source URL in metadata
cutoff = datetime.now(UTC) - timedelta(days=stale_after_days)
def _parse_openwebui_timestamp(timestamp_str: str) -> datetime | None:
"""Parse OpenWebUI timestamp with proper timezone handling."""
try:
# Handle both 'Z' suffix and explicit timezone
normalized = timestamp_str.replace("Z", "+00:00")
parsed = datetime.fromisoformat(normalized)
# Ensure timezone awareness
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=UTC)
return parsed
except (ValueError, AttributeError):
return None
def _check_file_freshness(file_info: dict[str, object]) -> bool:
"""Check if file is fresh enough based on creation date."""
created_at = file_info.get("created_at")
if not isinstance(created_at, str):
# No date info available, consider stale to be safe
return False
file_date = _parse_openwebui_timestamp(created_at)
return file_date is not None and file_date >= cutoff
for file_info in files:
if not isinstance(file_info, dict):
continue
file_id = file_info.get("id")
if str(file_id) == document_id:
return _check_file_freshness(file_info)
# Also check meta.source_url if available for URL-based document IDs
meta = file_info.get("meta", {})
if isinstance(meta, dict):
source_url = meta.get("source_url")
if source_url and document_id in str(source_url):
return _check_file_freshness(file_info)
return False
except Exception as e:
LOGGER.debug("Error checking document existence in OpenWebUI: %s", e)
return False
@override
async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool:
"""
@@ -316,14 +423,10 @@ class OpenWebUIStorage(BaseStorage):
await self.http_client.request(
"POST",
f"/api/v1/knowledge/{knowledge_id}/file/remove",
json={"file_id": document_id}
json={"file_id": document_id},
)
await self.http_client.request(
"DELETE",
f"/api/v1/files/{document_id}",
allow_404=True
)
await self.http_client.request("DELETE", f"/api/v1/files/{document_id}", allow_404=True)
return True
except Exception as exc:
LOGGER.error("Error deleting file %s from OpenWebUI", document_id, exc_info=exc)
@@ -366,9 +469,7 @@ class OpenWebUIStorage(BaseStorage):
# Delete the knowledge base using the OpenWebUI API
await self.http_client.request(
"DELETE",
f"/api/v1/knowledge/{knowledge_id}/delete",
allow_404=True
"DELETE", f"/api/v1/knowledge/{knowledge_id}/delete", allow_404=True
)
# Remove from cache if it exists
@@ -379,13 +480,16 @@ class OpenWebUIStorage(BaseStorage):
return True
except Exception as e:
if hasattr(e, 'response'):
response_attr = getattr(e, 'response', None)
if response_attr is not None and hasattr(response_attr, 'status_code'):
if hasattr(e, "response"):
response_attr = getattr(e, "response", None)
if response_attr is not None and hasattr(response_attr, "status_code"):
with contextlib.suppress(Exception):
status_code = response_attr.status_code # type: ignore[attr-defined]
status_code = response_attr.status_code
if status_code == 404:
LOGGER.info("Knowledge base %s was already deleted or not found", collection_name)
LOGGER.info(
"Knowledge base %s was already deleted or not found",
collection_name,
)
return True
LOGGER.error(
"Error deleting knowledge base %s from OpenWebUI",
@@ -394,8 +498,6 @@ class OpenWebUIStorage(BaseStorage):
)
return False
async def _get_knowledge_base_count(self, kb: OpenWebUIKnowledgeBase) -> int:
"""Get the file count for a knowledge base."""
kb_id = kb.get("id")
@@ -411,14 +513,13 @@ class OpenWebUIStorage(BaseStorage):
files = kb.get("files", [])
return len(files) if isinstance(files, list) and files is not None else 0
async def _count_files_from_detailed_info(self, kb_id: str, name: str, kb: OpenWebUIKnowledgeBase) -> int:
async def _count_files_from_detailed_info(
self, kb_id: str, name: str, kb: OpenWebUIKnowledgeBase
) -> int:
"""Count files by fetching detailed knowledge base info."""
try:
LOGGER.debug(f"Fetching detailed info for KB '{name}' from /api/v1/knowledge/{kb_id}")
detail_response = await self.http_client.request(
"GET",
f"/api/v1/knowledge/{kb_id}"
)
detail_response = await self.http_client.request("GET", f"/api/v1/knowledge/{kb_id}")
if detail_response is None:
LOGGER.warning(f"Knowledge base '{name}' (ID: {kb_id}) not found")
return self._count_files_from_basic_info(kb)
@@ -489,10 +590,7 @@ class OpenWebUIStorage(BaseStorage):
return 0
# Get detailed knowledge base information to get accurate file count
detail_response = await self.http_client.request(
"GET",
f"/api/v1/knowledge/{kb_id}"
)
detail_response = await self.http_client.request("GET", f"/api/v1/knowledge/{kb_id}")
if detail_response is None:
LOGGER.warning(f"Knowledge base '{collection_name}' (ID: {kb_id}) not found")
return self._count_files_from_basic_info(kb)
@@ -524,14 +622,28 @@ class OpenWebUIStorage(BaseStorage):
return None
knowledge_bases = response.json()
return next(
(
cast(OpenWebUIKnowledgeBase, kb)
for kb in knowledge_bases
if isinstance(kb, dict) and kb.get("name") == name
),
None,
)
# Find and properly type the knowledge base
for kb in knowledge_bases:
if (
isinstance(kb, dict)
and kb.get("name") == name
and "id" in kb
and isinstance(kb["id"], str)
):
# Create properly typed response
result: OpenWebUIKnowledgeBase = {
"id": kb["id"],
"name": str(kb["name"]),
"description": kb.get("description", ""),
"created_at": kb.get("created_at", ""),
"updated_at": kb.get("updated_at", ""),
}
if "files" in kb and isinstance(kb["files"], list):
result["files"] = kb["files"]
if "data" in kb and isinstance(kb["data"], dict):
result["data"] = kb["data"]
return result
return None
except Exception as e:
raise StorageError(f"Failed to get knowledge base by name: {e}") from e
@@ -623,7 +735,9 @@ class OpenWebUIStorage(BaseStorage):
elif isinstance(file_info.get("meta"), dict):
meta = file_info.get("meta")
if isinstance(meta, dict):
filename = meta.get("name")
filename_value = meta.get("name")
if isinstance(filename_value, str):
filename = filename_value
# Final fallback
if not filename:
@@ -635,9 +749,11 @@ class OpenWebUIStorage(BaseStorage):
size = 0
meta = file_info.get("meta")
if isinstance(meta, dict):
size = meta.get("size", 0)
size_value = meta.get("size", 0)
size = int(size_value) if isinstance(size_value, (int, float)) else 0
else:
size = file_info.get("size", 0)
size_value = file_info.get("size", 0)
size = int(size_value) if isinstance(size_value, (int, float)) else 0
# Estimate word count from file size (very rough approximation)
word_count = max(1, int(size / 6)) if isinstance(size, (int, float)) else 0
@@ -650,9 +766,7 @@ class OpenWebUIStorage(BaseStorage):
"content_type": str(file_info.get("content_type", "text/plain")),
"content_preview": f"File uploaded to OpenWebUI: {filename}",
"word_count": word_count,
"timestamp": str(
file_info.get("created_at") or file_info.get("timestamp", "")
),
"timestamp": str(file_info.get("created_at") or file_info.get("timestamp", "")),
}
documents.append(doc_info)

View File

@@ -4,9 +4,10 @@ from __future__ import annotations
import asyncio
import contextlib
import logging
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
from datetime import UTC, datetime
from typing import Self, TypeVar, cast
from typing import Final, Self, TypeVar, cast
from uuid import UUID, uuid4
# Direct imports for runtime and type checking
@@ -19,6 +20,8 @@ from ...core.models import Document, DocumentMetadata, IngestionSource, StorageC
from ..base import BaseStorage
from ..types import DocumentInfo
LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
T = TypeVar("T")
@@ -183,8 +186,10 @@ class R2RStorage(BaseStorage):
) -> list[str]:
"""Store multiple documents efficiently with connection reuse."""
collection_id = await self._resolve_collection_id(collection_name)
print(
f"Using collection ID: {collection_id} for collection: {collection_name or self.config.collection_name}"
LOGGER.info(
"Using collection ID: %s for collection: %s",
collection_id,
collection_name or self.config.collection_name,
)
# Filter valid documents upfront
@@ -218,7 +223,7 @@ class R2RStorage(BaseStorage):
if isinstance(result, str):
stored_ids.append(result)
elif isinstance(result, Exception):
print(f"Document upload failed: {result}")
LOGGER.error("Document upload failed: %s", result)
return stored_ids
@@ -239,12 +244,14 @@ class R2RStorage(BaseStorage):
requested_id = str(document.id)
if not document.content or not document.content.strip():
print(f"Skipping document {requested_id}: empty content")
LOGGER.warning("Skipping document %s: empty content", requested_id)
return False
if len(document.content) > 1_000_000: # 1MB limit
print(
f"Skipping document {requested_id}: content too large ({len(document.content)} chars)"
LOGGER.warning(
"Skipping document %s: content too large (%d chars)",
requested_id,
len(document.content),
)
return False
@@ -263,7 +270,7 @@ class R2RStorage(BaseStorage):
) -> str | None:
"""Store a single document with retry logic using provided HTTP client."""
requested_id = str(document.id)
print(f"Creating document with ID: {requested_id}")
LOGGER.debug("Creating document with ID: %s", requested_id)
max_retries = 3
retry_delay = 1.0
@@ -313,7 +320,7 @@ class R2RStorage(BaseStorage):
requested_id = str(document.id)
metadata = self._build_metadata(document)
print(f"Built metadata for document {requested_id}: {metadata}")
LOGGER.debug("Built metadata for document %s: %s", requested_id, metadata)
files = {
"raw_text": (None, document.content),
@@ -324,10 +331,12 @@ class R2RStorage(BaseStorage):
if collection_id:
files["collection_ids"] = (None, json.dumps([collection_id]))
print(f"Creating document {requested_id} with collection_ids: [{collection_id}]")
LOGGER.debug(
"Creating document %s with collection_ids: [%s]", requested_id, collection_id
)
print(f"Sending to R2R - files keys: {list(files.keys())}")
print(f"Metadata JSON: {files['metadata'][1]}")
LOGGER.debug("Sending to R2R - files keys: %s", list(files.keys()))
LOGGER.debug("Metadata JSON: %s", files["metadata"][1])
response = await http_client.post(f"{self.endpoint}/v3/documents", files=files) # type: ignore[call-arg]
@@ -346,15 +355,17 @@ class R2RStorage(BaseStorage):
error_detail = (
getattr(response, "json", lambda: {})() if hasattr(response, "json") else {}
)
print(f"R2R validation error for document {requested_id}: {error_detail}")
print(f"Document metadata sent: {metadata}")
print(f"Response status: {getattr(response, 'status_code', 'unknown')}")
print(f"Response headers: {dict(getattr(response, 'headers', {}))}")
LOGGER.error("R2R validation error for document %s: %s", requested_id, error_detail)
LOGGER.error("Document metadata sent: %s", metadata)
LOGGER.error("Response status: %s", getattr(response, "status_code", "unknown"))
LOGGER.error("Response headers: %s", dict(getattr(response, "headers", {})))
except Exception:
print(
f"R2R validation error for document {requested_id}: {getattr(response, 'text', 'unknown error')}"
LOGGER.error(
"R2R validation error for document %s: %s",
requested_id,
getattr(response, "text", "unknown error"),
)
print(f"Document metadata sent: {metadata}")
LOGGER.error("Document metadata sent: %s", metadata)
def _process_document_response(
self, doc_response: dict[str, object], requested_id: str, collection_id: str
@@ -363,14 +374,16 @@ class R2RStorage(BaseStorage):
response_payload = doc_response.get("results", doc_response)
doc_id = _extract_id(response_payload, requested_id)
print(f"R2R returned document ID: {doc_id}")
LOGGER.info("R2R returned document ID: %s", doc_id)
if doc_id != requested_id:
print(f"Warning: Requested ID {requested_id} but got {doc_id}")
LOGGER.warning("Requested ID %s but got %s", requested_id, doc_id)
if collection_id:
print(
f"Document {doc_id} should be assigned to collection {collection_id} via creation API"
LOGGER.info(
"Document %s should be assigned to collection %s via creation API",
doc_id,
collection_id,
)
return doc_id
@@ -387,7 +400,7 @@ class R2RStorage(BaseStorage):
if attempt >= max_retries - 1:
return False
print(f"Timeout for document {requested_id}, retrying in {retry_delay}s...")
LOGGER.warning("Timeout for document %s, retrying in %ss...", requested_id, retry_delay)
await asyncio.sleep(retry_delay)
return True
@@ -404,23 +417,26 @@ class R2RStorage(BaseStorage):
if status_code < 500 or attempt >= max_retries - 1:
return False
print(
f"Server error {status_code} for document {requested_id}, retrying in {retry_delay}s..."
LOGGER.warning(
"Server error %s for document %s, retrying in %ss...",
status_code,
requested_id,
retry_delay,
)
await asyncio.sleep(retry_delay)
return True
def _log_document_error(self, document_id: object, exc: Exception) -> None:
"""Log document storage errors with specific categorization."""
print(f"Failed to store document {document_id}: {exc}")
LOGGER.error("Failed to store document %s: %s", document_id, exc)
exc_str = str(exc)
if "422" in exc_str:
print(" → Data validation issue - check document content and metadata format")
LOGGER.error(" → Data validation issue - check document content and metadata format")
elif "timeout" in exc_str.lower():
print(" → Network timeout - R2R may be overloaded")
LOGGER.error(" → Network timeout - R2R may be overloaded")
elif "500" in exc_str:
print(" → Server error - R2R internal issue")
LOGGER.error(" → Server error - R2R internal issue")
else:
import traceback

View File

@@ -5,6 +5,7 @@ from typing import TypedDict
class CollectionSummary(TypedDict):
"""Collection metadata for describe_collections."""
name: str
count: int
size_mb: float
@@ -12,6 +13,7 @@ class CollectionSummary(TypedDict):
class DocumentInfo(TypedDict):
"""Document information for list_documents."""
id: str
title: str
source_url: str
@@ -19,4 +21,4 @@ class DocumentInfo(TypedDict):
content_type: str
content_preview: str
word_count: int
timestamp: str
timestamp: str

View File

@@ -1,8 +1,9 @@
"""Weaviate storage adapter."""
from collections.abc import AsyncGenerator, Mapping, Sequence
import asyncio
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
from datetime import UTC, datetime
from typing import Literal, Self, TypeAlias, cast, overload
from typing import Literal, Self, TypeAlias, TypeVar, cast, overload
from uuid import UUID
import weaviate
@@ -24,6 +25,7 @@ from .base import BaseStorage
from .types import CollectionSummary, DocumentInfo
VectorContainer: TypeAlias = Mapping[str, object] | Sequence[object] | None
T = TypeVar("T")
class WeaviateStorage(BaseStorage):
@@ -45,6 +47,28 @@ class WeaviateStorage(BaseStorage):
self.vectorizer = Vectorizer(config)
self._default_collection = self._normalize_collection_name(config.collection_name)
async def _run_sync(self, func: Callable[..., T], *args: object, **kwargs: object) -> T:
"""
Run synchronous Weaviate operations in thread pool to avoid blocking event loop.
Args:
func: Synchronous function to run
*args: Positional arguments for the function
**kwargs: Keyword arguments for the function
Returns:
Result of the function call
Raises:
StorageError: If the operation fails
"""
try:
return await asyncio.to_thread(func, *args, **kwargs)
except (WeaviateConnectionError, WeaviateBatchError, WeaviateQueryError) as e:
raise StorageError(f"Weaviate operation failed: {e}") from e
except Exception as e:
raise StorageError(f"Unexpected error in Weaviate operation: {e}") from e
@override
async def initialize(self) -> None:
"""Initialize Weaviate client and create collection if needed."""
@@ -53,7 +77,7 @@ class WeaviateStorage(BaseStorage):
self.client = weaviate.WeaviateClient(
connection_params=weaviate.connect.ConnectionParams.from_url(
url=str(self.config.endpoint),
grpc_port=50051, # Default gRPC port
grpc_port=self.config.grpc_port or 50051,
),
additional_config=weaviate.classes.init.AdditionalConfig(
timeout=weaviate.classes.init.Timeout(init=30, query=60, insert=120),
@@ -61,7 +85,7 @@ class WeaviateStorage(BaseStorage):
)
# Connect to the client
self.client.connect()
await self._run_sync(self.client.connect)
# Ensure the default collection exists
await self._ensure_collection(self._default_collection)
@@ -76,8 +100,8 @@ class WeaviateStorage(BaseStorage):
if not self.client:
raise StorageError("Weaviate client not initialized")
try:
client = cast(weaviate.WeaviateClient, self.client)
client.collections.create(
await self._run_sync(
self.client.collections.create,
name=collection_name,
properties=[
Property(
@@ -106,7 +130,7 @@ class WeaviateStorage(BaseStorage):
],
vectorizer_config=Configure.Vectorizer.none(),
)
except Exception as e:
except (WeaviateConnectionError, WeaviateBatchError) as e:
raise StorageError(f"Failed to create collection: {e}") from e
@staticmethod
@@ -114,13 +138,9 @@ class WeaviateStorage(BaseStorage):
"""Normalize vector payloads returned by Weaviate into a float list."""
if isinstance(vector_raw, Mapping):
default_vector = vector_raw.get("default")
return WeaviateStorage._extract_vector(
cast(VectorContainer, default_vector)
)
return WeaviateStorage._extract_vector(cast(VectorContainer, default_vector))
if not isinstance(vector_raw, Sequence) or isinstance(
vector_raw, (str, bytes, bytearray)
):
if not isinstance(vector_raw, Sequence) or isinstance(vector_raw, (str, bytes, bytearray)):
return None
items = list(vector_raw)
@@ -135,9 +155,7 @@ class WeaviateStorage(BaseStorage):
except (TypeError, ValueError):
return None
if isinstance(first_item, Sequence) and not isinstance(
first_item, (str, bytes, bytearray)
):
if isinstance(first_item, Sequence) and not isinstance(first_item, (str, bytes, bytearray)):
inner_items = list(first_item)
if all(isinstance(item, (int, float)) for item in inner_items):
try:
@@ -168,8 +186,7 @@ class WeaviateStorage(BaseStorage):
properties: object,
*,
context: str,
) -> Mapping[str, object]:
...
) -> Mapping[str, object]: ...
@staticmethod
@overload
@@ -178,8 +195,7 @@ class WeaviateStorage(BaseStorage):
*,
context: str,
allow_missing: Literal[False],
) -> Mapping[str, object]:
...
) -> Mapping[str, object]: ...
@staticmethod
@overload
@@ -188,8 +204,7 @@ class WeaviateStorage(BaseStorage):
*,
context: str,
allow_missing: Literal[True],
) -> Mapping[str, object] | None:
...
) -> Mapping[str, object] | None: ...
@staticmethod
def _coerce_properties(
@@ -211,6 +226,29 @@ class WeaviateStorage(BaseStorage):
return cast(Mapping[str, object], properties)
@staticmethod
def _build_document_properties(doc: Document) -> dict[str, object]:
"""
Build Weaviate properties dict from document.
Args:
doc: Document to build properties for
Returns:
Properties dict suitable for Weaviate
"""
return {
"content": doc.content,
"source_url": doc.metadata["source_url"],
"title": doc.metadata.get("title", ""),
"description": doc.metadata.get("description", ""),
"timestamp": doc.metadata["timestamp"].isoformat(),
"content_type": doc.metadata["content_type"],
"word_count": doc.metadata["word_count"],
"char_count": doc.metadata["char_count"],
"source": doc.source.value,
}
def _normalize_collection_name(self, collection_name: str | None) -> str:
"""Return a canonicalized collection name, defaulting to configured value."""
candidate = collection_name or self.config.collection_name
@@ -227,7 +265,7 @@ class WeaviateStorage(BaseStorage):
if not self.client:
raise StorageError("Weaviate client not initialized")
client = cast(weaviate.WeaviateClient, self.client)
client = self.client
existing = client.collections.list_all()
if collection_name not in existing:
await self._create_collection(collection_name)
@@ -247,7 +285,7 @@ class WeaviateStorage(BaseStorage):
if ensure_exists:
await self._ensure_collection(normalized)
client = cast(weaviate.WeaviateClient, self.client)
client = self.client
return client.collections.get(normalized), normalized
@override
@@ -271,26 +309,19 @@ class WeaviateStorage(BaseStorage):
)
# Prepare properties
properties = {
"content": document.content,
"source_url": document.metadata["source_url"],
"title": document.metadata.get("title", ""),
"description": document.metadata.get("description", ""),
"timestamp": document.metadata["timestamp"].isoformat(),
"content_type": document.metadata["content_type"],
"word_count": document.metadata["word_count"],
"char_count": document.metadata["char_count"],
"source": document.source.value,
}
properties = self._build_document_properties(document)
# Insert with vector
result = collection.data.insert(
properties=properties, vector=document.vector, uuid=str(document.id)
result = await self._run_sync(
collection.data.insert,
properties=properties,
vector=document.vector,
uuid=str(document.id),
)
return str(result)
except Exception as e:
except (WeaviateConnectionError, WeaviateBatchError, WeaviateQueryError) as e:
raise StorageError(f"Failed to store document: {e}") from e
@override
@@ -311,32 +342,24 @@ class WeaviateStorage(BaseStorage):
collection_name, ensure_exists=True
)
# Vectorize documents without vectors
for doc in documents:
if doc.vector is None:
doc.vector = await self.vectorizer.vectorize(doc.content)
# Vectorize documents without vectors using batch processing
to_vectorize = [(i, doc) for i, doc in enumerate(documents) if doc.vector is None]
if to_vectorize:
contents = [doc.content for _, doc in to_vectorize]
vectors = await self.vectorizer.vectorize_batch(contents)
for (idx, _), vector in zip(to_vectorize, vectors, strict=False):
documents[idx].vector = vector
# Prepare batch data for insert_many
batch_objects = []
for doc in documents:
properties = {
"content": doc.content,
"source_url": doc.metadata["source_url"],
"title": doc.metadata.get("title", ""),
"description": doc.metadata.get("description", ""),
"timestamp": doc.metadata["timestamp"].isoformat(),
"content_type": doc.metadata["content_type"],
"word_count": doc.metadata["word_count"],
"char_count": doc.metadata["char_count"],
"source": doc.source.value,
}
properties = self._build_document_properties(doc)
batch_objects.append(
DataObject(properties=properties, vector=doc.vector, uuid=str(doc.id))
)
# Insert batch using insert_many
response = collection.data.insert_many(batch_objects)
response = await self._run_sync(collection.data.insert_many, batch_objects)
successful_ids: list[str] = []
error_indices = set(response.errors.keys()) if response else set()
@@ -361,11 +384,7 @@ class WeaviateStorage(BaseStorage):
return successful_ids
except WeaviateBatchError as e:
raise StorageError(f"Batch operation failed: {e}") from e
except WeaviateConnectionError as e:
raise StorageError(f"Connection to Weaviate failed: {e}") from e
except Exception as e:
except (WeaviateBatchError, WeaviateConnectionError, WeaviateQueryError) as e:
raise StorageError(f"Failed to store batch: {e}") from e
@override
@@ -385,7 +404,7 @@ class WeaviateStorage(BaseStorage):
collection, resolved_name = await self._prepare_collection(
collection_name, ensure_exists=False
)
result = collection.query.fetch_object_by_id(document_id)
result = await self._run_sync(collection.query.fetch_object_by_id, document_id)
if not result:
return None
@@ -395,13 +414,30 @@ class WeaviateStorage(BaseStorage):
result.properties,
context="fetch_object_by_id",
)
# Parse timestamp to datetime for consistent metadata format
from datetime import UTC, datetime
timestamp_raw = props.get("timestamp")
timestamp_parsed: datetime
try:
if isinstance(timestamp_raw, str):
timestamp_parsed = datetime.fromisoformat(timestamp_raw)
if timestamp_parsed.tzinfo is None:
timestamp_parsed = timestamp_parsed.replace(tzinfo=UTC)
elif isinstance(timestamp_raw, datetime):
timestamp_parsed = timestamp_raw
if timestamp_parsed.tzinfo is None:
timestamp_parsed = timestamp_parsed.replace(tzinfo=UTC)
else:
timestamp_parsed = datetime.now(UTC)
except (ValueError, TypeError):
timestamp_parsed = datetime.now(UTC)
metadata_dict = {
"source_url": str(props["source_url"]),
"title": str(props.get("title")) if props.get("title") else None,
"description": str(props.get("description"))
if props.get("description")
else None,
"timestamp": str(props["timestamp"]),
"description": str(props.get("description")) if props.get("description") else None,
"timestamp": timestamp_parsed,
"content_type": str(props["content_type"]),
"word_count": int(str(props["word_count"])),
"char_count": int(str(props["char_count"])),
@@ -424,11 +460,13 @@ class WeaviateStorage(BaseStorage):
except WeaviateConnectionError as e:
# Connection issues should be logged and return None
import logging
logging.warning(f"Weaviate connection error retrieving document {document_id}: {e}")
return None
except Exception as e:
# Log unexpected errors for debugging
import logging
logging.warning(f"Unexpected error retrieving document {document_id}: {e}")
return None
@@ -437,9 +475,7 @@ class WeaviateStorage(BaseStorage):
metadata_dict = {
"source_url": str(props["source_url"]),
"title": str(props.get("title")) if props.get("title") else None,
"description": str(props.get("description"))
if props.get("description")
else None,
"description": str(props.get("description")) if props.get("description") else None,
"timestamp": str(props["timestamp"]),
"content_type": str(props["content_type"]),
"word_count": int(str(props["word_count"])),
@@ -462,6 +498,7 @@ class WeaviateStorage(BaseStorage):
return max(0.0, 1.0 - distance_value)
except (TypeError, ValueError) as e:
import logging
logging.debug(f"Invalid distance value {raw_distance}: {e}")
return None
@@ -506,37 +543,39 @@ class WeaviateStorage(BaseStorage):
collection_name: str | None = None,
) -> AsyncGenerator[Document, None]:
"""
Search for documents in Weaviate.
Search for documents in Weaviate using hybrid search.
Args:
query: Search query
limit: Maximum results
threshold: Similarity threshold
threshold: Similarity threshold (not used in hybrid search)
Yields:
Matching documents
"""
try:
query_vector = await self.vectorizer.vectorize(query)
if not self.client:
raise StorageError("Weaviate client not initialized")
collection, resolved_name = await self._prepare_collection(
collection_name, ensure_exists=False
)
results = collection.query.near_vector(
near_vector=query_vector,
limit=limit,
distance=1 - threshold,
return_metadata=["distance"],
)
# Try hybrid search first, fall back to BM25 keyword search
try:
response = await self._run_sync(
collection.query.hybrid, query=query, limit=limit, return_metadata=["score"]
)
except (WeaviateQueryError, StorageError):
# Fall back to BM25 if hybrid search is not supported or fails
response = await self._run_sync(
collection.query.bm25, query=query, limit=limit, return_metadata=["score"]
)
for result in results.objects:
yield self._build_search_document(result, resolved_name)
for obj in response.objects:
yield self._build_document_from_search(obj, resolved_name)
except WeaviateQueryError as e:
raise StorageError(f"Search query failed: {e}") from e
except WeaviateConnectionError as e:
raise StorageError(f"Connection to Weaviate failed during search: {e}") from e
except Exception as e:
except (WeaviateQueryError, WeaviateConnectionError) as e:
raise StorageError(f"Search failed: {e}") from e
@override
@@ -552,7 +591,7 @@ class WeaviateStorage(BaseStorage):
"""
try:
collection, _ = await self._prepare_collection(collection_name, ensure_exists=False)
collection.data.delete_by_id(document_id)
await self._run_sync(collection.data.delete_by_id, document_id)
return True
except WeaviateQueryError as e:
raise StorageError(f"Delete operation failed: {e}") from e
@@ -589,10 +628,10 @@ class WeaviateStorage(BaseStorage):
if not self.client:
raise StorageError("Weaviate client not initialized")
client = cast(weaviate.WeaviateClient, self.client)
client = self.client
return list(client.collections.list_all())
except Exception as e:
except WeaviateConnectionError as e:
raise StorageError(f"Failed to list collections: {e}") from e
async def describe_collections(self) -> list[CollectionSummary]:
@@ -601,7 +640,7 @@ class WeaviateStorage(BaseStorage):
raise StorageError("Weaviate client not initialized")
try:
client = cast(weaviate.WeaviateClient, self.client)
client = self.client
collections: list[CollectionSummary] = []
for name in client.collections.list_all():
collection_obj = client.collections.get(name)
@@ -639,7 +678,7 @@ class WeaviateStorage(BaseStorage):
)
# Query for sample documents
response = collection.query.fetch_objects(limit=limit)
response = await self._run_sync(collection.query.fetch_objects, limit=limit)
documents = []
for obj in response.objects:
@@ -712,9 +751,7 @@ class WeaviateStorage(BaseStorage):
return {
"source_url": str(props.get("source_url", "")),
"title": str(props.get("title", "")) if props.get("title") else None,
"description": str(props.get("description", ""))
if props.get("description")
else None,
"description": str(props.get("description", "")) if props.get("description") else None,
"timestamp": datetime.fromisoformat(
str(props.get("timestamp", datetime.now(UTC).isoformat()))
),
@@ -737,6 +774,7 @@ class WeaviateStorage(BaseStorage):
return float(raw_score)
except (TypeError, ValueError) as e:
import logging
logging.debug(f"Invalid score value {raw_score}: {e}")
return None
@@ -780,31 +818,11 @@ class WeaviateStorage(BaseStorage):
Returns:
List of matching documents
"""
try:
if not self.client:
raise StorageError("Weaviate client not initialized")
collection, resolved_name = await self._prepare_collection(
collection_name, ensure_exists=False
)
# Try hybrid search first, fall back to BM25 keyword search
try:
response = collection.query.hybrid(
query=query, limit=limit, return_metadata=["score"]
)
except Exception:
response = collection.query.bm25(
query=query, limit=limit, return_metadata=["score"]
)
return [
self._build_document_from_search(obj, resolved_name)
for obj in response.objects
]
except Exception as e:
raise StorageError(f"Failed to search documents: {e}") from e
# Delegate to the unified search method
results = []
async for document in self.search(query, limit=limit, collection_name=collection_name):
results.append(document)
return results
async def list_documents(
self,
@@ -830,8 +848,11 @@ class WeaviateStorage(BaseStorage):
collection, _ = await self._prepare_collection(collection_name, ensure_exists=False)
# Query documents with pagination
response = collection.query.fetch_objects(
limit=limit, offset=offset, return_metadata=["creation_time"]
response = await self._run_sync(
collection.query.fetch_objects,
limit=limit,
offset=offset,
return_metadata=["creation_time"],
)
documents: list[DocumentInfo] = []
@@ -896,7 +917,9 @@ class WeaviateStorage(BaseStorage):
)
delete_filter = Filter.by_id().contains_any(document_ids)
response = collection.data.delete_many(where=delete_filter, verbose=True)
response = await self._run_sync(
collection.data.delete_many, where=delete_filter, verbose=True
)
if objects := getattr(response, "objects", None):
for result_obj in objects:
@@ -938,20 +961,22 @@ class WeaviateStorage(BaseStorage):
# Get documents matching filter
if where_filter:
response = collection.query.fetch_objects(
response = await self._run_sync(
collection.query.fetch_objects,
filters=where_filter,
limit=1000, # Max batch size
)
else:
response = collection.query.fetch_objects(
limit=1000 # Max batch size
response = await self._run_sync(
collection.query.fetch_objects,
limit=1000, # Max batch size
)
# Delete matching documents
deleted_count = 0
for obj in response.objects:
try:
collection.data.delete_by_id(obj.uuid)
await self._run_sync(collection.data.delete_by_id, obj.uuid)
deleted_count += 1
except Exception:
continue
@@ -975,7 +1000,7 @@ class WeaviateStorage(BaseStorage):
target = self._normalize_collection_name(collection_name)
# Delete the collection using the client's collections API
client = cast(weaviate.WeaviateClient, self.client)
client = self.client
client.collections.delete(target)
return True
@@ -997,20 +1022,29 @@ class WeaviateStorage(BaseStorage):
await self.close()
async def close(self) -> None:
"""Close client connection."""
"""Close client connection and vectorizer HTTP client."""
if self.client:
try:
client = cast(weaviate.WeaviateClient, self.client)
client = self.client
client.close()
except Exception as e:
except (WeaviateConnectionError, AttributeError) as e:
import logging
logging.warning(f"Error closing Weaviate client: {e}")
# Close vectorizer HTTP client to prevent resource leaks
try:
await self.vectorizer.close()
except (AttributeError, OSError) as e:
import logging
logging.warning(f"Error closing vectorizer client: {e}")
def __del__(self) -> None:
"""Clean up client connection as fallback."""
if self.client:
try:
client = cast(weaviate.WeaviateClient, self.client)
client = self.client
client.close()
except Exception:
pass # Ignore errors in destructor

View File

@@ -0,0 +1,117 @@
"""Async utilities for task management and backpressure control."""
import asyncio
import logging
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable
from contextlib import asynccontextmanager
from typing import Final, TypeVar
LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
T = TypeVar("T")
class AsyncTaskManager:
"""Manages concurrent tasks with backpressure control."""
def __init__(self, max_concurrent: int = 10):
"""
Initialize task manager.
Args:
max_concurrent: Maximum number of concurrent tasks
"""
self.semaphore = asyncio.Semaphore(max_concurrent)
self.max_concurrent = max_concurrent
@asynccontextmanager
async def acquire(self) -> AsyncGenerator[None, None]:
"""Acquire a slot for task execution."""
async with self.semaphore:
yield
async def run_tasks(
self, tasks: Iterable[Awaitable[T]], return_exceptions: bool = False
) -> list[T | BaseException]:
"""
Run multiple tasks with backpressure control.
Args:
tasks: Iterable of awaitable tasks
return_exceptions: Whether to return exceptions or raise them
Returns:
List of task results or exceptions
"""
async def _controlled_task(task: Awaitable[T]) -> T:
async with self.acquire():
return await task
controlled_tasks = [_controlled_task(task) for task in tasks]
if return_exceptions:
results = await asyncio.gather(*controlled_tasks, return_exceptions=True)
return list(results)
else:
results = await asyncio.gather(*controlled_tasks)
return list(results)
async def map_async(
self, func: Callable[[T], Awaitable[T]], items: Iterable[T], return_exceptions: bool = False
) -> list[T | BaseException]:
"""
Apply async function to items with backpressure control.
Args:
func: Async function to apply
items: Items to process
return_exceptions: Whether to return exceptions or raise them
Returns:
List of processed results or exceptions
"""
tasks = [func(item) for item in items]
return await self.run_tasks(tasks, return_exceptions=return_exceptions)
async def run_with_semaphore(semaphore: asyncio.Semaphore, coro: Awaitable[T]) -> T:
"""Run coroutine with semaphore-controlled concurrency."""
async with semaphore:
return await coro
async def batch_process(
items: list[T],
processor: Callable[[T], Awaitable[T]],
batch_size: int = 50,
max_concurrent: int = 5,
) -> list[T]:
"""
Process items in batches with controlled concurrency.
Args:
items: Items to process
processor: Async function to process each item
batch_size: Number of items per batch
max_concurrent: Maximum concurrent tasks per batch
Returns:
List of processed results
"""
task_manager = AsyncTaskManager(max_concurrent)
results: list[T] = []
for i in range(0, len(items), batch_size):
batch = items[i : i + batch_size]
LOGGER.debug(
"Processing batch %d-%d of %d items", i, min(i + batch_size, len(items)), len(items)
)
batch_results = await task_manager.map_async(processor, batch, return_exceptions=False)
# If return_exceptions=False, exceptions would have been raised, so all results are successful
# Type checker doesn't know this, so we need to cast
successful_results: list[T] = [r for r in batch_results if not isinstance(r, BaseException)]
results.extend(successful_results)
return results

View File

@@ -6,12 +6,12 @@ from typing import Final, Protocol, TypedDict, cast
import httpx
from ..config import get_settings
from ..core.exceptions import IngestionError
from ..core.models import Document
JSON_CONTENT_TYPE: Final[str] = "application/json"
AUTHORIZATION_HEADER: Final[str] = "Authorization"
from ..config import get_settings
class HttpResponse(Protocol):
@@ -24,12 +24,7 @@ class HttpResponse(Protocol):
class AsyncHttpClient(Protocol):
"""Protocol for async HTTP client."""
async def post(
self,
url: str,
*,
json: dict[str, object] | None = None
) -> HttpResponse: ...
async def post(self, url: str, *, json: dict[str, object] | None = None) -> HttpResponse: ...
async def aclose(self) -> None: ...
@@ -45,16 +40,19 @@ class AsyncHttpClient(Protocol):
class LlmResponse(TypedDict):
"""Type for LLM API response structure."""
choices: list[dict[str, object]]
class LlmChoice(TypedDict):
"""Type for individual choice in LLM response."""
message: dict[str, object]
class LlmMessage(TypedDict):
"""Type for message in LLM choice."""
content: str
@@ -96,7 +94,7 @@ class MetadataTagger:
"""
settings = get_settings()
endpoint_value = llm_endpoint or str(settings.llm_endpoint)
self.endpoint = endpoint_value.rstrip('/')
self.endpoint = endpoint_value.rstrip("/")
self.model = model or settings.metadata_model
resolved_timeout = timeout if timeout is not None else float(settings.request_timeout)
@@ -161,12 +159,16 @@ class MetadataTagger:
# Build a proper DocumentMetadata instance with only valid keys
new_metadata: CoreDocumentMetadata = {
"source_url": str(updated_metadata.get("source_url", "")),
"timestamp": (
lambda ts: ts if isinstance(ts, datetime) else datetime.now(UTC)
)(updated_metadata.get("timestamp", datetime.now(UTC))),
"timestamp": (lambda ts: ts if isinstance(ts, datetime) else datetime.now(UTC))(
updated_metadata.get("timestamp", datetime.now(UTC))
),
"content_type": str(updated_metadata.get("content_type", "text/plain")),
"word_count": (lambda wc: int(wc) if isinstance(wc, (int, str)) else 0)(updated_metadata.get("word_count", 0)),
"char_count": (lambda cc: int(cc) if isinstance(cc, (int, str)) else 0)(updated_metadata.get("char_count", 0)),
"word_count": (lambda wc: int(wc) if isinstance(wc, (int, str)) else 0)(
updated_metadata.get("word_count", 0)
),
"char_count": (lambda cc: int(cc) if isinstance(cc, (int, str)) else 0)(
updated_metadata.get("char_count", 0)
),
}
# Add optional fields if they exist

View File

@@ -1,20 +1,79 @@
"""Vectorizer utility for generating embeddings."""
import asyncio
from types import TracebackType
from typing import Final, Self, cast
from typing import Final, NotRequired, Self, TypedDict
import httpx
from typings import EmbeddingResponse
from ..config import get_settings
from ..core.exceptions import VectorizationError
from ..core.models import StorageConfig, VectorConfig
from ..config import get_settings
JSON_CONTENT_TYPE: Final[str] = "application/json"
AUTHORIZATION_HEADER: Final[str] = "Authorization"
class EmbeddingData(TypedDict):
"""Structure for embedding data from providers."""
embedding: list[float]
index: NotRequired[int]
object: NotRequired[str]
class EmbeddingResponse(TypedDict):
"""Embedding response format for multiple providers."""
data: list[EmbeddingData]
model: NotRequired[str]
object: NotRequired[str]
usage: NotRequired[dict[str, int]]
# Alternative formats
embedding: NotRequired[list[float]]
vector: NotRequired[list[float]]
embeddings: NotRequired[list[list[float]]]
def _extract_embedding_from_response(response_data: dict[str, object]) -> list[float]:
"""Extract embedding vector from provider response."""
# OpenAI/Ollama format: {"data": [{"embedding": [...]}]}
if "data" in response_data:
data_list = response_data["data"]
if isinstance(data_list, list) and data_list:
first_item = data_list[0]
if isinstance(first_item, dict) and "embedding" in first_item:
embedding = first_item["embedding"]
if isinstance(embedding, list) and all(
isinstance(x, (int, float)) for x in embedding
):
return [float(x) for x in embedding]
# Direct embedding format: {"embedding": [...]}
if "embedding" in response_data:
embedding = response_data["embedding"]
if isinstance(embedding, list) and all(isinstance(x, (int, float)) for x in embedding):
return [float(x) for x in embedding]
# Vector format: {"vector": [...]}
if "vector" in response_data:
vector = response_data["vector"]
if isinstance(vector, list) and all(isinstance(x, (int, float)) for x in vector):
return [float(x) for x in vector]
# Embeddings array format: {"embeddings": [[...]]}
if "embeddings" in response_data:
embeddings = response_data["embeddings"]
if isinstance(embeddings, list) and embeddings:
first_embedding = embeddings[0]
if isinstance(first_embedding, list) and all(
isinstance(x, (int, float)) for x in first_embedding
):
return [float(x) for x in first_embedding]
raise VectorizationError("Unrecognized embedding response format")
class Vectorizer:
"""Handles text vectorization using LLM endpoints."""
@@ -72,21 +131,34 @@ class Vectorizer:
async def vectorize_batch(self, texts: list[str]) -> list[list[float]]:
"""
Generate embeddings for multiple texts.
Generate embeddings for multiple texts in parallel.
Args:
texts: List of texts to vectorize
Returns:
List of embedding vectors
Raises:
VectorizationError: If any vectorization fails
"""
vectors: list[list[float]] = []
for text in texts:
vector = await self.vectorize(text)
vectors.append(vector)
if not texts:
return []
return vectors
# Use semaphore to limit concurrent requests and prevent overwhelming the endpoint
semaphore = asyncio.Semaphore(20)
async def vectorize_with_semaphore(text: str) -> list[float]:
async with semaphore:
return await self.vectorize(text)
try:
# Execute all vectorization requests concurrently
vectors = await asyncio.gather(*[vectorize_with_semaphore(text) for text in texts])
return list(vectors)
except Exception as e:
raise VectorizationError(f"Batch vectorization failed: {e}") from e
async def _ollama_embed(self, text: str) -> list[float]:
"""
@@ -112,23 +184,12 @@ class Vectorizer:
_ = response.raise_for_status()
response_json = response.json()
# Response is expected to be dict[str, object] from our type stub
if not isinstance(response_json, dict):
raise VectorizationError("Invalid JSON response format")
response_data = cast(EmbeddingResponse, cast(object, response_json))
# Extract embedding using type-safe helper
embedding = _extract_embedding_from_response(response_json)
# Parse OpenAI-compatible response format
embeddings_list = response_data.get("data", [])
if not embeddings_list:
raise VectorizationError("No embeddings returned")
first_embedding = embeddings_list[0]
embedding_raw = first_embedding.get("embedding")
if not embedding_raw:
raise VectorizationError("Invalid embedding format")
# Convert to float list and validate
embedding: list[float] = []
embedding.extend(float(item) for item in embedding_raw)
# Ensure correct dimension
if len(embedding) != self.dimension:
raise VectorizationError(
@@ -157,22 +218,12 @@ class Vectorizer:
_ = response.raise_for_status()
response_json = response.json()
# Response is expected to be dict[str, object] from our type stub
if not isinstance(response_json, dict):
raise VectorizationError("Invalid JSON response format")
response_data = cast(EmbeddingResponse, cast(object, response_json))
# Extract embedding using type-safe helper
embedding = _extract_embedding_from_response(response_json)
embeddings_list = response_data.get("data", [])
if not embeddings_list:
raise VectorizationError("No embeddings returned")
first_embedding = embeddings_list[0]
embedding_raw = first_embedding.get("embedding")
if not embedding_raw:
raise VectorizationError("Invalid embedding format")
# Convert to float list and validate
embedding: list[float] = []
embedding.extend(float(item) for item in embedding_raw)
# Ensure correct dimension
if len(embedding) != self.dimension:
raise VectorizationError(
@@ -185,6 +236,14 @@ class Vectorizer:
"""Async context manager entry."""
return self
async def close(self) -> None:
"""Close the HTTP client connection."""
try:
await self.client.aclose()
except Exception:
# Already closed or connection lost
pass
async def __aexit__(
self,
exc_type: type[BaseException] | None,
@@ -192,4 +251,4 @@ class Vectorizer:
exc_tb: TracebackType | None,
) -> None:
"""Async context manager exit."""
await self.client.aclose()
await self.close()

27
notebooks/welcome.py Normal file
View File

@@ -0,0 +1,27 @@
import marimo
__generated_with = "0.16.0"
app = marimo.App()
@app.cell
def __():
import marimo as mo
return (mo,)
@app.cell
def __(mo):
mo.md("# Welcome to Marimo!")
return
@app.cell
def __(mo):
mo.md("This is your interactive notebook environment.")
return
if __name__ == "__main__":
app.run()

File diff suppressed because it is too large Load Diff

View File

@@ -59,6 +59,7 @@ class StubbedResponse:
def raise_for_status(self) -> None:
if self.status_code < 400:
return
# Create a minimal exception for testing - we don't need full httpx objects for tests
class TestHTTPError(Exception):
request: object | None
@@ -232,7 +233,7 @@ class AsyncClientStub:
# Convert params to the format expected by other methods
converted_params: dict[str, object] | None = None
if params:
converted_params = {k: v for k, v in params.items()}
converted_params = dict(params)
method_upper = method.upper()
if method_upper == "GET":
@@ -441,7 +442,6 @@ def r2r_spec() -> OpenAPISpec:
return OpenAPISpec.from_file(PROJECT_ROOT / "r2r.json")
@pytest.fixture(scope="session")
def firecrawl_spec() -> OpenAPISpec:
"""Load Firecrawl OpenAPI specification."""
@@ -539,7 +539,9 @@ def firecrawl_client_stub(
links = [SimpleNamespace(url=cast(str, item.get("url", ""))) for item in links_data]
return SimpleNamespace(success=payload.get("success", True), links=links)
async def scrape(self, url: str, formats: list[str] | None = None, **_: object) -> SimpleNamespace:
async def scrape(
self, url: str, formats: list[str] | None = None, **_: object
) -> SimpleNamespace:
payload = cast(MockResponseData, self._service.scrape_response(url, formats))
data = cast(MockResponseData, payload.get("data", {}))
metadata_payload = cast(MockResponseData, data.get("metadata", {}))

View File

@@ -158,7 +158,9 @@ class OpenAPISpec:
return []
return [segment for segment in path.strip("/").split("/") if segment]
def find_operation(self, method: str, path: str) -> tuple[Mapping[str, Any] | None, dict[str, str]]:
def find_operation(
self, method: str, path: str
) -> tuple[Mapping[str, Any] | None, dict[str, str]]:
method = method.lower()
normalized = "/" if path in {"", "/"} else "/" + path.strip("/")
actual_segments = self._split_path(normalized)
@@ -188,9 +190,7 @@ class OpenAPISpec:
status: str | None = None,
) -> tuple[int, Any]:
responses = operation.get("responses", {})
target_status = status or (
"200" if "200" in responses else next(iter(responses), "200")
)
target_status = status or ("200" if "200" in responses else next(iter(responses), "200"))
status_code = 200
try:
status_code = int(target_status)
@@ -222,7 +222,7 @@ class OpenAPISpec:
base = self.generate({"$ref": ref})
if overrides is None or not isinstance(base, Mapping):
return copy.deepcopy(base)
merged = copy.deepcopy(base)
merged: dict[str, Any] = dict(base)
for key, value in overrides.items():
if isinstance(value, Mapping) and isinstance(merged.get(key), Mapping):
merged[key] = self.generate_from_mapping(
@@ -237,8 +237,8 @@ class OpenAPISpec:
self,
base: Mapping[str, Any],
overrides: Mapping[str, Any],
) -> Mapping[str, Any]:
result = copy.deepcopy(base)
) -> dict[str, Any]:
result: dict[str, Any] = dict(base)
for key, value in overrides.items():
if isinstance(value, Mapping) and isinstance(result.get(key), Mapping):
result[key] = self.generate_from_mapping(result[key], value)
@@ -380,18 +380,20 @@ class OpenWebUIMockService(OpenAPIMockService):
def _knowledge_user_response(self, knowledge: Mapping[str, Any]) -> dict[str, Any]:
payload = self.spec.generate_from_ref("#/components/schemas/KnowledgeUserResponse")
payload.update({
"id": knowledge["id"],
"user_id": knowledge["user_id"],
"name": knowledge["name"],
"description": knowledge.get("description", ""),
"data": copy.deepcopy(knowledge.get("data", {})),
"meta": copy.deepcopy(knowledge.get("meta", {})),
"access_control": copy.deepcopy(knowledge.get("access_control", {})),
"created_at": knowledge.get("created_at", self._timestamp()),
"updated_at": knowledge.get("updated_at", self._timestamp()),
"files": copy.deepcopy(knowledge.get("files", [])),
})
payload.update(
{
"id": knowledge["id"],
"user_id": knowledge["user_id"],
"name": knowledge["name"],
"description": knowledge.get("description", ""),
"data": copy.deepcopy(knowledge.get("data", {})),
"meta": copy.deepcopy(knowledge.get("meta", {})),
"access_control": copy.deepcopy(knowledge.get("access_control", {})),
"created_at": knowledge.get("created_at", self._timestamp()),
"updated_at": knowledge.get("updated_at", self._timestamp()),
"files": copy.deepcopy(knowledge.get("files", [])),
}
)
return payload
def _knowledge_files_response(self, knowledge: Mapping[str, Any]) -> dict[str, Any]:
@@ -435,7 +437,9 @@ class OpenWebUIMockService(OpenAPIMockService):
# Delegate to parent
return super().handle(method=method, path=path, json=json, params=params, files=files)
def _handle_knowledge_endpoints(self, method: str, segments: list[str], json: Mapping[str, Any] | None) -> tuple[int, Any]:
def _handle_knowledge_endpoints(
self, method: str, segments: list[str], json: Mapping[str, Any] | None
) -> tuple[int, Any]:
"""Handle knowledge-related API endpoints."""
method_upper = method.upper()
segment_count = len(segments)
@@ -471,7 +475,9 @@ class OpenWebUIMockService(OpenAPIMockService):
)
return 200, entry
def _handle_specific_knowledge(self, method_upper: str, segments: list[str], json: Mapping[str, Any] | None) -> tuple[int, Any]:
def _handle_specific_knowledge(
self, method_upper: str, segments: list[str], json: Mapping[str, Any] | None
) -> tuple[int, Any]:
"""Handle endpoints for specific knowledge entries."""
knowledge_id = segments[3]
knowledge = self._knowledge.get(knowledge_id)
@@ -486,7 +492,9 @@ class OpenWebUIMockService(OpenAPIMockService):
# File operations
if segment_count == 6 and segments[4] == "file":
return self._handle_knowledge_file_operations(method_upper, segments[5], knowledge, json or {})
return self._handle_knowledge_file_operations(
method_upper, segments[5], knowledge, json or {}
)
# Delete knowledge
if method_upper == "DELETE" and segment_count == 5 and segments[4] == "delete":
@@ -495,7 +503,13 @@ class OpenWebUIMockService(OpenAPIMockService):
return 404, {"detail": "Knowledge operation not found"}
def _handle_knowledge_file_operations(self, method_upper: str, operation: str, knowledge: dict[str, Any], payload: Mapping[str, Any]) -> tuple[int, Any]:
def _handle_knowledge_file_operations(
self,
method_upper: str,
operation: str,
knowledge: dict[str, Any],
payload: Mapping[str, Any],
) -> tuple[int, Any]:
"""Handle file operations on knowledge entries."""
if method_upper != "POST":
return 405, {"detail": "Method not allowed"}
@@ -507,7 +521,9 @@ class OpenWebUIMockService(OpenAPIMockService):
return 404, {"detail": "File operation not found"}
def _add_file_to_knowledge(self, knowledge: dict[str, Any], payload: Mapping[str, Any]) -> tuple[int, Any]:
def _add_file_to_knowledge(
self, knowledge: dict[str, Any], payload: Mapping[str, Any]
) -> tuple[int, Any]:
"""Add a file to a knowledge entry."""
file_id = str(payload.get("file_id", ""))
if not file_id:
@@ -524,15 +540,21 @@ class OpenWebUIMockService(OpenAPIMockService):
return 200, self._knowledge_files_response(knowledge)
def _remove_file_from_knowledge(self, knowledge: dict[str, Any], payload: Mapping[str, Any]) -> tuple[int, Any]:
def _remove_file_from_knowledge(
self, knowledge: dict[str, Any], payload: Mapping[str, Any]
) -> tuple[int, Any]:
"""Remove a file from a knowledge entry."""
file_id = str(payload.get("file_id", ""))
knowledge["files"] = [item for item in knowledge.get("files", []) if item.get("id") != file_id]
knowledge["files"] = [
item for item in knowledge.get("files", []) if item.get("id") != file_id
]
knowledge["updated_at"] = self._timestamp()
return 200, self._knowledge_files_response(knowledge)
def _handle_file_endpoints(self, method: str, segments: list[str], files: Any) -> tuple[int, Any]:
def _handle_file_endpoints(
self, method: str, segments: list[str], files: Any
) -> tuple[int, Any]:
"""Handle file-related API endpoints."""
method_upper = method.upper()
@@ -561,7 +583,9 @@ class OpenWebUIMockService(OpenAPIMockService):
"""Delete a file and remove from all knowledge entries."""
self._files.pop(file_id, None)
for knowledge in self._knowledge.values():
knowledge["files"] = [item for item in knowledge.get("files", []) if item.get("id") != file_id]
knowledge["files"] = [
item for item in knowledge.get("files", []) if item.get("id") != file_id
]
return 200, {"deleted": True}
@@ -620,10 +644,14 @@ class R2RMockService(OpenAPIMockService):
None,
)
def _set_collection_document_ids(self, collection_id: str, document_id: str, *, add: bool) -> None:
def _set_collection_document_ids(
self, collection_id: str, document_id: str, *, add: bool
) -> None:
collection = self._collections.get(collection_id)
if collection is None:
collection = self.create_collection(name=f"Collection {collection_id}", collection_id=collection_id)
collection = self.create_collection(
name=f"Collection {collection_id}", collection_id=collection_id
)
documents = collection.setdefault("documents", [])
if add:
if document_id not in documents:
@@ -677,7 +705,9 @@ class R2RMockService(OpenAPIMockService):
self._set_collection_document_ids(collection_id, document_id, add=False)
return True
def append_document_metadata(self, document_id: str, metadata_list: list[dict[str, Any]]) -> dict[str, Any] | None:
def append_document_metadata(
self, document_id: str, metadata_list: list[dict[str, Any]]
) -> dict[str, Any] | None:
entry = self._documents.get(document_id)
if entry is None:
return None
@@ -712,7 +742,9 @@ class R2RMockService(OpenAPIMockService):
# Delegate to parent
return super().handle(method=method, path=path, json=json, params=params, files=files)
def _handle_collections_endpoint(self, method_upper: str, json: Mapping[str, Any] | None) -> tuple[int, Any]:
def _handle_collections_endpoint(
self, method_upper: str, json: Mapping[str, Any] | None
) -> tuple[int, Any]:
"""Handle collection-related endpoints."""
if method_upper == "GET":
return self._list_collections()
@@ -732,10 +764,12 @@ class R2RMockService(OpenAPIMockService):
clone["document_count"] = len(clone.get("documents", []))
results.append(clone)
payload.update({
"results": results,
"total_entries": len(results),
})
payload.update(
{
"results": results,
"total_entries": len(results),
}
)
return 200, payload
def _create_collection_endpoint(self, body: Mapping[str, Any]) -> tuple[int, Any]:
@@ -748,7 +782,9 @@ class R2RMockService(OpenAPIMockService):
payload.update({"results": entry})
return 200, payload
def _handle_documents_endpoint(self, method_upper: str, segments: list[str], json: Mapping[str, Any] | None, files: Any) -> tuple[int, Any]:
def _handle_documents_endpoint(
self, method_upper: str, segments: list[str], json: Mapping[str, Any] | None, files: Any
) -> tuple[int, Any]:
"""Handle document-related endpoints."""
if method_upper == "POST" and len(segments) == 2:
return self._create_document_endpoint(files)
@@ -774,11 +810,13 @@ class R2RMockService(OpenAPIMockService):
"#/components/schemas/R2RResults_IngestionResponse_"
)
ingestion = self.spec.generate_from_ref("#/components/schemas/IngestionResponse")
ingestion.update({
"message": ingestion.get("message") or "Ingestion task queued successfully.",
"document_id": document["id"],
"task_id": ingestion.get("task_id") or str(uuid4()),
})
ingestion.update(
{
"message": ingestion.get("message") or "Ingestion task queued successfully.",
"document_id": document["id"],
"task_id": ingestion.get("task_id") or str(uuid4()),
}
)
response_payload["results"] = ingestion
return 202, response_payload
@@ -835,7 +873,11 @@ class R2RMockService(OpenAPIMockService):
def _extract_content_from_files(self, raw_text_entry: Any) -> str:
"""Extract content from files entry."""
if isinstance(raw_text_entry, tuple) and len(raw_text_entry) >= 2 and raw_text_entry[1] is not None:
if (
isinstance(raw_text_entry, tuple)
and len(raw_text_entry) >= 2
and raw_text_entry[1] is not None
):
return str(raw_text_entry[1])
return ""
@@ -886,7 +928,7 @@ class R2RMockService(OpenAPIMockService):
results = response_payload.get("results")
if isinstance(results, Mapping):
results = copy.deepcopy(results)
results = dict(results)
results.update({"success": success})
else:
results = {"success": success}
@@ -894,7 +936,9 @@ class R2RMockService(OpenAPIMockService):
response_payload["results"] = results
return (200 if success else 404), response_payload
def _update_document_metadata(self, doc_id: str, json: Mapping[str, Any] | None) -> tuple[int, Any]:
def _update_document_metadata(
self, doc_id: str, json: Mapping[str, Any] | None
) -> tuple[int, Any]:
"""Update document metadata."""
metadata_list = [dict(item) for item in json] if isinstance(json, list) else []
document = self.append_document_metadata(doc_id, metadata_list)
@@ -919,7 +963,15 @@ class FirecrawlMockService(OpenAPIMockService):
def register_map_result(self, origin: str, links: list[str]) -> None:
self._maps[origin] = list(links)
def register_page(self, url: str, *, markdown: str | None = None, html: str | None = None, metadata: Mapping[str, Any] | None = None, links: list[str] | None = None) -> None:
def register_page(
self,
url: str,
*,
markdown: str | None = None,
html: str | None = None,
metadata: Mapping[str, Any] | None = None,
links: list[str] | None = None,
) -> None:
self._pages[url] = {
"markdown": markdown,
"html": html,

View File

@@ -0,0 +1,15 @@
from __future__ import annotations
import pytest
from ingest_pipeline.automations import AUTOMATION_TEMPLATES, get_automation_yaml_templates
@pytest.fixture(scope="function")
def automation_templates_copy() -> dict[str, str]:
return get_automation_yaml_templates()
@pytest.fixture(scope="function")
def automation_templates_reference() -> dict[str, str]:
return AUTOMATION_TEMPLATES

View File

@@ -0,0 +1,28 @@
from __future__ import annotations
import pytest
@pytest.mark.parametrize(
("template_key", "expected_snippet"),
[
("cancel_long_running", "Cancel Long Running Ingestion Flows"),
("retry_failed", "Retry Failed Ingestion Flows"),
("resource_monitoring", "Manage Work Pool Based on Resources"),
],
)
def test_automation_templates_include_expected_entries(
automation_templates_copy,
template_key: str,
expected_snippet: str,
) -> None:
assert expected_snippet in automation_templates_copy[template_key]
def test_get_automation_yaml_templates_returns_copy(
automation_templates_copy,
automation_templates_reference,
) -> None:
automation_templates_copy["cancel_long_running"] = "overridden"
assert "overridden" not in automation_templates_reference["cancel_long_running"]

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from collections.abc import AsyncGenerator
from types import SimpleNamespace
from uuid import uuid4
@@ -22,7 +23,9 @@ from ingest_pipeline.core.models import (
),
)
@pytest.mark.asyncio
async def test_run_ingestion_collection_resolution(monkeypatch: pytest.MonkeyPatch, collection_arg: str | None, expected: str) -> None:
async def test_run_ingestion_collection_resolution(
monkeypatch: pytest.MonkeyPatch, collection_arg: str | None, expected: str
) -> None:
recorded: dict[str, object] = {}
async def fake_flow(**kwargs: object) -> IngestionResult:
@@ -128,7 +131,7 @@ async def test_run_search_collects_results(monkeypatch: pytest.MonkeyPatch) -> N
threshold: float = 0.7,
*,
collection_name: str | None = None,
) -> object:
) -> AsyncGenerator[SimpleNamespace, None]:
yield SimpleNamespace(title="Title", content="Body text", score=0.91)
dummy_settings = SimpleNamespace(

View File

@@ -6,9 +6,11 @@ from unittest.mock import AsyncMock
import pytest
from ingest_pipeline.core.models import (
Document,
IngestionJob,
IngestionSource,
StorageBackend,
StorageConfig,
)
from ingest_pipeline.flows import ingestion
from ingest_pipeline.flows.ingestion import (
@@ -18,6 +20,7 @@ from ingest_pipeline.flows.ingestion import (
filter_existing_documents_task,
ingest_documents_task,
)
from ingest_pipeline.storage.base import BaseStorage
@pytest.mark.parametrize(
@@ -52,15 +55,34 @@ async def test_filter_existing_documents_task_filters_known_urls(
new_url = "https://new.example.com"
existing_id = str(FirecrawlIngestor.compute_document_id(existing_url))
class StubStorage:
display_name = "stub-storage"
class StubStorage(BaseStorage):
def __init__(self) -> None:
super().__init__(StorageConfig(backend=StorageBackend.WEAVIATE, endpoint="http://test.local"))
@property
def display_name(self) -> str:
return "stub-storage"
async def initialize(self) -> None:
pass
async def store(self, document: Document, *, collection_name: str | None = None) -> str:
return "stub-id"
async def store_batch(
self, documents: list[Document], *, collection_name: str | None = None
) -> list[str]:
return ["stub-id"] * len(documents)
async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool:
return True
async def check_exists(
self,
document_id: str,
*,
collection_name: str | None = None,
stale_after_days: int,
stale_after_days: int = 30,
) -> bool:
return document_id == existing_id
@@ -102,7 +124,9 @@ async def test_ingest_documents_task_invokes_helpers(
) -> object:
return storage_sentinel
async def fake_create_ingestor(job_arg: IngestionJob, config_block_name: str | None = None) -> object:
async def fake_create_ingestor(
job_arg: IngestionJob, config_block_name: str | None = None
) -> object:
return ingestor_sentinel
monkeypatch.setattr(ingestion, "_create_ingestor", fake_create_ingestor)

View File

@@ -4,6 +4,8 @@ from datetime import timedelta
from types import SimpleNamespace
import pytest
from prefect.deployments.runner import RunnerDeployment
from prefect.schedules import Cron, Interval
from ingest_pipeline.flows import scheduler
@@ -17,7 +19,6 @@ def test_create_scheduled_deployment_cron(monkeypatch: pytest.MonkeyPatch) -> No
captured |= kwargs
return SimpleNamespace(**kwargs)
monkeypatch.setattr(scheduler, "create_ingestion_flow", DummyFlow())
deployment = scheduler.create_scheduled_deployment(
@@ -28,8 +29,14 @@ def test_create_scheduled_deployment_cron(monkeypatch: pytest.MonkeyPatch) -> No
cron_expression="0 * * * *",
)
assert captured["schedule"].cron == "0 * * * *"
assert captured["parameters"]["source_type"] == "web"
schedule = captured["schedule"]
# Check that it's a cron schedule by verifying it has a cron attribute
assert hasattr(schedule, "cron")
assert schedule.cron == "0 * * * *"
parameters = captured["parameters"]
assert isinstance(parameters, dict)
assert parameters["source_type"] == "web"
assert deployment.tags == ["web", "weaviate"]
@@ -42,7 +49,6 @@ def test_create_scheduled_deployment_interval(monkeypatch: pytest.MonkeyPatch) -
captured |= kwargs
return SimpleNamespace(**kwargs)
monkeypatch.setattr(scheduler, "create_ingestion_flow", DummyFlow())
deployment = scheduler.create_scheduled_deployment(
@@ -55,7 +61,11 @@ def test_create_scheduled_deployment_interval(monkeypatch: pytest.MonkeyPatch) -
tags=["custom"],
)
assert captured["schedule"].interval == timedelta(minutes=15)
schedule = captured["schedule"]
# Check that it's an interval schedule by verifying it has an interval attribute
assert hasattr(schedule, "interval")
assert schedule.interval == timedelta(minutes=15)
assert captured["tags"] == ["custom"]
assert deployment.parameters["storage_backend"] == "open_webui"
@@ -63,12 +73,13 @@ def test_create_scheduled_deployment_interval(monkeypatch: pytest.MonkeyPatch) -
def test_serve_deployments_invokes_prefect(monkeypatch: pytest.MonkeyPatch) -> None:
called: dict[str, object] = {}
def fake_serve(*deployments: object, limit: int) -> None:
def fake_serve(*deployments: RunnerDeployment, limit: int) -> None:
called["deployments"] = deployments
called["limit"] = limit
monkeypatch.setattr(scheduler, "prefect_serve", fake_serve)
# Create a mock deployment using SimpleNamespace to avoid Prefect complexity
deployment = SimpleNamespace(name="only")
scheduler.serve_deployments([deployment])

View File

@@ -9,7 +9,7 @@ from ingest_pipeline.ingestors.repomix import RepomixIngestor
@pytest.mark.parametrize(
("content", "expected_keys"),
(
("## File: src/app.py\nprint(\"hi\")", ["src/app.py"]),
('## File: src/app.py\nprint("hi")', ["src/app.py"]),
("plain content without markers", ["repository"]),
),
)
@@ -44,7 +44,7 @@ def test_chunk_content_respects_max_size(
("file_path", "content", "expected"),
(
("src/app.py", "def feature():\n return True", "python"),
("scripts/run", "#!/usr/bin/env python\nprint(\"ok\")", "python"),
("scripts/run", '#!/usr/bin/env python\nprint("ok")', "python"),
("documentation.md", "# Title", "markdown"),
("unknown.ext", "text", None),
),
@@ -81,5 +81,8 @@ def test_create_document_enriches_metadata() -> None:
assert document.metadata["repository_name"] == "demo"
assert document.metadata["branch_name"] == "main"
assert document.metadata["commit_hash"] == "deadbeef"
assert document.metadata["title"].endswith("(chunk 1)")
title = document.metadata["title"]
assert title is not None
assert title.endswith("(chunk 1)")
assert document.collection == job.storage_backend.value

Some files were not shown because too many files have changed in this diff Show More