Enhance recovery and summarization services with asset path management
- Added `asset_path` to the `Meeting` entity for audio asset storage. - Implemented `AudioValidationResult` for audio integrity checks during recovery. - Updated `RecoveryService` to validate audio file integrity for crashed meetings. - Enhanced `SummarizationService` to include consent persistence callbacks. - Introduced new database migrations for `diarization_jobs` and `user_preferences` tables. - Refactored various components to support the new asset path and audio validation features. - Improved documentation in `CLAUDE.md` to reflect changes in recovery and summarization functionalities.
This commit is contained in:
161
CLAUDE.md
161
CLAUDE.md
@@ -29,23 +29,46 @@ ruff check . # Lint
|
||||
ruff check --fix . # Autofix
|
||||
mypy src/noteflow # Strict type checks
|
||||
basedpyright # Additional type checks
|
||||
|
||||
# Docker development
|
||||
docker compose up -d postgres # PostgreSQL with health checks
|
||||
python scripts/dev_watch_server.py # Auto-reload server (watches src/)
|
||||
```
|
||||
|
||||
## Docker Development
|
||||
|
||||
```bash
|
||||
# Start PostgreSQL (with pgvector)
|
||||
docker compose up -d postgres
|
||||
|
||||
# Dev container (VS Code) - full GUI environment
|
||||
# .devcontainer/ includes PortAudio, GTK, pystray, pynput support
|
||||
code . # Open in VS Code, select "Reopen in Container"
|
||||
|
||||
# Development server with auto-reload
|
||||
python scripts/dev_watch_server.py # Uses watchfiles, monitors src/ and alembic.ini
|
||||
```
|
||||
|
||||
Dev container features: dbus-x11, GTK-3, libgl1 for system tray and hotkey support.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
src/noteflow/
|
||||
├── domain/ # Entities (meeting, segment, annotation, summary) + ports (repository interfaces)
|
||||
├── application/ # Use-cases/services (MeetingService, RecoveryService, ExportService)
|
||||
├── domain/ # Entities (meeting, segment, annotation, summary, triggers) + ports
|
||||
├── application/ # Use-cases/services (MeetingService, RecoveryService, ExportService, SummarizationService, TriggerService)
|
||||
├── infrastructure/ # Implementations
|
||||
│ ├── audio/ # sounddevice capture, ring buffer, VU levels, playback
|
||||
│ ├── audio/ # sounddevice capture, ring buffer, VU levels, playback, buffered writer
|
||||
│ ├── asr/ # faster-whisper engine, VAD segmenter, streaming
|
||||
│ ├── diarization/ # Speaker diarization (streaming: diart, offline: pyannote.audio)
|
||||
│ ├── summarization/# Multi-provider summarization (CloudProvider, OllamaProvider) + citation verification
|
||||
│ ├── triggers/ # Auto-start signal providers (calendar, audio activity, foreground app)
|
||||
│ ├── persistence/ # SQLAlchemy + asyncpg + pgvector, Alembic migrations
|
||||
│ ├── security/ # keyring keystore, AES-GCM encryption
|
||||
│ ├── export/ # Markdown/HTML export
|
||||
│ └── converters/ # ORM ↔ domain entity converters
|
||||
├── grpc/ # Proto definitions, server, client, meeting store
|
||||
├── client/ # Flet UI app + components (transcript, VU meter, playback)
|
||||
├── grpc/ # Proto definitions, server, client, meeting store, modular mixins
|
||||
├── client/ # Flet UI app + components (transcript, VU meter, playback, trigger mixin)
|
||||
└── config/ # Pydantic settings (NOTEFLOW_ env vars)
|
||||
```
|
||||
|
||||
@@ -54,6 +77,26 @@ src/noteflow/
|
||||
- Repository pattern with Unit of Work (`SQLAlchemyUnitOfWork`)
|
||||
- gRPC bidirectional streaming for audio → transcript flow
|
||||
- Protocol-based DI (see `domain/ports/` and infrastructure `protocols.py` files)
|
||||
- Modular gRPC mixins for separation of concerns (see below)
|
||||
- `BackgroundWorkerMixin` for standardized thread lifecycle in components
|
||||
|
||||
## gRPC Mixin Architecture
|
||||
|
||||
The gRPC server uses modular mixins for maintainability:
|
||||
|
||||
```
|
||||
grpc/_mixins/
|
||||
├── streaming.py # ASR streaming, audio processing, partial buffers
|
||||
├── diarization.py # Speaker diarization jobs (background refinement, job TTL)
|
||||
├── summarization.py # Summary generation (separates LLM inference from DB transactions)
|
||||
├── meeting.py # Meeting lifecycle (create, get, list, delete)
|
||||
├── annotation.py # Segment annotations CRUD
|
||||
├── export.py # Markdown/HTML document export
|
||||
├── converters.py # Protobuf ↔ domain entity converters
|
||||
└── protocols.py # ServicerHost protocol for mixin composition
|
||||
```
|
||||
|
||||
Each mixin operates on `ServicerHost` protocol, enabling clean composition in `NoteFlowServicer`.
|
||||
|
||||
## Database
|
||||
|
||||
@@ -101,3 +144,111 @@ python -m grpc_tools.protoc -I src/noteflow/grpc/proto \
|
||||
- `spike_02_audio_capture/` - sounddevice + PortAudio
|
||||
- `spike_03_asr_latency/` - faster-whisper benchmarks (0.05x real-time)
|
||||
- `spike_04_encryption/` - keyring + AES-GCM (826 MB/s throughput)
|
||||
|
||||
## Key Subsystems
|
||||
|
||||
### Speaker Diarization
|
||||
- **Streaming**: diart for real-time speaker detection during recording
|
||||
- **Offline**: pyannote.audio for post-meeting refinement (higher quality)
|
||||
- **gRPC**: `RefineSpeakerDiarization` (background job), `GetDiarizationJobStatus` (polling), `RenameSpeaker`
|
||||
|
||||
### Summarization
|
||||
- **Providers**: CloudProvider (Anthropic/OpenAI), OllamaProvider (local), MockProvider (testing)
|
||||
- **Citation verification**: Links summary claims to transcript evidence
|
||||
- **Consent**: Cloud providers require explicit user consent (not yet persisted)
|
||||
|
||||
### Trigger Detection
|
||||
- **Signals**: Calendar proximity, audio activity, foreground app detection
|
||||
- **Actions**: IGNORE, NOTIFY, AUTO_START with confidence thresholds
|
||||
- **Client integration**: Background polling with dialog prompts (start/snooze/dismiss)
|
||||
|
||||
## Shared Utilities & Factories
|
||||
|
||||
### Factories
|
||||
|
||||
| Location | Function | Purpose |
|
||||
|----------|----------|---------|
|
||||
| `infrastructure/summarization/factory.py` | `create_summarization_service()` | Auto-configured service with provider detection |
|
||||
| `infrastructure/persistence/database.py` | `create_async_engine()` | SQLAlchemy async engine from settings |
|
||||
| `infrastructure/persistence/database.py` | `create_async_session_factory()` | Session factory from DB URL |
|
||||
| `config/settings.py` | `get_settings()` | Cached Settings from env vars |
|
||||
| `config/settings.py` | `get_trigger_settings()` | Cached TriggerSettings from env vars |
|
||||
|
||||
### Converters
|
||||
|
||||
| Location | Class/Function | Purpose |
|
||||
|----------|----------------|---------|
|
||||
| `infrastructure/converters/orm_converters.py` | `OrmConverter` | ORM ↔ domain entities (Meeting, Segment, Summary, etc.) |
|
||||
| `infrastructure/converters/asr_converters.py` | `AsrConverter` | ASR DTOs → domain WordTiming |
|
||||
| `grpc/_mixins/converters.py` | `meeting_to_proto()`, `segment_to_proto_update()` | Domain → protobuf messages |
|
||||
| `grpc/_mixins/converters.py` | `create_segment_from_asr()` | ASR result → Segment with word timings |
|
||||
|
||||
### Repository Base (`persistence/repositories/_base.py`)
|
||||
|
||||
| Method | Purpose |
|
||||
|--------|---------|
|
||||
| `_execute_scalar()` | Single result query (or None) |
|
||||
| `_execute_scalars()` | All scalar results from query |
|
||||
| `_add_and_flush()` | Add model and flush to DB |
|
||||
| `_delete_and_flush()` | Delete model and flush |
|
||||
|
||||
### Security Helpers (`infrastructure/security/keystore.py`)
|
||||
|
||||
| Function | Purpose |
|
||||
|----------|---------|
|
||||
| `_decode_and_validate_key()` | Validate base64 key, check size |
|
||||
| `_generate_key()` | Generate 256-bit key as `(bytes, base64_str)` |
|
||||
|
||||
### Export Helpers (`infrastructure/export/_formatting.py`)
|
||||
|
||||
| Function | Purpose |
|
||||
|----------|---------|
|
||||
| `format_timestamp()` | Seconds → `MM:SS` or `HH:MM:SS` |
|
||||
| `format_datetime()` | Datetime → display string |
|
||||
|
||||
### Summarization (`infrastructure/summarization/`)
|
||||
|
||||
| Location | Function | Purpose |
|
||||
|----------|----------|---------|
|
||||
| `_parsing.py` | `build_transcript_prompt()` | Transcript with segment markers for LLM |
|
||||
| `_parsing.py` | `parse_llm_response()` | JSON → Summary entity |
|
||||
| `citation_verifier.py` | `verify_citations()` | Validate segment_ids exist |
|
||||
|
||||
### Diarization (`infrastructure/diarization/assigner.py`)
|
||||
|
||||
| Function | Purpose |
|
||||
|----------|---------|
|
||||
| `assign_speaker()` | Speaker for time range from turns |
|
||||
| `assign_speakers_batch()` | Batch speaker assignment |
|
||||
|
||||
### Triggers (`infrastructure/triggers/calendar.py`)
|
||||
|
||||
| Function | Purpose |
|
||||
|----------|---------|
|
||||
| `parse_calendar_events()` | Parse events from config/env |
|
||||
|
||||
### Client Mixins (`client/components/`)
|
||||
|
||||
| Class | Purpose |
|
||||
|-------|---------|
|
||||
| `BackgroundWorkerMixin` | Thread lifecycle: `_start_worker()`, `_stop_worker()`, `_should_run()` |
|
||||
| `AsyncOperationMixin[T]` | Async ops with state: `run_async_operation()` |
|
||||
| `TriggerMixin` | Trigger signal polling |
|
||||
|
||||
### Recovery Service (`application/services/recovery_service.py`)
|
||||
|
||||
| Method | Purpose |
|
||||
|--------|---------|
|
||||
| `recover_all()` | Orchestrate meeting + job recovery |
|
||||
| `RecoveryResult` | Dataclass with recovery counts |
|
||||
|
||||
## Known Issues
|
||||
|
||||
See `docs/triage.md` for tracked technical debt.
|
||||
|
||||
**Resolved:**
|
||||
- ~~Server-side state volatility~~ → Diarization jobs persisted to DB
|
||||
- ~~Hardcoded directory paths~~ → `asset_path` column added to meetings
|
||||
- ~~Synchronous blocking in async gRPC~~ → `run_in_executor` for diarization
|
||||
- ~~Summarization consent not persisted~~ → Stored in `user_preferences` table
|
||||
- ~~VU meter update throttling~~ → 20fps throttle implemented
|
||||
|
||||
363
docs/triage.md
363
docs/triage.md
@@ -1,265 +1,196 @@
|
||||
This is a comprehensive code review of the `NoteFlow` repository.
|
||||
# Triage Review (Validated)
|
||||
|
||||
Overall, this codebase demonstrates a high level of engineering maturity. It effectively utilizes Clean Architecture concepts (Entities, Use Cases, Ports/Adapters), leveraging strong typing, Pydantic for validation, and SQLAlchemy/Alembic for persistence. The integration test setup using `testcontainers` is particularly robust.
|
||||
Validated: 2025-12-19
|
||||
Legend: Status = Confirmed, Partially confirmed, Not observed, Already implemented
|
||||
Citations use [path:line] format.
|
||||
|
||||
However, there are critical performance bottlenecks regarding async/sync bridging in the ASR engine, potential concurrency issues in the UI state management, and specific security considerations regarding the encryption implementation.
|
||||
## 2. Architecture & State Management
|
||||
|
||||
Below is the review categorized into actionable feedback, formatted to be convertible into Git issues.
|
||||
### Issue 2.1: Server-Side State Volatility
|
||||
|
||||
---
|
||||
Status: Confirmed
|
||||
Severity: High
|
||||
Location: src/noteflow/grpc/service.py, src/noteflow/grpc/_mixins/diarization.py, src/noteflow/grpc/server.py
|
||||
|
||||
## 1. Critical Architecture & Performance Issues
|
||||
Evidence:
|
||||
- In-memory stream state lives on the servicer (_active_streams, _audio_writers, _partial_buffers, _diarization_jobs). [src/noteflow/grpc/service.py:95] [src/noteflow/grpc/service.py:106] [src/noteflow/grpc/service.py:109] [src/noteflow/grpc/service.py:122]
|
||||
- Background diarization jobs are stored only in a dict and status reads from it. [src/noteflow/grpc/_mixins/diarization.py:241] [src/noteflow/grpc/_mixins/diarization.py:480]
|
||||
- Server shutdown only stops gRPC; no servicer cleanup hook is invoked. [src/noteflow/grpc/server.py:132]
|
||||
- Crash recovery marks meetings ERROR but does not validate audio assets. [src/noteflow/application/services/recovery_service.py:21] [src/noteflow/application/services/recovery_service.py:71]
|
||||
|
||||
### Issue 1: Blocking ASR Inference in Async gRPC Server
|
||||
**Severity:** Critical
|
||||
**Location:** `src/noteflow/grpc/service.py`, `src/noteflow/infrastructure/asr/engine.py`
|
||||
Example:
|
||||
- If job state is lost (for example, after a restart), polling will fail and surface a fetch error in the UI. [src/noteflow/grpc/_mixins/diarization.py:479] [src/noteflow/grpc/client.py:885] [src/noteflow/client/components/meeting_library.py:534]
|
||||
|
||||
**The Problem:**
|
||||
The `NoteFlowServer` uses `grpc.aio` (AsyncIO), but the `FasterWhisperEngine.transcribe` method is blocking (synchronous CPU-bound operation).
|
||||
In `NoteFlowServicer._maybe_emit_partial` and `_process_audio_segment`, the code calls:
|
||||
```python
|
||||
# src/noteflow/grpc/service.py
|
||||
partial_text = " ".join(result.text for result in self._asr_engine.transcribe(combined))
|
||||
```
|
||||
Since `transcribe` performs heavy computation, executing it directly within an `async def` method freezes the entire Python AsyncIO event loop. This blocks heartbeats, other RPC calls, and other concurrent meeting streams until inference completes.
|
||||
Reusable code locations:
|
||||
- `_close_audio_writer` already centralizes writer cleanup. [src/noteflow/grpc/service.py:247]
|
||||
- Migration patterns for new tables exist (annotations). [src/noteflow/infrastructure/persistence/migrations/versions/b5c3e8a2d1f0_add_annotations_table.py:22]
|
||||
|
||||
**Actionable Solution:**
|
||||
Offload the transcription to a separate thread pool executor.
|
||||
Actions:
|
||||
- Persist diarization jobs (table or cache) and query from `GetDiarizationJobStatus`.
|
||||
- Add a shutdown hook to close all `_audio_writers` and flush buffers.
|
||||
- Optional: add asset integrity checks after RecoveryService marks a meeting ERROR.
|
||||
|
||||
1. Modify `FasterWhisperEngine` to remain synchronous (it wraps CTranslate2 which releases the GIL often, but it is still blocking from an asyncio perspective).
|
||||
2. Update `NoteFlowServicer` to run transcription in an executor.
|
||||
### Issue 2.2: Implicit Meeting Asset Paths
|
||||
|
||||
```python
|
||||
# In NoteFlowServicer
|
||||
from functools import partial
|
||||
Status: Confirmed
|
||||
Severity: Medium
|
||||
Location: src/noteflow/infrastructure/audio/reader.py, src/noteflow/infrastructure/audio/writer.py, src/noteflow/infrastructure/persistence/models.py
|
||||
|
||||
# Helper method
|
||||
async def _run_transcription(self, audio):
|
||||
loop = asyncio.get_running_loop()
|
||||
# Use a ThreadPoolExecutor specifically for compute-heavy tasks
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
partial(list, self._asr_engine.transcribe(audio))
|
||||
)
|
||||
Evidence:
|
||||
- Audio assets are read/written under `meetings_dir / meeting_id`. [src/noteflow/infrastructure/audio/reader.py:72] [src/noteflow/infrastructure/audio/writer.py:94]
|
||||
- MeetingModel defines meeting fields (id/title/state/metadata/wrapped_dek) but no asset path. [src/noteflow/infrastructure/persistence/models.py:38] [src/noteflow/infrastructure/persistence/models.py:64]
|
||||
- Delete logic also assumes `meetings_dir / meeting_id`. [src/noteflow/application/services/meeting_service.py:195]
|
||||
|
||||
# Usage in _maybe_emit_partial
|
||||
results = await self._run_transcription(combined)
|
||||
partial_text = " ".join(r.text for r in results)
|
||||
```
|
||||
Example:
|
||||
- Record a meeting with `meetings_dir` set to `~/.noteflow/meetings`, then change it to `/mnt/noteflow`. Playback will look in the new base and fail to find older audio.
|
||||
|
||||
### Issue 2: Synchronous `sounddevice` Callbacks in Async Client App
|
||||
**Severity:** High
|
||||
**Location:** `src/noteflow/infrastructure/audio/capture.py`
|
||||
Reusable code locations:
|
||||
- `MeetingAudioWriter.open` and `MeetingAudioReader.load_meeting_audio` are the path entry points. [src/noteflow/infrastructure/audio/writer.py:70] [src/noteflow/infrastructure/audio/reader.py:60]
|
||||
- Migration templates live in `infrastructure/persistence/migrations/versions`. [src/noteflow/infrastructure/persistence/migrations/versions/b5c3e8a2d1f0_add_annotations_table.py:22]
|
||||
|
||||
**The Problem:**
|
||||
The `sounddevice` library calls the python callback from a C-level background thread. In `SoundDeviceCapture._stream_callback`, you are invoking the user-provided callback:
|
||||
```python
|
||||
self._callback(audio_data, timestamp)
|
||||
```
|
||||
In `app.py`, this callback (`_on_audio_frames`) interacts with `self._audio_activity.update` and `self._client.send_audio`. While `queue.put` is thread-safe, any heavy logic or object allocation here happens in the real-time audio thread. If Python garbage collection pauses this thread, audio artifacts (dropouts) will occur.
|
||||
Actions:
|
||||
- Add `asset_path` (or `storage_path`) column to meetings.
|
||||
- Store the relative path at creation time and use it on read/delete.
|
||||
|
||||
**Actionable Solution:**
|
||||
The callback should strictly put bytes into a thread-safe queue and return immediately. A separate consumer thread/task should process the VAD, VU meter logic, and network sending.
|
||||
## 3. Concurrency & Performance
|
||||
|
||||
### Issue 3: Encryption Key Material in Memory
|
||||
**Severity:** Medium
|
||||
**Location:** `src/noteflow/infrastructure/security/crypto.py`
|
||||
### Issue 3.1: Synchronous Blocking in Async gRPC (Streaming Diarization)
|
||||
|
||||
**The Problem:**
|
||||
The `AesGcmCryptoBox` keeps the master key in memory via `_get_master_cipher`. While inevitable for operation, `secrets.token_bytes` creates immutable bytes objects which cannot be zeroed out (wiped) from memory when no longer needed. Python's GC handles cleanup, but the key lingers in RAM.
|
||||
Status: Confirmed
|
||||
Severity: Medium
|
||||
Location: src/noteflow/grpc/_mixins/diarization.py, src/noteflow/grpc/_mixins/streaming.py
|
||||
|
||||
**Actionable Solution:**
|
||||
While strict memory zeroing is hard in Python, you should minimize the lifespan of the `dek` (Data Encryption Key).
|
||||
1. In `MeetingAudioWriter`, the `dek` is stored as an instance attribute: `self._dek`. This keeps the unencrypted key in memory for the duration of the meeting.
|
||||
2. Consider refactoring `ChunkedAssetWriter` to store the `cipher` object (the `AESGCM` context) rather than the raw bytes of the `dek` if the underlying C-library handles memory better, though strictly speaking, the key is still in RAM.
|
||||
3. **Critical:** Ensure `writer.close()` sets `self._dek = None` immediately (it currently does, which is good practice).
|
||||
Evidence:
|
||||
- `_process_streaming_diarization` calls `process_chunk` synchronously. [src/noteflow/grpc/_mixins/diarization.py:63] [src/noteflow/grpc/_mixins/diarization.py:92]
|
||||
- It is invoked in the async streaming loop on every chunk. [src/noteflow/grpc/_mixins/streaming.py:379]
|
||||
- Diarization engine uses pyannote/diart pipelines. [src/noteflow/infrastructure/diarization/engine.py:1]
|
||||
|
||||
---
|
||||
Example:
|
||||
- With diarization enabled on CPU, heavy `process_chunk` calls can stall the event loop, delaying transcript updates and heartbeats.
|
||||
|
||||
## 2. Domain & Infrastructure Logic
|
||||
Reusable code locations:
|
||||
- ASR already offloads blocking work via `run_in_executor`. [src/noteflow/infrastructure/asr/engine.py:156]
|
||||
- Offline diarization uses `asyncio.to_thread`. [src/noteflow/grpc/_mixins/diarization.py:305]
|
||||
|
||||
### Issue 4: Fallback Logic in `SummarizationService`
|
||||
**Severity:** Low
|
||||
**Location:** `src/noteflow/application/services/summarization_service.py`
|
||||
Actions:
|
||||
- Offload streaming diarization to a thread/process pool similar to ASR.
|
||||
- Consider a bounded queue so diarization lag does not backpressure streaming.
|
||||
|
||||
**The Problem:**
|
||||
The method `_get_provider_with_fallback` iterates through a hardcoded `fallback_order = [SummarizationMode.LOCAL, SummarizationMode.MOCK]`. This ignores the configuration order or user preference if they added new providers.
|
||||
### Issue 3.2: VU Meter UI Updates on Every Audio Chunk
|
||||
|
||||
**Actionable Solution:**
|
||||
Allow `SummarizationServiceSettings` to define a `fallback_chain: list[SummarizationMode]`.
|
||||
Status: Confirmed
|
||||
Severity: Medium
|
||||
Location: src/noteflow/client/app.py, src/noteflow/client/components/vu_meter.py
|
||||
|
||||
### Issue 5: Race Condition in `MeetingStore` (In-Memory)
|
||||
**Severity:** Medium
|
||||
**Location:** `src/noteflow/grpc/meeting_store.py`
|
||||
Evidence:
|
||||
- Audio capture uses 100ms chunks. [src/noteflow/client/app.py:307]
|
||||
- Each chunk triggers `VuMeterComponent.on_audio_frames`, which schedules a UI update. [src/noteflow/client/app.py:552] [src/noteflow/client/components/vu_meter.py:58]
|
||||
|
||||
**The Problem:**
|
||||
The `MeetingStore` uses `threading.RLock`. However, the methods return the actual `Meeting` object reference.
|
||||
```python
|
||||
def get(self, meeting_id: str) -> Meeting | None:
|
||||
with self._lock:
|
||||
return self._meetings.get(meeting_id)
|
||||
```
|
||||
The caller gets a reference to the mutable `Meeting` entity. If two threads get the meeting and modify it (e.g., `meeting.state = ...`), the `MeetingStore` lock does nothing to protect the entity itself, only the dictionary lookups.
|
||||
Example:
|
||||
- At 100ms chunks, the UI updates about 10 times per second. If chunk duration is lowered (e.g., 20ms), that becomes about 50 updates per second and can cause stutter.
|
||||
|
||||
**Actionable Solution:**
|
||||
1. Return deep copies of the Meeting object (performance impact).
|
||||
2. Or, implement specific atomic update methods on the Store (e.g., `update_status(id, status)`), rather than returning the whole object for modification.
|
||||
Reusable code locations:
|
||||
- Recording timer throttles updates with a fixed interval and background worker. [src/noteflow/client/components/recording_timer.py:14]
|
||||
|
||||
### Issue 6: `pgvector` Dependency Management
|
||||
**Severity:** Low
|
||||
**Location:** `src/noteflow/infrastructure/persistence/migrations/versions/6a9d9f408f40_initial_schema.py`
|
||||
Actions:
|
||||
- Throttle VU updates (for example, 20 fps) or update only when delta exceeds a threshold.
|
||||
|
||||
**The Problem:**
|
||||
The migration blindly executes `CREATE EXTENSION IF NOT EXISTS vector`. On managed database services (like RDS or standard Docker Postgres images), the user might not have superuser privileges to install extensions, or the extension binaries might be missing.
|
||||
## 4. Domain Logic & Reliability
|
||||
|
||||
**Actionable Solution:**
|
||||
Wrap the extension creation in a try/catch block or check capabilities. For the integration tests, ensure the `pgvector/pgvector:pg16` image is strictly pinned (which you have done, good job).
|
||||
### Issue 4.1: Summarization Consent Persistence
|
||||
|
||||
---
|
||||
Status: Confirmed
|
||||
Severity: Low (UX)
|
||||
Location: src/noteflow/application/services/summarization_service.py
|
||||
|
||||
## 3. Client & UI (Flet)
|
||||
Evidence:
|
||||
- Consent is stored on `SummarizationServiceSettings` and defaults to False. [src/noteflow/application/services/summarization_service.py:56]
|
||||
- `grant_cloud_consent` only mutates the in-memory settings. [src/noteflow/application/services/summarization_service.py:150]
|
||||
|
||||
### Issue 7: Massive `app.py` File Size
|
||||
**Severity:** Medium
|
||||
**Location:** `src/noteflow/client/app.py`
|
||||
Example:
|
||||
- Users who grant cloud consent must re-consent after every server/client restart.
|
||||
|
||||
**The Problem:**
|
||||
`app.py` is orchestrating too much. It handles UI layout, audio capture orchestration, gRPC client events, and state updates. It serves as a "God Class" Controller.
|
||||
Reusable code locations:
|
||||
- Existing persistence patterns for per-meeting data live in `infrastructure/persistence`. [src/noteflow/infrastructure/persistence/models.py:32]
|
||||
|
||||
**Actionable Solution:**
|
||||
Refactor into a `ClientController` class separate from the UI layout construction.
|
||||
1. `src/noteflow/client/controller.py`: Handles `NoteFlowClient`, `SoundDeviceCapture`, and updates `AppState`.
|
||||
2. `src/noteflow/client/views.py`: Accepts `AppState` and renders UI.
|
||||
Actions:
|
||||
- Persist consent in a preferences table or config file and hydrate on startup.
|
||||
|
||||
### Issue 8: Re-rendering Efficiency in Transcript
|
||||
**Severity:** Medium
|
||||
**Location:** `src/noteflow/client/components/transcript.py`
|
||||
### Issue 4.2: Annotation Validation / Point-in-Time Annotations
|
||||
|
||||
**The Problem:**
|
||||
`_render_final_segment` appends controls to `self._list_view.controls`. In Flet, modifying a large list of controls can become slow as the transcript grows (hundreds of segments).
|
||||
Status: Not observed (current code supports point annotations)
|
||||
Severity: Low
|
||||
Location: src/noteflow/domain/entities/annotation.py, src/noteflow/client/components/annotation_toolbar.py, src/noteflow/client/components/annotation_display.py
|
||||
|
||||
**Actionable Solution:**
|
||||
1. Implement a "virtualized" list or pagination if Flet supports it efficiently.
|
||||
2. If not, implement a sliding window rendering approach where only the last N segments + visible segments are rendered in the DOM, though this is complex in Flet.
|
||||
3. **Immediate fix:** Ensure `auto_scroll` is handled efficiently. The current implementation clears and re-adds specific rows during search, which is heavy.
|
||||
Evidence:
|
||||
- Validation allows `end_time == start_time`. [src/noteflow/domain/entities/annotation.py:37]
|
||||
- UI creates point annotations by setting `start_time == end_time`. [src/noteflow/client/components/annotation_toolbar.py:189]
|
||||
- Display uses `start_time` only, not duration. [src/noteflow/client/components/annotation_display.py:164]
|
||||
|
||||
---
|
||||
Example:
|
||||
- Clicking a point annotation seeks to the exact timestamp (no duration needed).
|
||||
|
||||
## 4. Specific Code Feedback (Nitpicks & Bugs)
|
||||
Reusable code locations:
|
||||
- If range annotations are introduced later, `AnnotationDisplayComponent` is where a start-end range would be rendered. [src/noteflow/client/components/annotation_display.py:148]
|
||||
|
||||
### 1. Hardcoded Audio Constants
|
||||
**File:** `src/noteflow/infrastructure/asr/segmenter.py`
|
||||
The `SegmenterConfig` defaults to `sample_rate=16000`.
|
||||
The `SoundDeviceCapture` defaults to `16000`.
|
||||
**Risk:** If the server is configured for 44.1kHz, the client currently defaults to 16kHz hardcoded in several places.
|
||||
**Fix:** Ensure `DEFAULT_SAMPLE_RATE` from `src/noteflow/config/constants.py` is used everywhere.
|
||||
Action:
|
||||
- None required now; revisit if range annotations are added to the UI or exports.
|
||||
|
||||
### 2. Exception Swallowing in Audio Writer
|
||||
**File:** `src/noteflow/grpc/service.py` -> `_write_audio_chunk_safe`
|
||||
```python
|
||||
except Exception as e:
|
||||
logger.error("Failed to write audio chunk: %s", e)
|
||||
```
|
||||
If the disk fills up or permissions change, the audio writer fails silently (just logging), but the meeting continues. The user might lose the audio recording entirely while thinking it's safe.
|
||||
**Fix:** This error should probably trigger a circuit breaker that stops the recording or notifies the client via a gRPC status update or a metadata stream update.
|
||||
## 5. Suggested Git Issues (Validated)
|
||||
|
||||
### 3. Trigger Service Rate Limiting Logic
|
||||
**File:** `src/noteflow/application/services/trigger_service.py`
|
||||
In `_determine_action`:
|
||||
```python
|
||||
if self._last_prompt is not None:
|
||||
elapsed = now - self._last_prompt
|
||||
if elapsed < self._settings.rate_limit_seconds:
|
||||
return TriggerAction.IGNORE
|
||||
```
|
||||
This logic ignores *all* triggers if within the rate limit. If a **high confidence** trigger (Auto-start) comes in 10 seconds after a low confidence prompt, it gets ignored.
|
||||
**Fix:** The rate limit should likely apply to `NOTIFY` actions, but `AUTO_START` might need to bypass the rate limit or have a shorter one.
|
||||
Issue A: Persist Diarization Jobs to Database
|
||||
Status: Confirmed
|
||||
Evidence: Jobs live in-memory and `GetDiarizationJobStatus` reads from the dict. [src/noteflow/grpc/_mixins/diarization.py:241] [src/noteflow/grpc/_mixins/diarization.py:480]
|
||||
Reusable code locations:
|
||||
- SQLAlchemy models/migrations patterns. [src/noteflow/infrastructure/persistence/models.py:32] [src/noteflow/infrastructure/persistence/migrations/versions/b5c3e8a2d1f0_add_annotations_table.py:22]
|
||||
Tasks:
|
||||
- Add JobModel with status fields.
|
||||
- Update mixin to persist and query DB.
|
||||
|
||||
### 4. Database Session Lifecycle in UoW
|
||||
**File:** `src/noteflow/infrastructure/persistence/unit_of_work.py`
|
||||
The `__init__` does not create the session, `__aenter__` does. This is correct. However, `SqlAlchemyUnitOfWork` caches repositories:
|
||||
```python
|
||||
self._annotations_repo = SqlAlchemyAnnotationRepository(self._session)
|
||||
```
|
||||
If `__aenter__` is called, `__aexit__` closes the session. If the same UoW instance is reused (calling `async with uow:` again), it creates a *new* session but overwrites the repo references. This is generally safe, but verify that `SqlAlchemyUnitOfWork` instances are intended to be reusable or disposable. Currently, they look reusable, which is fine.
|
||||
Issue B: Implement Audio Writer Buffering
|
||||
Status: Already implemented (close or re-scope)
|
||||
Evidence: `MeetingAudioWriter` buffers and flushes based on `buffer_size`. [src/noteflow/infrastructure/audio/writer.py:126]
|
||||
Reusable code locations:
|
||||
- `AUDIO_BUFFER_SIZE_BYTES` constant. [src/noteflow/config/constants.py:26]
|
||||
Tasks:
|
||||
- None, unless you want to tune buffer size.
|
||||
|
||||
### 5. Frontend Polling vs Events
|
||||
**File:** `src/noteflow/client/components/playback_sync.py`
|
||||
`POSITION_POLL_INTERVAL = 0.1`.
|
||||
Using a thread to poll `self._state.playback.current_position` every 100ms is CPU inefficient in Python (due to GIL).
|
||||
**Suggestion:** Use the `sounddevice` stream callback time info to update the position state only when audio is actually playing, rather than a separate `while True` loop.
|
||||
Issue C: Fallback for Headless Keyring
|
||||
Status: Confirmed
|
||||
Evidence: `KeyringKeyStore` only falls back to env var, not file storage. [src/noteflow/infrastructure/security/keystore.py:49]
|
||||
Reusable code locations:
|
||||
- `KeyringKeyStore` and `InMemoryKeyStore` live in the same module. [src/noteflow/infrastructure/security/keystore.py:35]
|
||||
Tasks:
|
||||
- Add `FileKeyStore` and wire fallback in server/service initialization.
|
||||
|
||||
---
|
||||
Issue D: Throttled VU Meter in Client
|
||||
Status: Confirmed
|
||||
Evidence: Each chunk schedules a UI update with no throttle. [src/noteflow/client/components/vu_meter.py:58]
|
||||
Reusable code locations:
|
||||
- Background worker/throttle pattern in `RecordingTimerComponent`. [src/noteflow/client/components/recording_timer.py:14]
|
||||
Tasks:
|
||||
- Add a `last_update_time` and update at fixed cadence.
|
||||
|
||||
## 5. Security Review
|
||||
Issue E: Explicit Asset Path Storage
|
||||
Status: Confirmed
|
||||
Evidence: Meeting paths derived from `meetings_dir / meeting_id`. [src/noteflow/infrastructure/audio/reader.py:72]
|
||||
Reusable code locations:
|
||||
- Meeting model + migrations. [src/noteflow/infrastructure/persistence/models.py:32]
|
||||
Tasks:
|
||||
- Add `asset_path` column and persist at create time.
|
||||
|
||||
### 1. Keyring Headless Failure
|
||||
**File:** `src/noteflow/infrastructure/security/keystore.py`
|
||||
**Risk:** The app crashes if `keyring` cannot find a backend (common in Docker/Headless Linux servers).
|
||||
**Fix:**
|
||||
```python
|
||||
except keyring.errors.KeyringError:
|
||||
logger.warning("Keyring unavailable, falling back to environment variable or temporary key")
|
||||
# Implement a fallback strategy or explicit failure
|
||||
```
|
||||
Currently, it raises `RuntimeError`, which crashes the server startup.
|
||||
Issue F: PGVector Index Creation
|
||||
Status: Confirmed (requires product decision)
|
||||
Evidence: Migration uses `ivfflat` index created immediately. [src/noteflow/infrastructure/persistence/migrations/versions/6a9d9f408f40_initial_schema.py:95]
|
||||
Reusable code locations:
|
||||
- Same migration file for index changes. [src/noteflow/infrastructure/persistence/migrations/versions/6a9d9f408f40_initial_schema.py:95]
|
||||
Tasks:
|
||||
- Consider switching to HNSW or defer index creation until data exists.
|
||||
|
||||
### 2. DEK Handling
|
||||
**Analysis:** You generate a DEK, wrap it, and store `wrapped_dek` in the DB. The `dek` stays in memory during the stream.
|
||||
**Verdict:** This is standard envelope encryption practice. Acceptable for this application tier.
|
||||
## 6. Code Quality & Nitpicks (Validated)
|
||||
|
||||
---
|
||||
|
||||
## 6. Generated Issues for Git
|
||||
|
||||
### Issue: Asynchronous Transcription Processing
|
||||
**Title:** Refactor ASR Engine to run in ThreadPoolExecutor
|
||||
**Description:**
|
||||
The gRPC server uses `asyncio`, but `FasterWhisperEngine.transcribe` is blocking. This freezes the event loop during transcription segments.
|
||||
**Task:**
|
||||
1. Inject `asyncio.get_running_loop()` into `NoteFlowServicer`.
|
||||
2. Wrap `self._asr_engine.transcribe` calls in `loop.run_in_executor`.
|
||||
|
||||
### Issue: Client Audio Callback Optimization
|
||||
**Title:** Optimize Audio Capture Callback
|
||||
**Description:**
|
||||
`SoundDeviceCapture` callback executes application logic (network sending, VAD updates) in the audio thread.
|
||||
**Task:**
|
||||
1. Change callback to only `queue.put_nowait()`.
|
||||
2. Move logic to a dedicated consumer worker thread.
|
||||
|
||||
### Issue: Handle Write Errors in Audio Stream
|
||||
**Title:** Critical Error Handling for Audio Writer
|
||||
**Description:**
|
||||
`_write_audio_chunk_safe` catches exceptions and logs them, potentially resulting in data loss without user feedback.
|
||||
**Task:**
|
||||
1. If writing fails, update the meeting state to `ERROR`.
|
||||
2. Send an error message back to the client via the Transcript stream if possible, or terminate the connection.
|
||||
|
||||
### Issue: Database Extension Installation Check
|
||||
**Title:** Graceful degradation for `pgvector`
|
||||
**Description:**
|
||||
Migration script `6a9d9f408f40` attempts to create an extension. This fails if the DB user isn't superuser.
|
||||
**Task:**
|
||||
1. Check if extension exists or if user has permissions.
|
||||
2. If not, fail with a clear message about required database setup steps.
|
||||
|
||||
### Issue: Foreground App Window Detection on Linux/Headless
|
||||
**Title:** Handle `pywinctl` dependencies
|
||||
**Description:**
|
||||
`pywinctl` requires X11/display headers on Linux. The server might run headless.
|
||||
**Task:**
|
||||
1. Wrap `ForegroundAppProvider` imports in try/except blocks.
|
||||
2. Ensure the app doesn't crash if `pywinctl` fails to load.
|
||||
|
||||
---
|
||||
|
||||
## 7. Packaging & Deployment (Future)
|
||||
|
||||
Since you mentioned packaging is a WIP:
|
||||
1. **Dependencies:** Separating `server` deps (torch, faster-whisper) from `client` deps (flet, sounddevice) is crucial. Use `pyproject.toml` extras: `pip install noteflow[server]` vs `noteflow[client]`.
|
||||
2. **Model Management:** The Docker image for the server will be huge due to Torch/Whisper. Consider a build stage that pre-downloads the "base" model so the container starts faster.
|
||||
|
||||
## Conclusion
|
||||
|
||||
The code is high quality, well-typed, and structurally sound. Fixing the **Blocking ASR** issue is the only mandatory change before any serious load testing or deployment. The rest are robustness and architectural improvements.
|
||||
- HTML export template is minimal inline CSS (no external framework). [src/noteflow/infrastructure/export/html.py:33]
|
||||
Example: optionally add a lightweight stylesheet (with offline fallback) for nicer exports.
|
||||
- Transcript partial row is updated in place (no remove/re-add), so flicker risk is already mitigated. [src/noteflow/client/components/transcript.py:182]
|
||||
- `Segment.word_count` uses `text.split()` when no words are present. [src/noteflow/domain/entities/segment.py:72]
|
||||
Example: for very large transcripts, a streaming count (for example, regex iterator) avoids allocating a full list.
|
||||
|
||||
28577
repomix-output.md
28577
repomix-output.md
File diff suppressed because one or more lines are too long
@@ -6,12 +6,12 @@
|
||||
"output": {
|
||||
"filePath": "repomix-output.md",
|
||||
"style": "markdown",
|
||||
"parsableStyle": false,
|
||||
"parsableStyle": true,
|
||||
"fileSummary": true,
|
||||
"directoryStructure": true,
|
||||
"files": true,
|
||||
"removeComments": false,
|
||||
"removeEmptyLines": false,
|
||||
"removeComments": true,
|
||||
"removeEmptyLines": true,
|
||||
"compress": false,
|
||||
"topFilesLength": 5,
|
||||
"showLineNumbers": false,
|
||||
|
||||
@@ -21,7 +21,9 @@ def run_server() -> None:
|
||||
def main() -> None:
|
||||
root = Path(__file__).resolve().parents[1]
|
||||
watch_paths = [root / "src" / "noteflow", root / "alembic.ini"]
|
||||
existing_paths = [str(path) for path in watch_paths if path.exists()] or [str(root / "src" / "noteflow")]
|
||||
existing_paths = [str(path) for path in watch_paths if path.exists()] or [
|
||||
str(root / "src" / "noteflow")
|
||||
]
|
||||
|
||||
run_process(
|
||||
*existing_paths,
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
"""Recovery service for crash recovery on startup.
|
||||
|
||||
Detect and recover meetings left in active states after server restart.
|
||||
Optionally validate audio file integrity for crashed meetings.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from noteflow.domain.value_objects import MeetingState
|
||||
@@ -18,11 +21,37 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AudioValidationResult:
|
||||
"""Result of audio file validation for a meeting."""
|
||||
|
||||
is_valid: bool
|
||||
manifest_exists: bool
|
||||
audio_exists: bool
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RecoveryResult:
|
||||
"""Result of crash recovery operations."""
|
||||
|
||||
meetings_recovered: int
|
||||
diarization_jobs_failed: int
|
||||
audio_validation_failures: int = 0
|
||||
|
||||
@property
|
||||
def total_recovered(self) -> int:
|
||||
"""Total items recovered across all types."""
|
||||
return self.meetings_recovered + self.diarization_jobs_failed
|
||||
|
||||
|
||||
class RecoveryService:
|
||||
"""Recover meetings from crash states on server startup.
|
||||
|
||||
Find meetings left in RECORDING or STOPPING state and mark them as ERROR.
|
||||
This handles the case where the server crashed during an active meeting.
|
||||
|
||||
Optionally validates audio file integrity if crypto and meetings_dir are provided.
|
||||
"""
|
||||
|
||||
ACTIVE_STATES: ClassVar[list[MeetingState]] = [
|
||||
@@ -30,22 +59,92 @@ class RecoveryService:
|
||||
MeetingState.STOPPING,
|
||||
]
|
||||
|
||||
def __init__(self, uow: UnitOfWork) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
uow: UnitOfWork,
|
||||
meetings_dir: Path | None = None,
|
||||
) -> None:
|
||||
"""Initialize recovery service.
|
||||
|
||||
Args:
|
||||
uow: Unit of work for persistence.
|
||||
meetings_dir: Optional meetings directory for audio validation.
|
||||
If provided, validates that audio files exist for crashed meetings.
|
||||
"""
|
||||
self._uow = uow
|
||||
self._meetings_dir = meetings_dir
|
||||
|
||||
async def recover_crashed_meetings(self) -> list[Meeting]:
|
||||
def _validate_meeting_audio(self, meeting: Meeting) -> AudioValidationResult:
|
||||
"""Validate audio files for a crashed meeting.
|
||||
|
||||
Check that manifest.json and audio.enc exist in the meeting directory.
|
||||
|
||||
Args:
|
||||
meeting: Meeting to validate.
|
||||
|
||||
Returns:
|
||||
AudioValidationResult with validation status.
|
||||
"""
|
||||
if self._meetings_dir is None:
|
||||
return AudioValidationResult(
|
||||
is_valid=True,
|
||||
manifest_exists=True,
|
||||
audio_exists=True,
|
||||
error_message="Audio validation skipped (no meetings_dir configured)",
|
||||
)
|
||||
|
||||
# Prefer explicit asset_path; fall back to metadata for backward compatibility
|
||||
default_path = str(meeting.id)
|
||||
asset_path = meeting.asset_path or default_path
|
||||
if asset_path == default_path:
|
||||
asset_path = meeting.metadata.get("asset_path") or asset_path
|
||||
meeting_dir = self._meetings_dir / asset_path
|
||||
|
||||
manifest_path = meeting_dir / "manifest.json"
|
||||
audio_path = meeting_dir / "audio.enc"
|
||||
|
||||
manifest_exists = manifest_path.exists()
|
||||
audio_exists = audio_path.exists()
|
||||
|
||||
if not manifest_exists and not audio_exists:
|
||||
return AudioValidationResult(
|
||||
is_valid=False,
|
||||
manifest_exists=False,
|
||||
audio_exists=False,
|
||||
error_message="Meeting directory missing or empty",
|
||||
)
|
||||
|
||||
if not manifest_exists:
|
||||
return AudioValidationResult(
|
||||
is_valid=False,
|
||||
manifest_exists=False,
|
||||
audio_exists=audio_exists,
|
||||
error_message="manifest.json not found",
|
||||
)
|
||||
|
||||
if not audio_exists:
|
||||
return AudioValidationResult(
|
||||
is_valid=False,
|
||||
manifest_exists=True,
|
||||
audio_exists=False,
|
||||
error_message="audio.enc not found",
|
||||
)
|
||||
|
||||
return AudioValidationResult(
|
||||
is_valid=True,
|
||||
manifest_exists=True,
|
||||
audio_exists=True,
|
||||
)
|
||||
|
||||
async def recover_crashed_meetings(self) -> tuple[list[Meeting], int]:
|
||||
"""Find and recover meetings left in active states.
|
||||
|
||||
Mark all meetings in RECORDING or STOPPING state as ERROR
|
||||
with metadata explaining the crash recovery.
|
||||
with metadata explaining the crash recovery. Also validates
|
||||
audio file integrity if meetings_dir is configured.
|
||||
|
||||
Returns:
|
||||
List of recovered meetings.
|
||||
Tuple of (recovered meetings, audio validation failure count).
|
||||
"""
|
||||
async with self._uow:
|
||||
# Find all meetings in active states
|
||||
@@ -56,7 +155,7 @@ class RecoveryService:
|
||||
|
||||
if total == 0:
|
||||
logger.info("No crashed meetings found during recovery")
|
||||
return []
|
||||
return [], 0
|
||||
|
||||
logger.warning(
|
||||
"Found %d meetings in active state during startup, marking as ERROR",
|
||||
@@ -64,6 +163,7 @@ class RecoveryService:
|
||||
)
|
||||
|
||||
recovered: list[Meeting] = []
|
||||
audio_failures = 0
|
||||
recovery_time = datetime.now(UTC).isoformat()
|
||||
|
||||
for meeting in meetings:
|
||||
@@ -75,18 +175,35 @@ class RecoveryService:
|
||||
meeting.metadata["crash_recovery_time"] = recovery_time
|
||||
meeting.metadata["crash_previous_state"] = previous_state
|
||||
|
||||
# Validate audio files if configured
|
||||
validation = self._validate_meeting_audio(meeting)
|
||||
meeting.metadata["audio_valid"] = str(validation.is_valid).lower()
|
||||
if not validation.is_valid:
|
||||
audio_failures += 1
|
||||
meeting.metadata["audio_error"] = validation.error_message or "unknown"
|
||||
logger.warning(
|
||||
"Audio validation failed for meeting %s: %s",
|
||||
meeting.id,
|
||||
validation.error_message,
|
||||
)
|
||||
|
||||
await self._uow.meetings.update(meeting)
|
||||
recovered.append(meeting)
|
||||
|
||||
logger.info(
|
||||
"Recovered crashed meeting: id=%s, previous_state=%s",
|
||||
"Recovered crashed meeting: id=%s, previous_state=%s, audio_valid=%s",
|
||||
meeting.id,
|
||||
previous_state,
|
||||
validation.is_valid,
|
||||
)
|
||||
|
||||
await self._uow.commit()
|
||||
logger.info("Crash recovery complete: %d meetings recovered", len(recovered))
|
||||
return recovered
|
||||
logger.info(
|
||||
"Crash recovery complete: %d meetings recovered, %d audio failures",
|
||||
len(recovered),
|
||||
audio_failures,
|
||||
)
|
||||
return recovered, audio_failures
|
||||
|
||||
async def count_crashed_meetings(self) -> int:
|
||||
"""Count meetings currently in crash states.
|
||||
@@ -99,3 +216,53 @@ class RecoveryService:
|
||||
for state in self.ACTIVE_STATES:
|
||||
total += await self._uow.meetings.count_by_state(state)
|
||||
return total
|
||||
|
||||
async def recover_crashed_diarization_jobs(self) -> int:
|
||||
"""Mark diarization jobs left in running states as failed.
|
||||
|
||||
Find all diarization jobs in QUEUED or RUNNING state and mark them
|
||||
as FAILED with an error message explaining the crash recovery.
|
||||
|
||||
Returns:
|
||||
Number of jobs marked as failed.
|
||||
"""
|
||||
async with self._uow:
|
||||
failed_count = await self._uow.diarization_jobs.mark_running_as_failed()
|
||||
await self._uow.commit()
|
||||
|
||||
if failed_count > 0:
|
||||
logger.warning(
|
||||
"Marked %d diarization jobs as failed during crash recovery",
|
||||
failed_count,
|
||||
)
|
||||
else:
|
||||
logger.info("No crashed diarization jobs found during recovery")
|
||||
|
||||
return failed_count
|
||||
|
||||
async def recover_all(self) -> RecoveryResult:
|
||||
"""Run all crash recovery operations.
|
||||
|
||||
Recovers crashed meetings and failed diarization jobs in a single
|
||||
coordinated operation.
|
||||
|
||||
Returns:
|
||||
RecoveryResult with counts of recovered items.
|
||||
"""
|
||||
meetings, audio_failures = await self.recover_crashed_meetings()
|
||||
jobs = await self.recover_crashed_diarization_jobs()
|
||||
|
||||
result = RecoveryResult(
|
||||
meetings_recovered=len(meetings),
|
||||
diarization_jobs_failed=jobs,
|
||||
audio_validation_failures=audio_failures,
|
||||
)
|
||||
|
||||
if result.total_recovered > 0:
|
||||
logger.warning(
|
||||
"Crash recovery complete: %d meetings, %d diarization jobs",
|
||||
result.meetings_recovered,
|
||||
result.diarization_jobs_failed,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -27,6 +27,9 @@ if TYPE_CHECKING:
|
||||
# Type alias for persistence callback
|
||||
PersistCallback = Callable[[Summary], Awaitable[None]]
|
||||
|
||||
# Type alias for consent persistence callback
|
||||
ConsentPersistCallback = Callable[[bool], Awaitable[None]]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -102,6 +105,7 @@ class SummarizationService:
|
||||
verifier: CitationVerifier | None = None
|
||||
settings: SummarizationServiceSettings = field(default_factory=SummarizationServiceSettings)
|
||||
on_persist: PersistCallback | None = None
|
||||
on_consent_change: ConsentPersistCallback | None = None
|
||||
|
||||
def register_provider(self, mode: SummarizationMode, provider: SummarizerProvider) -> None:
|
||||
"""Register a provider for a specific mode.
|
||||
@@ -147,15 +151,19 @@ class SummarizationService:
|
||||
"""
|
||||
return mode in self.get_available_modes()
|
||||
|
||||
def grant_cloud_consent(self) -> None:
|
||||
async def grant_cloud_consent(self) -> None:
|
||||
"""Grant consent for cloud processing."""
|
||||
self.settings.cloud_consent_granted = True
|
||||
logger.info("Cloud consent granted")
|
||||
if self.on_consent_change:
|
||||
await self.on_consent_change(True)
|
||||
|
||||
def revoke_cloud_consent(self) -> None:
|
||||
async def revoke_cloud_consent(self) -> None:
|
||||
"""Revoke consent for cloud processing."""
|
||||
self.settings.cloud_consent_granted = False
|
||||
logger.info("Cloud consent revoked")
|
||||
if self.on_consent_change:
|
||||
await self.on_consent_change(False)
|
||||
|
||||
async def summarize(
|
||||
self,
|
||||
|
||||
@@ -5,8 +5,6 @@ Usage:
|
||||
python -m noteflow.cli.retention status
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
@@ -4,8 +4,6 @@ Provides standardized thread start/stop patterns for UI components
|
||||
that need background polling or timer threads.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
@@ -408,7 +408,10 @@ class MeetingLibraryComponent:
|
||||
if num_speakers < 1:
|
||||
num_speakers = None
|
||||
except ValueError:
|
||||
logger.debug("Invalid speaker count input '%s', using auto-detection", self._num_speakers_field.value)
|
||||
logger.debug(
|
||||
"Invalid speaker count input '%s', using auto-detection",
|
||||
self._num_speakers_field.value,
|
||||
)
|
||||
|
||||
meeting_id = self._state.selected_meeting.id
|
||||
self._close_analyze_dialog()
|
||||
|
||||
@@ -5,7 +5,8 @@ Uses RmsLevelProvider from AppState (not a new instance).
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Final
|
||||
|
||||
import flet as ft
|
||||
import numpy as np
|
||||
@@ -14,6 +15,9 @@ from numpy.typing import NDArray
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.client.state import AppState
|
||||
|
||||
# Throttle UI updates to 20 fps (50ms interval)
|
||||
VU_UPDATE_INTERVAL: Final[float] = 0.05
|
||||
|
||||
|
||||
class VuMeterComponent:
|
||||
"""Audio level visualization component.
|
||||
@@ -31,6 +35,7 @@ class VuMeterComponent:
|
||||
# REUSE level_provider from state - do not create new instance
|
||||
self._progress_bar: ft.ProgressBar | None = None
|
||||
self._label: ft.Text | None = None
|
||||
self._last_update_time: float = 0.0
|
||||
|
||||
def build(self) -> ft.Row:
|
||||
"""Build VU meter UI elements.
|
||||
@@ -59,10 +64,17 @@ class VuMeterComponent:
|
||||
"""Process incoming audio frames for level metering.
|
||||
|
||||
Uses state.level_provider.get_db() - existing RmsLevelProvider method.
|
||||
Throttled to VU_UPDATE_INTERVAL to avoid excessive UI updates.
|
||||
|
||||
Args:
|
||||
frames: Audio samples as float32 array.
|
||||
"""
|
||||
now = time.time()
|
||||
if now - self._last_update_time < VU_UPDATE_INTERVAL:
|
||||
return # Throttle: skip update if within interval
|
||||
|
||||
self._last_update_time = now
|
||||
|
||||
# REUSE existing RmsLevelProvider from state
|
||||
db_level = self._state.level_provider.get_db(frames)
|
||||
self._state.current_db_level = db_level
|
||||
|
||||
@@ -4,8 +4,6 @@ Composes existing types from grpc.client and infrastructure.audio.
|
||||
Does not recreate any dataclasses - imports and uses existing ones.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
@@ -4,8 +4,6 @@ This module provides shared constants used across the codebase to avoid
|
||||
magic numbers and ensure consistency.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Final
|
||||
|
||||
# Audio constants
|
||||
@@ -21,3 +19,10 @@ DEFAULT_GRPC_PORT: Final[int] = 50051
|
||||
|
||||
MAX_GRPC_MESSAGE_SIZE: Final[int] = 100 * 1024 * 1024
|
||||
"""Maximum gRPC message size in bytes (100 MB)."""
|
||||
|
||||
# Audio encryption buffering constants
|
||||
AUDIO_BUFFER_SIZE_BYTES: Final[int] = 320_000
|
||||
"""Target audio buffer size before encryption (320 KB = ~10 seconds at 16kHz PCM16)."""
|
||||
|
||||
PERIODIC_FLUSH_INTERVAL_SECONDS: Final[float] = 2.0
|
||||
"""Interval for periodic audio buffer flush to disk (crash resilience)."""
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""NoteFlow application settings using Pydantic settings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
@@ -32,6 +32,7 @@ class Meeting:
|
||||
summary: Summary | None = None
|
||||
metadata: dict[str, str] = field(default_factory=dict)
|
||||
wrapped_dek: bytes | None = None # Encrypted data encryption key
|
||||
asset_path: str | None = None # Relative path for audio assets
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -60,6 +61,7 @@ class Meeting:
|
||||
state=MeetingState.CREATED,
|
||||
created_at=now,
|
||||
metadata=metadata or {},
|
||||
asset_path=str(meeting_id), # Default to meeting ID
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -73,6 +75,7 @@ class Meeting:
|
||||
ended_at: datetime | None = None,
|
||||
metadata: dict[str, str] | None = None,
|
||||
wrapped_dek: bytes | None = None,
|
||||
asset_path: str | None = None,
|
||||
) -> Meeting:
|
||||
"""Create meeting with existing UUID string.
|
||||
|
||||
@@ -85,6 +88,7 @@ class Meeting:
|
||||
ended_at: End timestamp.
|
||||
metadata: Meeting metadata.
|
||||
wrapped_dek: Encrypted data encryption key.
|
||||
asset_path: Relative path for audio assets.
|
||||
|
||||
Returns:
|
||||
Meeting instance with specified ID.
|
||||
@@ -99,6 +103,7 @@ class Meeting:
|
||||
ended_at=ended_at,
|
||||
metadata=metadata or {},
|
||||
wrapped_dek=wrapped_dek,
|
||||
asset_path=asset_path or uuid_str,
|
||||
)
|
||||
|
||||
def start_recording(self) -> None:
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
Define trigger signals, decisions, and actions for meeting detection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
@@ -4,8 +4,8 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
@@ -17,6 +17,10 @@ from noteflow.domain.entities import Segment
|
||||
from noteflow.domain.value_objects import MeetingId, MeetingState
|
||||
from noteflow.infrastructure.audio.reader import MeetingAudioReader
|
||||
from noteflow.infrastructure.diarization import SpeakerTurn, assign_speaker
|
||||
from noteflow.infrastructure.persistence.repositories import (
|
||||
DiarizationJob,
|
||||
StreamingTurn,
|
||||
)
|
||||
|
||||
from ..proto import noteflow_pb2
|
||||
|
||||
@@ -26,31 +30,6 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DiarizationJob:
|
||||
"""Track background diarization job state."""
|
||||
|
||||
job_id: str
|
||||
meeting_id: str
|
||||
status: int
|
||||
segments_updated: int = 0
|
||||
speaker_ids: list[str] = field(default_factory=list)
|
||||
error_message: str = ""
|
||||
created_at: float = field(default_factory=time.time)
|
||||
updated_at: float = field(default_factory=time.time)
|
||||
task: asyncio.Task[None] | None = None
|
||||
|
||||
def to_proto(self) -> noteflow_pb2.DiarizationJobStatus:
|
||||
"""Convert to protobuf message."""
|
||||
return noteflow_pb2.DiarizationJobStatus(
|
||||
job_id=self.job_id,
|
||||
status=self.status,
|
||||
segments_updated=self.segments_updated,
|
||||
speaker_ids=self.speaker_ids,
|
||||
error_message=self.error_message,
|
||||
)
|
||||
|
||||
|
||||
class DiarizationMixin:
|
||||
"""Mixin providing speaker diarization functionality.
|
||||
|
||||
@@ -60,12 +39,15 @@ class DiarizationMixin:
|
||||
# Job retention constant
|
||||
DIARIZATION_JOB_TTL_SECONDS: float = 60 * 60 # 1 hour
|
||||
|
||||
def _process_streaming_diarization(
|
||||
async def _process_streaming_diarization(
|
||||
self: ServicerHost,
|
||||
meeting_id: str,
|
||||
audio: NDArray[np.float32],
|
||||
) -> None:
|
||||
"""Process an audio chunk for streaming diarization (best-effort)."""
|
||||
"""Process an audio chunk for streaming diarization (best-effort).
|
||||
|
||||
Offloads heavy ML inference to thread pool to avoid blocking the event loop.
|
||||
"""
|
||||
if self._diarization_engine is None:
|
||||
return
|
||||
if meeting_id in self._diarization_streaming_failed:
|
||||
@@ -73,48 +55,77 @@ class DiarizationMixin:
|
||||
if audio.size == 0:
|
||||
return
|
||||
|
||||
if not self._diarization_engine.is_streaming_loaded:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
async with self._diarization_lock:
|
||||
if not self._diarization_engine.is_streaming_loaded:
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
self._diarization_engine.load_streaming_model,
|
||||
)
|
||||
except (RuntimeError, ValueError) as exc:
|
||||
logger.warning(
|
||||
"Streaming diarization disabled for meeting %s: %s",
|
||||
meeting_id,
|
||||
exc,
|
||||
)
|
||||
self._diarization_streaming_failed.add(meeting_id)
|
||||
return
|
||||
|
||||
stream_time = self._diarization_stream_time.get(meeting_id, 0.0)
|
||||
duration = len(audio) / self.DEFAULT_SAMPLE_RATE
|
||||
|
||||
try:
|
||||
self._diarization_engine.load_streaming_model()
|
||||
except (RuntimeError, ValueError) as exc:
|
||||
turns = await loop.run_in_executor(
|
||||
None,
|
||||
partial(
|
||||
self._diarization_engine.process_chunk,
|
||||
audio,
|
||||
sample_rate=self.DEFAULT_SAMPLE_RATE,
|
||||
),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Streaming diarization disabled for meeting %s: %s",
|
||||
"Streaming diarization failed for meeting %s: %s",
|
||||
meeting_id,
|
||||
exc,
|
||||
)
|
||||
self._diarization_streaming_failed.add(meeting_id)
|
||||
return
|
||||
|
||||
stream_time = self._diarization_stream_time.get(meeting_id, 0.0)
|
||||
duration = len(audio) / self.DEFAULT_SAMPLE_RATE
|
||||
|
||||
try:
|
||||
turns = self._diarization_engine.process_chunk(
|
||||
audio,
|
||||
sample_rate=self.DEFAULT_SAMPLE_RATE,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Streaming diarization failed for meeting %s: %s",
|
||||
meeting_id,
|
||||
exc,
|
||||
)
|
||||
self._diarization_streaming_failed.add(meeting_id)
|
||||
return
|
||||
|
||||
diarization_turns = self._diarization_turns.setdefault(meeting_id, [])
|
||||
adjusted_turns: list[SpeakerTurn] = []
|
||||
for turn in turns:
|
||||
diarization_turns.append(
|
||||
SpeakerTurn(
|
||||
speaker=turn.speaker,
|
||||
start=turn.start + stream_time,
|
||||
end=turn.end + stream_time,
|
||||
confidence=turn.confidence,
|
||||
)
|
||||
adjusted = SpeakerTurn(
|
||||
speaker=turn.speaker,
|
||||
start=turn.start + stream_time,
|
||||
end=turn.end + stream_time,
|
||||
confidence=turn.confidence,
|
||||
)
|
||||
diarization_turns.append(adjusted)
|
||||
adjusted_turns.append(adjusted)
|
||||
|
||||
self._diarization_stream_time[meeting_id] = stream_time + duration
|
||||
|
||||
# Persist turns immediately for crash resilience
|
||||
if adjusted_turns and self._use_database():
|
||||
try:
|
||||
async with self._create_uow() as uow:
|
||||
repo_turns = [
|
||||
StreamingTurn(
|
||||
speaker=t.speaker,
|
||||
start_time=t.start,
|
||||
end_time=t.end,
|
||||
confidence=t.confidence,
|
||||
)
|
||||
for t in adjusted_turns
|
||||
]
|
||||
await uow.diarization_jobs.add_streaming_turns(meeting_id, repo_turns)
|
||||
await uow.commit()
|
||||
except Exception:
|
||||
logger.exception("Failed to persist streaming turns for %s", meeting_id)
|
||||
|
||||
def _maybe_assign_speaker(
|
||||
self: ServicerHost,
|
||||
meeting_id: str,
|
||||
@@ -140,23 +151,41 @@ class DiarizationMixin:
|
||||
segment.speaker_id = speaker_id
|
||||
segment.speaker_confidence = confidence
|
||||
|
||||
def _prune_diarization_jobs(self: ServicerHost) -> None:
|
||||
"""Remove completed diarization jobs older than retention window."""
|
||||
if not self._diarization_jobs:
|
||||
return
|
||||
now = time.time()
|
||||
async def _prune_diarization_jobs(self: ServicerHost) -> None:
|
||||
"""Remove completed diarization jobs older than retention window.
|
||||
|
||||
Prunes both in-memory task references and database records.
|
||||
"""
|
||||
# Clean up in-memory task references for completed tasks
|
||||
completed_tasks = [
|
||||
job_id for job_id, task in self._diarization_tasks.items() if task.done()
|
||||
]
|
||||
for job_id in completed_tasks:
|
||||
self._diarization_tasks.pop(job_id, None)
|
||||
|
||||
terminal_statuses = {
|
||||
noteflow_pb2.JOB_STATUS_COMPLETED,
|
||||
noteflow_pb2.JOB_STATUS_FAILED,
|
||||
}
|
||||
expired = [
|
||||
job_id
|
||||
for job_id, job in self._diarization_jobs.items()
|
||||
if job.status in terminal_statuses
|
||||
and now - job.updated_at > self.DIARIZATION_JOB_TTL_SECONDS
|
||||
]
|
||||
for job_id in expired:
|
||||
self._diarization_jobs.pop(job_id, None)
|
||||
|
||||
# Prune old completed jobs from database
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
pruned = await uow.diarization_jobs.prune_completed(
|
||||
self.DIARIZATION_JOB_TTL_SECONDS
|
||||
)
|
||||
await uow.commit()
|
||||
if pruned > 0:
|
||||
logger.debug("Pruned %d completed diarization jobs", pruned)
|
||||
else:
|
||||
cutoff = datetime.now() - timedelta(seconds=self.DIARIZATION_JOB_TTL_SECONDS)
|
||||
expired = [
|
||||
job_id
|
||||
for job_id, job in self._diarization_jobs.items()
|
||||
if job.status in terminal_statuses and job.updated_at < cutoff
|
||||
]
|
||||
for job_id in expired:
|
||||
self._diarization_jobs.pop(job_id, None)
|
||||
|
||||
async def RefineSpeakerDiarization(
|
||||
self: ServicerHost,
|
||||
@@ -166,9 +195,9 @@ class DiarizationMixin:
|
||||
"""Run post-meeting speaker diarization refinement.
|
||||
|
||||
Load the full meeting audio, run offline diarization, and update
|
||||
segment speaker assignments.
|
||||
segment speaker assignments. Job state is persisted to database.
|
||||
"""
|
||||
self._prune_diarization_jobs()
|
||||
await self._prune_diarization_jobs()
|
||||
|
||||
if not self._diarization_refinement_enabled:
|
||||
response = noteflow_pb2.RefineSpeakerDiarizationResponse()
|
||||
@@ -233,16 +262,23 @@ class DiarizationMixin:
|
||||
num_speakers = request.num_speakers if request.num_speakers > 0 else None
|
||||
|
||||
job_id = str(uuid4())
|
||||
job = _DiarizationJob(
|
||||
job = DiarizationJob(
|
||||
job_id=job_id,
|
||||
meeting_id=request.meeting_id,
|
||||
status=noteflow_pb2.JOB_STATUS_QUEUED,
|
||||
)
|
||||
self._diarization_jobs[job_id] = job
|
||||
|
||||
# Task runs in background, no need to await
|
||||
# Persist job to database
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
await uow.diarization_jobs.create(job)
|
||||
await uow.commit()
|
||||
else:
|
||||
self._diarization_jobs[job_id] = job
|
||||
|
||||
# Create background task and store reference for potential cancellation
|
||||
task = asyncio.create_task(self._run_diarization_job(job_id, num_speakers))
|
||||
job.task = task
|
||||
self._diarization_tasks[job_id] = task
|
||||
|
||||
response = noteflow_pb2.RefineSpeakerDiarizationResponse()
|
||||
response.segments_updated = 0
|
||||
@@ -257,29 +293,75 @@ class DiarizationMixin:
|
||||
job_id: str,
|
||||
num_speakers: int | None,
|
||||
) -> None:
|
||||
"""Run background diarization job."""
|
||||
job = self._diarization_jobs.get(job_id)
|
||||
if job is None:
|
||||
return
|
||||
"""Run background diarization job.
|
||||
|
||||
job.status = noteflow_pb2.JOB_STATUS_RUNNING
|
||||
job.updated_at = time.time()
|
||||
Updates job status in database as the job progresses.
|
||||
"""
|
||||
# Get meeting_id from database
|
||||
meeting_id: str | None = None
|
||||
job: DiarizationJob | None = None
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
job = await uow.diarization_jobs.get(job_id)
|
||||
if job is None:
|
||||
logger.warning("Diarization job %s not found in database", job_id)
|
||||
return
|
||||
meeting_id = job.meeting_id
|
||||
# Update status to RUNNING
|
||||
await uow.diarization_jobs.update_status(
|
||||
job_id,
|
||||
noteflow_pb2.JOB_STATUS_RUNNING,
|
||||
)
|
||||
await uow.commit()
|
||||
else:
|
||||
job = self._diarization_jobs.get(job_id)
|
||||
if job is None:
|
||||
logger.warning("Diarization job %s not found in memory", job_id)
|
||||
return
|
||||
meeting_id = job.meeting_id
|
||||
job.status = noteflow_pb2.JOB_STATUS_RUNNING
|
||||
job.updated_at = datetime.now()
|
||||
|
||||
try:
|
||||
updated_count = await self.refine_speaker_diarization(
|
||||
meeting_id=job.meeting_id,
|
||||
meeting_id=meeting_id,
|
||||
num_speakers=num_speakers,
|
||||
)
|
||||
speaker_ids = await self._collect_speaker_ids(job.meeting_id)
|
||||
job.segments_updated = updated_count
|
||||
job.speaker_ids = speaker_ids
|
||||
job.status = noteflow_pb2.JOB_STATUS_COMPLETED
|
||||
speaker_ids = await self._collect_speaker_ids(meeting_id)
|
||||
|
||||
# Update status to COMPLETED
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
await uow.diarization_jobs.update_status(
|
||||
job_id,
|
||||
noteflow_pb2.JOB_STATUS_COMPLETED,
|
||||
segments_updated=updated_count,
|
||||
speaker_ids=speaker_ids,
|
||||
)
|
||||
await uow.commit()
|
||||
else:
|
||||
if job is not None:
|
||||
job.status = noteflow_pb2.JOB_STATUS_COMPLETED
|
||||
job.segments_updated = updated_count
|
||||
job.speaker_ids = speaker_ids
|
||||
job.updated_at = datetime.now()
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception("Diarization failed for meeting %s", job.meeting_id)
|
||||
job.error_message = str(exc)
|
||||
job.status = noteflow_pb2.JOB_STATUS_FAILED
|
||||
finally:
|
||||
job.updated_at = time.time()
|
||||
logger.exception("Diarization failed for meeting %s", meeting_id)
|
||||
# Update status to FAILED
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
await uow.diarization_jobs.update_status(
|
||||
job_id,
|
||||
noteflow_pb2.JOB_STATUS_FAILED,
|
||||
error_message=str(exc),
|
||||
)
|
||||
await uow.commit()
|
||||
else:
|
||||
if job is not None:
|
||||
job.status = noteflow_pb2.JOB_STATUS_FAILED
|
||||
job.error_message = str(exc)
|
||||
job.updated_at = datetime.now()
|
||||
|
||||
async def refine_speaker_diarization(
|
||||
self: ServicerHost,
|
||||
@@ -302,11 +384,12 @@ class DiarizationMixin:
|
||||
Raises:
|
||||
RuntimeError: If diarization engine not available or meeting not found.
|
||||
"""
|
||||
turns = await asyncio.to_thread(
|
||||
self._run_diarization_inference,
|
||||
meeting_id,
|
||||
num_speakers,
|
||||
)
|
||||
async with self._diarization_lock:
|
||||
turns = await asyncio.to_thread(
|
||||
self._run_diarization_inference,
|
||||
meeting_id,
|
||||
num_speakers,
|
||||
)
|
||||
|
||||
updated_count = await self._apply_diarization_turns(meeting_id, turns)
|
||||
|
||||
@@ -475,12 +558,37 @@ class DiarizationMixin:
|
||||
request: noteflow_pb2.GetDiarizationJobStatusRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.DiarizationJobStatus:
|
||||
"""Return current status for a diarization job."""
|
||||
self._prune_diarization_jobs()
|
||||
"""Return current status for a diarization job.
|
||||
|
||||
Queries job state from database for persistence across restarts.
|
||||
"""
|
||||
await self._prune_diarization_jobs()
|
||||
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
job = await uow.diarization_jobs.get(request.job_id)
|
||||
if job is None:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
"Diarization job not found",
|
||||
)
|
||||
return noteflow_pb2.DiarizationJobStatus(
|
||||
job_id=job.job_id,
|
||||
status=job.status,
|
||||
segments_updated=job.segments_updated,
|
||||
speaker_ids=job.speaker_ids,
|
||||
error_message=job.error_message,
|
||||
)
|
||||
job = self._diarization_jobs.get(request.job_id)
|
||||
if job is None:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
"Diarization job not found",
|
||||
)
|
||||
return job.to_proto()
|
||||
return noteflow_pb2.DiarizationJobStatus(
|
||||
job_id=job.job_id,
|
||||
status=job.status,
|
||||
segments_updated=job.segments_updated,
|
||||
speaker_ids=job.speaker_ids,
|
||||
error_message=job.error_message,
|
||||
)
|
||||
|
||||
@@ -69,6 +69,8 @@ class MeetingMixin:
|
||||
except ValueError as e:
|
||||
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
||||
await uow.meetings.update(meeting)
|
||||
# Clean up streaming diarization turns (no longer needed)
|
||||
await uow.diarization_jobs.clear_streaming_turns(meeting_id)
|
||||
await uow.commit()
|
||||
return meeting_to_proto(meeting)
|
||||
store = self._get_memory_store()
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
@@ -15,6 +16,7 @@ if TYPE_CHECKING:
|
||||
from noteflow.infrastructure.asr import FasterWhisperEngine, Segmenter, StreamingVad
|
||||
from noteflow.infrastructure.audio.writer import MeetingAudioWriter
|
||||
from noteflow.infrastructure.diarization import DiarizationEngine, SpeakerTurn
|
||||
from noteflow.infrastructure.persistence.repositories import DiarizationJob
|
||||
from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWork
|
||||
from noteflow.infrastructure.security.crypto import AesGcmCryptoBox
|
||||
|
||||
@@ -62,6 +64,11 @@ class ServicerHost(Protocol):
|
||||
_diarization_stream_time: dict[str, float]
|
||||
_diarization_streaming_failed: set[str]
|
||||
|
||||
# Background diarization task references (for cancellation)
|
||||
_diarization_jobs: dict[str, DiarizationJob]
|
||||
_diarization_tasks: dict[str, asyncio.Task[None]]
|
||||
_diarization_lock: asyncio.Lock
|
||||
|
||||
# Constants
|
||||
DEFAULT_SAMPLE_RATE: int
|
||||
SUPPORTED_SAMPLE_RATES: list[int]
|
||||
@@ -105,6 +112,7 @@ class ServicerHost(Protocol):
|
||||
meeting_id: str,
|
||||
dek: bytes,
|
||||
wrapped_dek: bytes,
|
||||
asset_path: str | None = None,
|
||||
) -> None:
|
||||
"""Open audio writer for a meeting."""
|
||||
...
|
||||
|
||||
@@ -15,6 +15,7 @@ import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
from noteflow.infrastructure.diarization import SpeakerTurn
|
||||
|
||||
from ..proto import noteflow_pb2
|
||||
from .converters import create_segment_from_asr, create_vad_update, segment_to_proto_update
|
||||
@@ -94,6 +95,12 @@ class StreamingMixin:
|
||||
yield update
|
||||
finally:
|
||||
if current_meeting_id:
|
||||
# Flush audio buffer before cleanup to minimize data loss
|
||||
if current_meeting_id in self._audio_writers:
|
||||
try:
|
||||
self._audio_writers[current_meeting_id].flush()
|
||||
except Exception as e:
|
||||
logger.warning("Failed to flush audio for %s: %s", current_meeting_id, e)
|
||||
self._cleanup_streaming_state(current_meeting_id)
|
||||
self._close_audio_writer(current_meeting_id)
|
||||
self._active_streams.discard(current_meeting_id)
|
||||
@@ -167,9 +174,36 @@ class StreamingMixin:
|
||||
await uow.commit()
|
||||
|
||||
next_segment_id = await uow.segments.get_next_segment_id(meeting.id)
|
||||
self._open_meeting_audio_writer(meeting_id, dek, wrapped_dek)
|
||||
self._open_meeting_audio_writer(
|
||||
meeting_id, dek, wrapped_dek, asset_path=meeting.asset_path
|
||||
)
|
||||
self._init_streaming_state(meeting_id, next_segment_id)
|
||||
|
||||
# Load any persisted streaming turns (crash recovery)
|
||||
persisted_turns = await uow.diarization_jobs.get_streaming_turns(meeting_id)
|
||||
if persisted_turns:
|
||||
domain_turns = [
|
||||
SpeakerTurn(
|
||||
speaker=t.speaker,
|
||||
start=t.start_time,
|
||||
end=t.end_time,
|
||||
confidence=t.confidence,
|
||||
)
|
||||
for t in persisted_turns
|
||||
]
|
||||
self._diarization_turns[meeting_id] = domain_turns
|
||||
# Advance stream time to avoid overlapping recovered turns
|
||||
last_end = max(t.end_time for t in persisted_turns)
|
||||
self._diarization_stream_time[meeting_id] = max(
|
||||
self._diarization_stream_time.get(meeting_id, 0.0),
|
||||
last_end,
|
||||
)
|
||||
logger.info(
|
||||
"Loaded %d streaming diarization turns for meeting %s",
|
||||
len(domain_turns),
|
||||
meeting_id,
|
||||
)
|
||||
|
||||
return _StreamSessionInit(next_segment_id=next_segment_id)
|
||||
|
||||
def _init_stream_session_memory(
|
||||
@@ -207,7 +241,7 @@ class StreamingMixin:
|
||||
store.update(meeting)
|
||||
|
||||
next_segment_id = meeting.next_segment_id
|
||||
self._open_meeting_audio_writer(meeting_id, dek, wrapped_dek)
|
||||
self._open_meeting_audio_writer(meeting_id, dek, wrapped_dek, asset_path=meeting.asset_path)
|
||||
self._init_streaming_state(meeting_id, next_segment_id)
|
||||
|
||||
return _StreamSessionInit(next_segment_id=next_segment_id)
|
||||
@@ -377,7 +411,7 @@ class StreamingMixin:
|
||||
|
||||
# Streaming diarization (optional) - call mixin method if available
|
||||
if hasattr(self, "_process_streaming_diarization"):
|
||||
self._process_streaming_diarization(meeting_id, audio)
|
||||
await self._process_streaming_diarization(meeting_id, audio)
|
||||
|
||||
# Emit VAD state change events
|
||||
was_speaking = self._was_speaking.get(meeting_id, False)
|
||||
|
||||
@@ -177,7 +177,7 @@ class NoteFlowClient:
|
||||
True if connected successfully.
|
||||
"""
|
||||
try:
|
||||
return self._extracted_from_connect_11(timeout)
|
||||
return self._setup_grpc_channel(timeout)
|
||||
except grpc.FutureTimeoutError:
|
||||
logger.error("Connection timeout: %s", self._server_address)
|
||||
self._notify_connection(False, "Connection timeout")
|
||||
@@ -187,8 +187,15 @@ class NoteFlowClient:
|
||||
self._notify_connection(False, str(e))
|
||||
return False
|
||||
|
||||
# TODO Rename this here and in `connect`
|
||||
def _extracted_from_connect_11(self, timeout):
|
||||
def _setup_grpc_channel(self, timeout: float) -> bool:
|
||||
"""Set up the gRPC channel and stub.
|
||||
|
||||
Args:
|
||||
timeout: Connection timeout in seconds.
|
||||
|
||||
Returns:
|
||||
True if connection succeeded.
|
||||
"""
|
||||
self._channel = grpc.insecure_channel(
|
||||
self._server_address,
|
||||
options=[
|
||||
|
||||
@@ -137,6 +137,9 @@ class NoteFlowServer:
|
||||
"""
|
||||
if self._server:
|
||||
logger.info("Stopping server (grace period: %.1fs)...", grace_period)
|
||||
# Clean up servicer state before stopping
|
||||
if self._servicer:
|
||||
await self._servicer.shutdown()
|
||||
await self._server.stop(grace_period)
|
||||
logger.info("Server stopped")
|
||||
|
||||
@@ -184,19 +187,48 @@ async def run_server(
|
||||
logger.info("Database connection pool ready")
|
||||
|
||||
# Run crash recovery on startup
|
||||
uow = SqlAlchemyUnitOfWork(session_factory)
|
||||
recovery_service = RecoveryService(uow)
|
||||
recovered = await recovery_service.recover_crashed_meetings()
|
||||
if recovered:
|
||||
settings = get_settings()
|
||||
recovery_service = RecoveryService(
|
||||
SqlAlchemyUnitOfWork(session_factory),
|
||||
meetings_dir=settings.meetings_dir,
|
||||
)
|
||||
recovery_result = await recovery_service.recover_all()
|
||||
if recovery_result.meetings_recovered:
|
||||
logger.warning(
|
||||
"Recovered %d crashed meetings on startup",
|
||||
len(recovered),
|
||||
recovery_result.meetings_recovered,
|
||||
)
|
||||
if recovery_result.diarization_jobs_failed:
|
||||
logger.warning(
|
||||
"Recovered %d crashed diarization jobs on startup",
|
||||
recovery_result.diarization_jobs_failed,
|
||||
)
|
||||
if recovery_result.audio_validation_failures:
|
||||
logger.warning(
|
||||
"Found %d meetings with missing/invalid audio files",
|
||||
recovery_result.audio_validation_failures,
|
||||
)
|
||||
|
||||
# Create summarization service - auto-detects LOCAL/MOCK providers
|
||||
summarization_service = create_summarization_service()
|
||||
logger.info("Summarization service initialized")
|
||||
|
||||
# Load cloud consent from database and set up persistence callback
|
||||
if session_factory:
|
||||
async with SqlAlchemyUnitOfWork(session_factory) as uow:
|
||||
cloud_consent = await uow.preferences.get_bool("cloud_consent_granted", False)
|
||||
summarization_service.settings.cloud_consent_granted = cloud_consent
|
||||
logger.info("Loaded cloud consent from database: %s", cloud_consent)
|
||||
|
||||
# Create consent persistence callback
|
||||
async def persist_consent(granted: bool) -> None:
|
||||
async with SqlAlchemyUnitOfWork(session_factory) as uow:
|
||||
await uow.preferences.set("cloud_consent_granted", granted)
|
||||
await uow.commit()
|
||||
logger.info("Persisted cloud consent: %s", granted)
|
||||
|
||||
summarization_service.on_consent_change = persist_consent
|
||||
|
||||
# Create diarization engine if enabled
|
||||
diarization_engine: DiarizationEngine | None = None
|
||||
if diarization_enabled:
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
@@ -15,6 +17,7 @@ from noteflow.domain.entities import Meeting
|
||||
from noteflow.domain.value_objects import MeetingState
|
||||
from noteflow.infrastructure.asr import Segmenter, SegmenterConfig, StreamingVad
|
||||
from noteflow.infrastructure.audio.writer import MeetingAudioWriter
|
||||
from noteflow.infrastructure.persistence.repositories import DiarizationJob
|
||||
from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWork
|
||||
from noteflow.infrastructure.security.crypto import AesGcmCryptoBox
|
||||
from noteflow.infrastructure.security.keystore import KeyringKeyStore
|
||||
@@ -118,8 +121,10 @@ class NoteFlowServicer(
|
||||
# Track audio write failures to avoid log spam
|
||||
self._audio_write_failed: set[str] = set()
|
||||
|
||||
# Background diarization jobs
|
||||
self._diarization_jobs: dict[str, object] = {}
|
||||
# Background diarization task references (for cancellation)
|
||||
self._diarization_jobs: dict[str, DiarizationJob] = {}
|
||||
self._diarization_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
self._diarization_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def asr_engine(self) -> FasterWhisperEngine | None:
|
||||
@@ -226,6 +231,7 @@ class NoteFlowServicer(
|
||||
meeting_id: str,
|
||||
dek: bytes,
|
||||
wrapped_dek: bytes,
|
||||
asset_path: str | None = None,
|
||||
) -> None:
|
||||
"""Open audio writer for a meeting.
|
||||
|
||||
@@ -233,6 +239,7 @@ class NoteFlowServicer(
|
||||
meeting_id: Meeting ID string.
|
||||
dek: Data encryption key.
|
||||
wrapped_dek: Wrapped DEK.
|
||||
asset_path: Relative path for audio storage (defaults to meeting_id).
|
||||
"""
|
||||
writer = MeetingAudioWriter(self._crypto, self._meetings_dir)
|
||||
writer.open(
|
||||
@@ -240,6 +247,7 @@ class NoteFlowServicer(
|
||||
dek=dek,
|
||||
wrapped_dek=wrapped_dek,
|
||||
sample_rate=self.DEFAULT_SAMPLE_RATE,
|
||||
asset_path=asset_path,
|
||||
)
|
||||
self._audio_writers[meeting_id] = writer
|
||||
logger.info("Audio writer opened for meeting %s", meeting_id)
|
||||
@@ -317,3 +325,39 @@ class NoteFlowServicer(
|
||||
diarization_enabled=diarization_enabled,
|
||||
diarization_ready=diarization_ready,
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Clean up servicer state before server stops.
|
||||
|
||||
Cancel in-flight diarization tasks, close audio writers, and mark
|
||||
any running jobs as failed in the database.
|
||||
"""
|
||||
logger.info("Shutting down servicer...")
|
||||
|
||||
# Cancel in-flight diarization tasks
|
||||
for job_id, task in list(self._diarization_tasks.items()):
|
||||
if not task.done():
|
||||
logger.debug("Cancelling diarization task %s", job_id)
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
self._diarization_tasks.clear()
|
||||
|
||||
# Close all audio writers
|
||||
for meeting_id in list(self._audio_writers.keys()):
|
||||
logger.debug("Closing audio writer for meeting %s", meeting_id)
|
||||
self._close_audio_writer(meeting_id)
|
||||
|
||||
# Mark running jobs as FAILED in database
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
failed_count = await uow.diarization_jobs.mark_running_as_failed()
|
||||
await uow.commit()
|
||||
if failed_count > 0:
|
||||
logger.warning(
|
||||
"Marked %d running diarization jobs as failed on shutdown",
|
||||
failed_count,
|
||||
)
|
||||
|
||||
logger.info("Servicer shutdown complete")
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
These DTOs define the data structures used by ASR components.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
Define data structures used by audio capture components.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
Provide RMS and dB level calculation for VU meter display.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Final
|
||||
|
||||
|
||||
@@ -54,6 +54,7 @@ class MeetingAudioReader:
|
||||
def load_meeting_audio(
|
||||
self,
|
||||
meeting_id: str,
|
||||
asset_path: str | None = None,
|
||||
) -> list[TimestampedAudio]:
|
||||
"""Load all audio from an archived meeting.
|
||||
|
||||
@@ -61,6 +62,8 @@ class MeetingAudioReader:
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting UUID string.
|
||||
asset_path: Relative path for audio storage (defaults to meeting_id).
|
||||
Use the stored asset_path from the database when available.
|
||||
|
||||
Returns:
|
||||
List of TimestampedAudio chunks (or empty list if not found/failed).
|
||||
@@ -69,7 +72,8 @@ class MeetingAudioReader:
|
||||
FileNotFoundError: If meeting directory or audio file not found.
|
||||
ValueError: If manifest is invalid or audio format unsupported.
|
||||
"""
|
||||
meeting_dir = self._meetings_dir / meeting_id
|
||||
storage_path = asset_path or meeting_id
|
||||
meeting_dir = self._meetings_dir / storage_path
|
||||
self._meeting_dir = meeting_dir
|
||||
|
||||
# Load and parse manifest
|
||||
@@ -145,31 +149,43 @@ class MeetingAudioReader:
|
||||
|
||||
return chunks
|
||||
|
||||
def get_manifest(self, meeting_id: str) -> dict[str, object] | None:
|
||||
def get_manifest(
|
||||
self,
|
||||
meeting_id: str,
|
||||
asset_path: str | None = None,
|
||||
) -> dict[str, object] | None:
|
||||
"""Get manifest metadata for a meeting.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting UUID string.
|
||||
asset_path: Relative path for audio storage (defaults to meeting_id).
|
||||
|
||||
Returns:
|
||||
Manifest dict or None if not found.
|
||||
"""
|
||||
manifest_path = self._meetings_dir / meeting_id / "manifest.json"
|
||||
storage_path = asset_path or meeting_id
|
||||
manifest_path = self._meetings_dir / storage_path / "manifest.json"
|
||||
if not manifest_path.exists():
|
||||
return None
|
||||
|
||||
return dict(json.loads(manifest_path.read_text()))
|
||||
|
||||
def audio_exists(self, meeting_id: str) -> bool:
|
||||
def audio_exists(
|
||||
self,
|
||||
meeting_id: str,
|
||||
asset_path: str | None = None,
|
||||
) -> bool:
|
||||
"""Check if audio file exists for a meeting.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting UUID string.
|
||||
asset_path: Relative path for audio storage (defaults to meeting_id).
|
||||
|
||||
Returns:
|
||||
True if audio.enc exists.
|
||||
"""
|
||||
meeting_dir = self._meetings_dir / meeting_id
|
||||
storage_path = asset_path or meeting_id
|
||||
meeting_dir = self._meetings_dir / storage_path
|
||||
audio_path = meeting_dir / "audio.enc"
|
||||
manifest_path = meeting_dir / "manifest.json"
|
||||
return audio_path.exists() and manifest_path.exists()
|
||||
|
||||
@@ -2,15 +2,21 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from noteflow.config.constants import DEFAULT_SAMPLE_RATE
|
||||
from noteflow.config.constants import (
|
||||
AUDIO_BUFFER_SIZE_BYTES,
|
||||
DEFAULT_SAMPLE_RATE,
|
||||
PERIODIC_FLUSH_INTERVAL_SECONDS,
|
||||
)
|
||||
from noteflow.infrastructure.security.crypto import ChunkedAssetWriter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -27,6 +33,14 @@ class MeetingAudioWriter:
|
||||
Manage meeting directory creation, manifest file, and encrypted audio storage.
|
||||
Uses ChunkedAssetWriter for the actual encryption.
|
||||
|
||||
Audio data is buffered internally to reduce encryption overhead. Each encrypted
|
||||
chunk has 28 bytes overhead (12 byte nonce + 16 byte tag) plus 4 byte length
|
||||
prefix. Buffering aggregates small writes into larger chunks (~320KB) before
|
||||
encryption to minimize this overhead.
|
||||
|
||||
A background thread periodically flushes the buffer every 2 seconds to minimize
|
||||
data loss on crashes. All buffer access is protected by a lock.
|
||||
|
||||
Directory structure:
|
||||
~/.noteflow/meetings/<meeting-uuid>/
|
||||
├── manifest.json # Meeting metadata + wrapped DEK
|
||||
@@ -37,19 +51,30 @@ class MeetingAudioWriter:
|
||||
self,
|
||||
crypto: AesGcmCryptoBox,
|
||||
meetings_dir: Path,
|
||||
buffer_size: int = AUDIO_BUFFER_SIZE_BYTES,
|
||||
) -> None:
|
||||
"""Initialize audio writer.
|
||||
|
||||
Args:
|
||||
crypto: CryptoBox instance for encryption operations.
|
||||
meetings_dir: Root directory for all meetings (e.g., ~/.noteflow/meetings).
|
||||
buffer_size: Buffer size threshold in bytes before flushing to disk.
|
||||
Defaults to AUDIO_BUFFER_SIZE_BYTES (~320KB = 10 seconds at 16kHz).
|
||||
"""
|
||||
self._crypto = crypto
|
||||
self._meetings_dir = meetings_dir
|
||||
self._buffer_size = buffer_size
|
||||
self._asset_writer: ChunkedAssetWriter | None = None
|
||||
self._meeting_dir: Path | None = None
|
||||
self._sample_rate: int = DEFAULT_SAMPLE_RATE
|
||||
self._chunk_count: int = 0
|
||||
self._write_count: int = 0
|
||||
self._buffer: io.BytesIO = io.BytesIO()
|
||||
|
||||
# Thread-safety for periodic flush
|
||||
self._buffer_lock = threading.Lock()
|
||||
self._flush_thread: threading.Thread | None = None
|
||||
self._stop_flush = threading.Event()
|
||||
|
||||
def open(
|
||||
self,
|
||||
@@ -57,6 +82,7 @@ class MeetingAudioWriter:
|
||||
dek: bytes,
|
||||
wrapped_dek: bytes,
|
||||
sample_rate: int = DEFAULT_SAMPLE_RATE,
|
||||
asset_path: str | None = None,
|
||||
) -> None:
|
||||
"""Open meeting for audio writing.
|
||||
|
||||
@@ -67,6 +93,8 @@ class MeetingAudioWriter:
|
||||
dek: Unwrapped data encryption key (32 bytes).
|
||||
wrapped_dek: Encrypted DEK to store in manifest.
|
||||
sample_rate: Audio sample rate (default 16000 Hz).
|
||||
asset_path: Relative path for audio storage (defaults to meeting_id).
|
||||
This allows meetings_dir to change without orphaning files.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If already open.
|
||||
@@ -75,8 +103,11 @@ class MeetingAudioWriter:
|
||||
if self._asset_writer is not None:
|
||||
raise RuntimeError("Writer already open")
|
||||
|
||||
# Use asset_path if provided, otherwise default to meeting_id
|
||||
storage_path = asset_path or meeting_id
|
||||
|
||||
# Create meeting directory
|
||||
self._meeting_dir = self._meetings_dir / meeting_id
|
||||
self._meeting_dir = self._meetings_dir / storage_path
|
||||
self._meeting_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write manifest.json
|
||||
@@ -98,15 +129,49 @@ class MeetingAudioWriter:
|
||||
|
||||
self._sample_rate = sample_rate
|
||||
self._chunk_count = 0
|
||||
self._write_count = 0
|
||||
self._buffer = io.BytesIO()
|
||||
|
||||
# Start periodic flush thread for crash resilience
|
||||
self._stop_flush.clear()
|
||||
self._flush_thread = threading.Thread(
|
||||
target=self._periodic_flush_loop,
|
||||
name=f"AudioFlush-{meeting_id[:8]}",
|
||||
daemon=True,
|
||||
)
|
||||
self._flush_thread.start()
|
||||
|
||||
logger.info(
|
||||
"Opened audio writer: meeting=%s, dir=%s",
|
||||
"Opened audio writer: meeting=%s, dir=%s, buffer_size=%d",
|
||||
meeting_id,
|
||||
self._meeting_dir,
|
||||
self._buffer_size,
|
||||
)
|
||||
|
||||
def _periodic_flush_loop(self) -> None:
|
||||
"""Background thread: periodically flush buffer for crash resilience."""
|
||||
while not self._stop_flush.wait(timeout=PERIODIC_FLUSH_INTERVAL_SECONDS):
|
||||
try:
|
||||
self._flush_if_open()
|
||||
except Exception:
|
||||
logger.exception("Periodic flush failed")
|
||||
|
||||
def _flush_if_open(self) -> None:
|
||||
"""Flush buffer if writer is open (thread-safe, no exception if closed)."""
|
||||
with self._buffer_lock:
|
||||
if (
|
||||
self._asset_writer is not None
|
||||
and self._asset_writer.is_open
|
||||
and self._buffer.tell() > 0
|
||||
):
|
||||
self._flush_buffer_unlocked()
|
||||
|
||||
def write_chunk(self, audio: NDArray[np.float32]) -> None:
|
||||
"""Write audio chunk (convert float32 → PCM16).
|
||||
"""Write audio chunk to internal buffer (convert float32 → PCM16).
|
||||
|
||||
Audio is buffered internally and flushed to encrypted storage when the
|
||||
buffer exceeds the configured threshold. Call flush() to force immediate
|
||||
write, or close() to finalize.
|
||||
|
||||
Args:
|
||||
audio: Audio samples as float32 array (-1.0 to 1.0).
|
||||
@@ -122,29 +187,85 @@ class MeetingAudioWriter:
|
||||
audio_clamped = np.clip(audio, -1.0, 1.0)
|
||||
pcm16 = (audio_clamped * 32767.0).astype(np.int16)
|
||||
|
||||
# Write as raw bytes (platform-native endianness, typically little-endian)
|
||||
self._asset_writer.write_chunk(pcm16.tobytes())
|
||||
self._chunk_count += 1
|
||||
with self._buffer_lock:
|
||||
# Append to buffer
|
||||
self._buffer.write(pcm16.tobytes())
|
||||
self._write_count += 1
|
||||
|
||||
# Flush buffer if threshold exceeded
|
||||
if self._buffer.tell() >= self._buffer_size:
|
||||
self._flush_buffer_unlocked()
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Force flush buffered audio to encrypted storage.
|
||||
|
||||
Call this to ensure all buffered audio is written immediately.
|
||||
Normally only needed before a long pause or when precise timing matters.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If not open.
|
||||
"""
|
||||
if self._asset_writer is None or not self._asset_writer.is_open:
|
||||
raise RuntimeError("Writer not open")
|
||||
|
||||
with self._buffer_lock:
|
||||
if self._buffer.tell() > 0:
|
||||
self._flush_buffer_unlocked()
|
||||
|
||||
def _flush_buffer_unlocked(self) -> None:
|
||||
"""Flush internal buffer to encrypted storage.
|
||||
|
||||
Must be called with _buffer_lock held.
|
||||
"""
|
||||
if self._asset_writer is None:
|
||||
return
|
||||
|
||||
if buffer_bytes := self._buffer.getvalue():
|
||||
self._asset_writer.write_chunk(buffer_bytes)
|
||||
self._chunk_count += 1
|
||||
logger.debug(
|
||||
"Flushed audio buffer: %d bytes, chunk #%d",
|
||||
len(buffer_bytes),
|
||||
self._chunk_count,
|
||||
)
|
||||
|
||||
# Reset buffer
|
||||
self._buffer = io.BytesIO()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close audio writer and finalize files.
|
||||
|
||||
Stops the periodic flush thread, flushes remaining audio, and closes files.
|
||||
Safe to call if already closed or never opened.
|
||||
"""
|
||||
# Stop periodic flush thread first
|
||||
self._stop_flush.set()
|
||||
if self._flush_thread is not None:
|
||||
self._flush_thread.join(timeout=1.0)
|
||||
self._flush_thread = None
|
||||
|
||||
if self._asset_writer is not None:
|
||||
# Flush remaining buffer under lock
|
||||
with self._buffer_lock:
|
||||
self._flush_buffer_unlocked()
|
||||
|
||||
bytes_written = self._asset_writer.bytes_written
|
||||
self._asset_writer.close()
|
||||
self._asset_writer = None
|
||||
|
||||
logger.info(
|
||||
"Closed audio writer: dir=%s, chunks=%d, bytes=%d",
|
||||
"Closed audio writer: dir=%s, writes=%d, encrypted_chunks=%d, bytes=%d",
|
||||
self._meeting_dir,
|
||||
self._write_count,
|
||||
self._chunk_count,
|
||||
bytes_written,
|
||||
)
|
||||
|
||||
self._meeting_dir = None
|
||||
self._chunk_count = 0
|
||||
self._write_count = 0
|
||||
with self._buffer_lock:
|
||||
self._buffer = io.BytesIO()
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
@@ -158,9 +279,26 @@ class MeetingAudioWriter:
|
||||
|
||||
@property
|
||||
def chunk_count(self) -> int:
|
||||
"""Number of audio chunks written."""
|
||||
"""Number of encrypted chunks written to disk.
|
||||
|
||||
Due to buffering, this may be less than write_count.
|
||||
"""
|
||||
return self._chunk_count
|
||||
|
||||
@property
|
||||
def write_count(self) -> int:
|
||||
"""Number of write_chunk() calls made.
|
||||
|
||||
This counts incoming audio frames, not encrypted chunks written to disk.
|
||||
"""
|
||||
return self._write_count
|
||||
|
||||
@property
|
||||
def buffered_bytes(self) -> int:
|
||||
"""Current bytes pending in buffer, not yet written to disk."""
|
||||
with self._buffer_lock:
|
||||
return self._buffer.tell()
|
||||
|
||||
@property
|
||||
def meeting_dir(self) -> Path | None:
|
||||
"""Current meeting directory, or None if not open."""
|
||||
|
||||
@@ -100,6 +100,7 @@ class OrmConverter:
|
||||
ended_at=model.ended_at,
|
||||
metadata=model.metadata_,
|
||||
wrapped_dek=model.wrapped_dek,
|
||||
asset_path=model.asset_path,
|
||||
)
|
||||
|
||||
# --- Segment ---
|
||||
|
||||
@@ -4,8 +4,6 @@ Provides functions to assign speaker labels to transcript segments based on
|
||||
diarization output using timestamp overlap matching.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from noteflow.infrastructure.diarization.dto import SpeakerTurn
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
These DTOs define the data structures used by diarization components.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Shared formatting utilities for export modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
@@ -55,21 +54,6 @@ def get_async_session_factory(
|
||||
)
|
||||
|
||||
|
||||
async def get_async_session(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Yield an async database session.
|
||||
|
||||
Args:
|
||||
session_factory: Factory for creating sessions.
|
||||
|
||||
Yields:
|
||||
Async database session that is closed after use.
|
||||
"""
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
def create_async_session_factory(
|
||||
database_url: str,
|
||||
pool_size: int = 5,
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Alembic migration environment configuration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
"""add_diarization_jobs_table
|
||||
|
||||
Revision ID: d8e5f6a7b2c3
|
||||
Revises: c7d4e9f3a2b1
|
||||
Create Date: 2025-12-19 10:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d8e5f6a7b2c3"
|
||||
down_revision: str | Sequence[str] | None = "c7d4e9f3a2b1"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create diarization_jobs table for tracking background jobs."""
|
||||
op.create_table(
|
||||
"diarization_jobs",
|
||||
sa.Column("id", sa.String(36), primary_key=True),
|
||||
sa.Column(
|
||||
"meeting_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("noteflow.meetings.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("status", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("segments_updated", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column(
|
||||
"speaker_ids",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="[]",
|
||||
),
|
||||
sa.Column("error_message", sa.Text(), nullable=False, server_default=""),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
schema="noteflow",
|
||||
)
|
||||
|
||||
# Create index for meeting_id lookups
|
||||
op.create_index(
|
||||
"ix_diarization_jobs_meeting_id",
|
||||
"diarization_jobs",
|
||||
["meeting_id"],
|
||||
schema="noteflow",
|
||||
)
|
||||
|
||||
# Create index for status queries (e.g., finding running jobs)
|
||||
op.create_index(
|
||||
"ix_diarization_jobs_status",
|
||||
"diarization_jobs",
|
||||
["status"],
|
||||
schema="noteflow",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop diarization_jobs table."""
|
||||
op.drop_index("ix_diarization_jobs_status", table_name="diarization_jobs", schema="noteflow")
|
||||
op.drop_index(
|
||||
"ix_diarization_jobs_meeting_id", table_name="diarization_jobs", schema="noteflow"
|
||||
)
|
||||
op.drop_table("diarization_jobs", schema="noteflow")
|
||||
@@ -0,0 +1,45 @@
|
||||
"""add_asset_path_to_meetings
|
||||
|
||||
Revision ID: e9f0a1b2c3d4
|
||||
Revises: d8e5f6a7b2c3
|
||||
Create Date: 2025-12-19 08:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "e9f0a1b2c3d4"
|
||||
down_revision: str | Sequence[str] | None = "d8e5f6a7b2c3"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add asset_path column to meetings table.
|
||||
|
||||
Stores the relative path for audio files. This allows the meetings_dir
|
||||
to change without orphaning existing recordings.
|
||||
"""
|
||||
op.add_column(
|
||||
"meetings",
|
||||
sa.Column("asset_path", sa.Text(), nullable=True),
|
||||
schema="noteflow",
|
||||
)
|
||||
|
||||
# Backfill existing rows: asset_path = id (as string)
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE noteflow.meetings
|
||||
SET asset_path = id::text
|
||||
WHERE asset_path IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove asset_path column from meetings table."""
|
||||
op.drop_column("meetings", "asset_path", schema="noteflow")
|
||||
@@ -0,0 +1,54 @@
|
||||
"""add_user_preferences_table
|
||||
|
||||
Revision ID: f0a1b2c3d4e5
|
||||
Revises: e9f0a1b2c3d4
|
||||
Create Date: 2025-12-19 09:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "f0a1b2c3d4e5"
|
||||
down_revision: str | Sequence[str] | None = "e9f0a1b2c3d4"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create user_preferences table for persisting user settings."""
|
||||
op.create_table(
|
||||
"user_preferences",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("key", sa.String(64), nullable=False),
|
||||
sa.Column("value", JSONB(), nullable=False),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("key"),
|
||||
schema="noteflow",
|
||||
)
|
||||
op.create_index(
|
||||
"ix_noteflow_user_preferences_key",
|
||||
"user_preferences",
|
||||
["key"],
|
||||
schema="noteflow",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop user_preferences table."""
|
||||
op.drop_index(
|
||||
"ix_noteflow_user_preferences_key",
|
||||
table_name="user_preferences",
|
||||
schema="noteflow",
|
||||
)
|
||||
op.drop_table("user_preferences", schema="noteflow")
|
||||
@@ -0,0 +1,62 @@
|
||||
"""add_streaming_diarization_turns
|
||||
|
||||
Revision ID: g1b2c3d4e5f6
|
||||
Revises: f0a1b2c3d4e5
|
||||
Create Date: 2025-12-19 14:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "g1b2c3d4e5f6"
|
||||
down_revision: str | Sequence[str] | None = "f0a1b2c3d4e5"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create streaming_diarization_turns table for crash-resilient speaker turns."""
|
||||
op.create_table(
|
||||
"streaming_diarization_turns",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column(
|
||||
"meeting_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("noteflow.meetings.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("speaker", sa.String(50), nullable=False),
|
||||
sa.Column("start_time", sa.Float(), nullable=False),
|
||||
sa.Column("end_time", sa.Float(), nullable=False),
|
||||
sa.Column("confidence", sa.Float(), nullable=False, server_default="0.0"),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
schema="noteflow",
|
||||
)
|
||||
|
||||
# Create index for meeting_id lookups
|
||||
op.create_index(
|
||||
"ix_streaming_diarization_turns_meeting_id",
|
||||
"streaming_diarization_turns",
|
||||
["meeting_id"],
|
||||
schema="noteflow",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop streaming_diarization_turns table."""
|
||||
op.drop_index(
|
||||
"ix_streaming_diarization_turns_meeting_id",
|
||||
table_name="streaming_diarization_turns",
|
||||
schema="noteflow",
|
||||
)
|
||||
op.drop_table("streaming_diarization_turns", schema="noteflow")
|
||||
@@ -65,6 +65,10 @@ class MeetingModel(Base):
|
||||
LargeBinary,
|
||||
nullable=True,
|
||||
)
|
||||
asset_path: Mapped[str | None] = mapped_column(
|
||||
Text,
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
segments: Mapped[list[SegmentModel]] = relationship(
|
||||
@@ -300,3 +304,91 @@ class AnnotationModel(Base):
|
||||
"MeetingModel",
|
||||
back_populates="annotations",
|
||||
)
|
||||
|
||||
|
||||
class UserPreferencesModel(Base):
|
||||
"""SQLAlchemy model for user_preferences table.
|
||||
|
||||
Stores key-value user preferences for persistence across server restarts.
|
||||
Currently used for cloud consent and other settings.
|
||||
"""
|
||||
|
||||
__tablename__ = "user_preferences"
|
||||
__table_args__: ClassVar[dict[str, str]] = {"schema": "noteflow"}
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
key: Mapped[str] = mapped_column(String(64), unique=True, index=True, nullable=False)
|
||||
value: Mapped[dict[str, object]] = mapped_column(JSONB, nullable=False, default=dict)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=datetime.now,
|
||||
onupdate=datetime.now,
|
||||
)
|
||||
|
||||
|
||||
class DiarizationJobModel(Base):
|
||||
"""SQLAlchemy model for diarization_jobs table.
|
||||
|
||||
Tracks background speaker diarization jobs. Persisting job state
|
||||
allows recovery after server restart and provides client visibility.
|
||||
"""
|
||||
|
||||
__tablename__ = "diarization_jobs"
|
||||
__table_args__: ClassVar[dict[str, str]] = {"schema": "noteflow"}
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True)
|
||||
meeting_id: Mapped[UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("noteflow.meetings.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
status: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
segments_updated: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
speaker_ids: Mapped[list[str]] = mapped_column(
|
||||
JSONB,
|
||||
nullable=False,
|
||||
default=list,
|
||||
)
|
||||
error_message: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=datetime.now,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=datetime.now,
|
||||
onupdate=datetime.now,
|
||||
)
|
||||
|
||||
|
||||
class StreamingDiarizationTurnModel(Base):
|
||||
"""SQLAlchemy model for streaming_diarization_turns table.
|
||||
|
||||
Stores speaker turns from real-time streaming diarization for crash
|
||||
resilience. These turns are persisted as they arrive and can be reloaded
|
||||
if the server restarts during a recording session.
|
||||
"""
|
||||
|
||||
__tablename__ = "streaming_diarization_turns"
|
||||
__table_args__: ClassVar[dict[str, str]] = {"schema": "noteflow"}
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
meeting_id: Mapped[UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("noteflow.meetings.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
speaker: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
start_time: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
end_time: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=datetime.now,
|
||||
)
|
||||
|
||||
@@ -1,13 +1,23 @@
|
||||
"""Repository implementations for NoteFlow."""
|
||||
|
||||
from .annotation_repo import SqlAlchemyAnnotationRepository
|
||||
from .diarization_job_repo import (
|
||||
DiarizationJob,
|
||||
SqlAlchemyDiarizationJobRepository,
|
||||
StreamingTurn,
|
||||
)
|
||||
from .meeting_repo import SqlAlchemyMeetingRepository
|
||||
from .preferences_repo import SqlAlchemyPreferencesRepository
|
||||
from .segment_repo import SqlAlchemySegmentRepository
|
||||
from .summary_repo import SqlAlchemySummaryRepository
|
||||
|
||||
__all__ = [
|
||||
"DiarizationJob",
|
||||
"SqlAlchemyAnnotationRepository",
|
||||
"SqlAlchemyDiarizationJobRepository",
|
||||
"SqlAlchemyMeetingRepository",
|
||||
"SqlAlchemyPreferencesRepository",
|
||||
"SqlAlchemySegmentRepository",
|
||||
"SqlAlchemySummaryRepository",
|
||||
"StreamingTurn",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,282 @@
|
||||
"""SQLAlchemy implementation of DiarizationJobRepository."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Final
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import delete, select, update
|
||||
|
||||
from noteflow.infrastructure.persistence.models import (
|
||||
DiarizationJobModel,
|
||||
StreamingDiarizationTurnModel,
|
||||
)
|
||||
from noteflow.infrastructure.persistence.repositories._base import BaseRepository
|
||||
|
||||
# Job status constants (mirrors proto enum)
|
||||
JOB_STATUS_UNSPECIFIED: Final[int] = 0
|
||||
JOB_STATUS_QUEUED: Final[int] = 1
|
||||
JOB_STATUS_RUNNING: Final[int] = 2
|
||||
JOB_STATUS_COMPLETED: Final[int] = 3
|
||||
JOB_STATUS_FAILED: Final[int] = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiarizationJob:
|
||||
"""Data transfer object for diarization job state.
|
||||
|
||||
Separate from ORM model to allow easy passing between layers.
|
||||
"""
|
||||
|
||||
job_id: str
|
||||
meeting_id: str
|
||||
status: int
|
||||
segments_updated: int = 0
|
||||
speaker_ids: list[str] = field(default_factory=list)
|
||||
error_message: str = ""
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamingTurn:
|
||||
"""Data transfer object for streaming diarization turn.
|
||||
|
||||
Represents a speaker turn collected during real-time streaming diarization.
|
||||
"""
|
||||
|
||||
speaker: str
|
||||
start_time: float
|
||||
end_time: float
|
||||
confidence: float = 0.0
|
||||
|
||||
|
||||
class SqlAlchemyDiarizationJobRepository(BaseRepository):
|
||||
"""SQLAlchemy implementation of DiarizationJobRepository."""
|
||||
|
||||
@staticmethod
|
||||
def _to_domain(model: DiarizationJobModel) -> DiarizationJob:
|
||||
"""Convert ORM model to domain object."""
|
||||
return DiarizationJob(
|
||||
job_id=model.id,
|
||||
meeting_id=str(model.meeting_id),
|
||||
status=model.status,
|
||||
segments_updated=model.segments_updated,
|
||||
speaker_ids=list(model.speaker_ids),
|
||||
error_message=model.error_message,
|
||||
created_at=model.created_at,
|
||||
updated_at=model.updated_at,
|
||||
)
|
||||
|
||||
async def create(self, job: DiarizationJob) -> DiarizationJob:
|
||||
"""Persist a new diarization job.
|
||||
|
||||
Args:
|
||||
job: Job to create.
|
||||
|
||||
Returns:
|
||||
Created job.
|
||||
"""
|
||||
model = DiarizationJobModel(
|
||||
id=job.job_id,
|
||||
meeting_id=UUID(job.meeting_id),
|
||||
status=job.status,
|
||||
segments_updated=job.segments_updated,
|
||||
speaker_ids=job.speaker_ids,
|
||||
error_message=job.error_message,
|
||||
created_at=job.created_at,
|
||||
updated_at=job.updated_at,
|
||||
)
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
return job
|
||||
|
||||
async def get(self, job_id: str) -> DiarizationJob | None:
|
||||
"""Retrieve a job by ID.
|
||||
|
||||
Args:
|
||||
job_id: Job identifier.
|
||||
|
||||
Returns:
|
||||
Job if found, None otherwise.
|
||||
"""
|
||||
stmt = select(DiarizationJobModel).where(DiarizationJobModel.id == job_id)
|
||||
model = await self._execute_scalar(stmt)
|
||||
|
||||
return None if model is None else self._to_domain(model)
|
||||
|
||||
async def update_status(
|
||||
self,
|
||||
job_id: str,
|
||||
status: int,
|
||||
*,
|
||||
segments_updated: int | None = None,
|
||||
speaker_ids: list[str] | None = None,
|
||||
error_message: str | None = None,
|
||||
) -> bool:
|
||||
"""Update job status and optional fields.
|
||||
|
||||
Args:
|
||||
job_id: Job identifier.
|
||||
status: New status value.
|
||||
segments_updated: Optional segments count.
|
||||
speaker_ids: Optional speaker IDs list.
|
||||
error_message: Optional error message.
|
||||
|
||||
Returns:
|
||||
True if job was updated, False if not found.
|
||||
"""
|
||||
values: dict[str, int | list[str] | str | datetime] = {
|
||||
"status": status,
|
||||
"updated_at": datetime.now(),
|
||||
}
|
||||
if segments_updated is not None:
|
||||
values["segments_updated"] = segments_updated
|
||||
if speaker_ids is not None:
|
||||
values["speaker_ids"] = speaker_ids
|
||||
if error_message is not None:
|
||||
values["error_message"] = error_message
|
||||
|
||||
stmt = update(DiarizationJobModel).where(DiarizationJobModel.id == job_id).values(**values)
|
||||
result = await self._session.execute(stmt)
|
||||
await self._session.flush()
|
||||
return result.rowcount > 0
|
||||
|
||||
async def list_for_meeting(self, meeting_id: str) -> Sequence[DiarizationJob]:
|
||||
"""List all jobs for a meeting.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting identifier.
|
||||
|
||||
Returns:
|
||||
List of jobs ordered by creation time (newest first).
|
||||
"""
|
||||
stmt = (
|
||||
select(DiarizationJobModel)
|
||||
.where(DiarizationJobModel.meeting_id == UUID(meeting_id))
|
||||
.order_by(DiarizationJobModel.created_at.desc())
|
||||
)
|
||||
models = await self._execute_scalars(stmt)
|
||||
return [self._to_domain(model) for model in models]
|
||||
|
||||
async def mark_running_as_failed(self, error_message: str = "Server restarted") -> int:
|
||||
"""Mark all QUEUED or RUNNING jobs as FAILED.
|
||||
|
||||
Used during crash recovery to mark orphaned jobs.
|
||||
|
||||
Args:
|
||||
error_message: Error message to set on failed jobs.
|
||||
|
||||
Returns:
|
||||
Number of jobs marked as failed.
|
||||
"""
|
||||
stmt = (
|
||||
update(DiarizationJobModel)
|
||||
.where(DiarizationJobModel.status.in_([JOB_STATUS_QUEUED, JOB_STATUS_RUNNING]))
|
||||
.values(
|
||||
status=JOB_STATUS_FAILED,
|
||||
error_message=error_message,
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
await self._session.flush()
|
||||
return result.rowcount
|
||||
|
||||
async def prune_completed(self, ttl_seconds: float) -> int:
|
||||
"""Delete completed/failed jobs older than TTL.
|
||||
|
||||
Args:
|
||||
ttl_seconds: Time-to-live in seconds.
|
||||
|
||||
Returns:
|
||||
Number of jobs deleted.
|
||||
"""
|
||||
cutoff = datetime.now().timestamp() - ttl_seconds
|
||||
cutoff_dt = datetime.fromtimestamp(cutoff)
|
||||
|
||||
stmt = delete(DiarizationJobModel).where(
|
||||
DiarizationJobModel.status.in_([JOB_STATUS_COMPLETED, JOB_STATUS_FAILED]),
|
||||
DiarizationJobModel.updated_at < cutoff_dt,
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
await self._session.flush()
|
||||
return result.rowcount
|
||||
|
||||
# Streaming diarization turn methods
|
||||
|
||||
async def add_streaming_turns(self, meeting_id: str, turns: Sequence[StreamingTurn]) -> int:
|
||||
"""Persist streaming diarization turns for a meeting.
|
||||
|
||||
Immediately stores speaker turns as they arrive during streaming.
|
||||
Used for crash resilience.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting identifier.
|
||||
turns: Speaker turns to persist.
|
||||
|
||||
Returns:
|
||||
Number of turns added.
|
||||
"""
|
||||
if not turns:
|
||||
return 0
|
||||
|
||||
meeting_uuid = UUID(meeting_id)
|
||||
for turn in turns:
|
||||
model = StreamingDiarizationTurnModel(
|
||||
meeting_id=meeting_uuid,
|
||||
speaker=turn.speaker,
|
||||
start_time=turn.start_time,
|
||||
end_time=turn.end_time,
|
||||
confidence=turn.confidence,
|
||||
)
|
||||
self._session.add(model)
|
||||
|
||||
await self._session.flush()
|
||||
return len(turns)
|
||||
|
||||
async def get_streaming_turns(self, meeting_id: str) -> list[StreamingTurn]:
|
||||
"""Retrieve streaming diarization turns for a meeting.
|
||||
|
||||
Used to recover streaming state after server restart.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting identifier.
|
||||
|
||||
Returns:
|
||||
List of streaming turns ordered by start time.
|
||||
"""
|
||||
stmt = (
|
||||
select(StreamingDiarizationTurnModel)
|
||||
.where(StreamingDiarizationTurnModel.meeting_id == UUID(meeting_id))
|
||||
.order_by(StreamingDiarizationTurnModel.start_time)
|
||||
)
|
||||
models = await self._execute_scalars(stmt)
|
||||
return [
|
||||
StreamingTurn(
|
||||
speaker=model.speaker,
|
||||
start_time=model.start_time,
|
||||
end_time=model.end_time,
|
||||
confidence=model.confidence,
|
||||
)
|
||||
for model in models
|
||||
]
|
||||
|
||||
async def clear_streaming_turns(self, meeting_id: str) -> int:
|
||||
"""Delete streaming diarization turns for a meeting.
|
||||
|
||||
Called when a meeting stops recording to clean up temporary turns.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting identifier.
|
||||
|
||||
Returns:
|
||||
Number of turns deleted.
|
||||
"""
|
||||
stmt = delete(StreamingDiarizationTurnModel).where(
|
||||
StreamingDiarizationTurnModel.meeting_id == UUID(meeting_id)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
await self._session.flush()
|
||||
return result.rowcount
|
||||
@@ -1,7 +1,5 @@
|
||||
"""SQLAlchemy implementation of MeetingRepository."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
@@ -36,6 +34,7 @@ class SqlAlchemyMeetingRepository(BaseRepository):
|
||||
ended_at=meeting.ended_at,
|
||||
metadata_=meeting.metadata,
|
||||
wrapped_dek=meeting.wrapped_dek,
|
||||
asset_path=meeting.asset_path,
|
||||
)
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
@@ -79,6 +78,7 @@ class SqlAlchemyMeetingRepository(BaseRepository):
|
||||
model.ended_at = meeting.ended_at
|
||||
model.metadata_ = meeting.metadata
|
||||
model.wrapped_dek = meeting.wrapped_dek
|
||||
model.asset_path = meeting.asset_path
|
||||
|
||||
await self._session.flush()
|
||||
return meeting
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
"""SQLAlchemy implementation of PreferencesRepository."""
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from noteflow.infrastructure.persistence.models import UserPreferencesModel
|
||||
from noteflow.infrastructure.persistence.repositories._base import BaseRepository
|
||||
|
||||
|
||||
class SqlAlchemyPreferencesRepository(BaseRepository):
|
||||
"""SQLAlchemy implementation of PreferencesRepository.
|
||||
|
||||
Provides key-value storage for user preferences. Values are stored as JSONB
|
||||
for flexibility while maintaining type-safe retrieval.
|
||||
"""
|
||||
|
||||
async def _get_by_key(self, key: str) -> UserPreferencesModel | None:
|
||||
"""Get preference model by key.
|
||||
|
||||
Args:
|
||||
key: Preference key.
|
||||
|
||||
Returns:
|
||||
UserPreferencesModel or None if not found.
|
||||
"""
|
||||
stmt = select(UserPreferencesModel).where(UserPreferencesModel.key == key)
|
||||
return await self._execute_scalar(stmt)
|
||||
|
||||
async def get(self, key: str) -> object | None:
|
||||
"""Get a preference value by key.
|
||||
|
||||
Args:
|
||||
key: Preference key.
|
||||
|
||||
Returns:
|
||||
Preference value or None if not found.
|
||||
"""
|
||||
model = await self._get_by_key(key)
|
||||
return None if model is None else model.value.get("value")
|
||||
|
||||
async def get_bool(self, key: str, default: bool = False) -> bool:
|
||||
"""Get a boolean preference.
|
||||
|
||||
Args:
|
||||
key: Preference key.
|
||||
default: Default value if not found.
|
||||
|
||||
Returns:
|
||||
Boolean preference value.
|
||||
"""
|
||||
value = await self.get(key)
|
||||
return default if value is None else bool(value)
|
||||
|
||||
async def set(self, key: str, value: object) -> None:
|
||||
"""Set a preference value.
|
||||
|
||||
Args:
|
||||
key: Preference key.
|
||||
value: Preference value (must be JSON-serializable).
|
||||
"""
|
||||
model = await self._get_by_key(key)
|
||||
|
||||
if model is None:
|
||||
model = UserPreferencesModel(key=key, value={"value": value})
|
||||
self._session.add(model)
|
||||
else:
|
||||
model.value = {"value": value}
|
||||
|
||||
await self._session.flush()
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete a preference.
|
||||
|
||||
Args:
|
||||
key: Preference key.
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found.
|
||||
"""
|
||||
model = await self._get_by_key(key)
|
||||
|
||||
if model is None:
|
||||
return False
|
||||
|
||||
await self._delete_and_flush(model)
|
||||
return True
|
||||
@@ -1,7 +1,5 @@
|
||||
"""SQLAlchemy implementation of SegmentRepository."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
@@ -23,6 +24,53 @@ if TYPE_CHECKING:
|
||||
class SqlAlchemySummaryRepository(BaseRepository):
|
||||
"""SQLAlchemy implementation of SummaryRepository."""
|
||||
|
||||
async def _add_key_points(self, summary_id: int, key_points: Sequence[KeyPoint]) -> None:
|
||||
"""Add key points to a summary.
|
||||
|
||||
Args:
|
||||
summary_id: Database ID of the summary.
|
||||
key_points: Key points to add. Their db_id fields are updated in place.
|
||||
"""
|
||||
models: list[tuple[KeyPointModel, KeyPoint]] = []
|
||||
for kp in key_points:
|
||||
kp_model = KeyPointModel(
|
||||
summary_id=summary_id,
|
||||
text=kp.text,
|
||||
start_time=kp.start_time,
|
||||
end_time=kp.end_time,
|
||||
segment_ids=kp.segment_ids,
|
||||
)
|
||||
self._session.add(kp_model)
|
||||
models.append((kp_model, kp))
|
||||
|
||||
await self._session.flush()
|
||||
for kp_model, kp in models:
|
||||
kp.db_id = kp_model.id
|
||||
|
||||
async def _add_action_items(self, summary_id: int, action_items: Sequence[ActionItem]) -> None:
|
||||
"""Add action items to a summary.
|
||||
|
||||
Args:
|
||||
summary_id: Database ID of the summary.
|
||||
action_items: Action items to add. Their db_id fields are updated in place.
|
||||
"""
|
||||
models: list[tuple[ActionItemModel, ActionItem]] = []
|
||||
for ai in action_items:
|
||||
ai_model = ActionItemModel(
|
||||
summary_id=summary_id,
|
||||
text=ai.text,
|
||||
assignee=ai.assignee,
|
||||
due_date=ai.due_date,
|
||||
priority=ai.priority,
|
||||
segment_ids=ai.segment_ids,
|
||||
)
|
||||
self._session.add(ai_model)
|
||||
models.append((ai_model, ai))
|
||||
|
||||
await self._session.flush()
|
||||
for ai_model, ai in models:
|
||||
ai.db_id = ai_model.id
|
||||
|
||||
async def save(self, summary: Summary) -> Summary:
|
||||
"""Save or update a meeting summary.
|
||||
|
||||
@@ -50,38 +98,9 @@ class SqlAlchemySummaryRepository(BaseRepository):
|
||||
delete(ActionItemModel).where(ActionItemModel.summary_id == existing.id)
|
||||
)
|
||||
|
||||
# Add new key points
|
||||
kp_models: list[tuple[KeyPointModel, KeyPoint]] = []
|
||||
for kp in summary.key_points:
|
||||
kp_model = KeyPointModel(
|
||||
summary_id=existing.id,
|
||||
text=kp.text,
|
||||
start_time=kp.start_time,
|
||||
end_time=kp.end_time,
|
||||
segment_ids=kp.segment_ids,
|
||||
)
|
||||
self._session.add(kp_model)
|
||||
kp_models.append((kp_model, kp))
|
||||
|
||||
# Add new action items
|
||||
ai_models: list[tuple[ActionItemModel, ActionItem]] = []
|
||||
for ai in summary.action_items:
|
||||
ai_model = ActionItemModel(
|
||||
summary_id=existing.id,
|
||||
text=ai.text,
|
||||
assignee=ai.assignee,
|
||||
due_date=ai.due_date,
|
||||
priority=ai.priority,
|
||||
segment_ids=ai.segment_ids,
|
||||
)
|
||||
self._session.add(ai_model)
|
||||
ai_models.append((ai_model, ai))
|
||||
|
||||
await self._session.flush()
|
||||
for kp_model, kp in kp_models:
|
||||
kp.db_id = kp_model.id
|
||||
for ai_model, ai in ai_models:
|
||||
ai.db_id = ai_model.id
|
||||
# Add new key points and action items
|
||||
await self._add_key_points(existing.id, summary.key_points)
|
||||
await self._add_action_items(existing.id, summary.action_items)
|
||||
summary.db_id = existing.id
|
||||
else:
|
||||
# Create new summary
|
||||
@@ -94,33 +113,9 @@ class SqlAlchemySummaryRepository(BaseRepository):
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
|
||||
# Add key points
|
||||
for kp in summary.key_points:
|
||||
kp_model = KeyPointModel(
|
||||
summary_id=model.id,
|
||||
text=kp.text,
|
||||
start_time=kp.start_time,
|
||||
end_time=kp.end_time,
|
||||
segment_ids=kp.segment_ids,
|
||||
)
|
||||
self._session.add(kp_model)
|
||||
await self._session.flush()
|
||||
kp.db_id = kp_model.id
|
||||
|
||||
# Add action items
|
||||
for ai in summary.action_items:
|
||||
ai_model = ActionItemModel(
|
||||
summary_id=model.id,
|
||||
text=ai.text,
|
||||
assignee=ai.assignee,
|
||||
due_date=ai.due_date,
|
||||
priority=ai.priority,
|
||||
segment_ids=ai.segment_ids,
|
||||
)
|
||||
self._session.add(ai_model)
|
||||
await self._session.flush()
|
||||
ai.db_id = ai_model.id
|
||||
|
||||
# Add key points and action items
|
||||
await self._add_key_points(model.id, summary.key_points)
|
||||
await self._add_action_items(model.id, summary.action_items)
|
||||
summary.db_id = model.id
|
||||
|
||||
return summary
|
||||
|
||||
@@ -15,7 +15,9 @@ from noteflow.infrastructure.persistence.database import (
|
||||
|
||||
from .repositories import (
|
||||
SqlAlchemyAnnotationRepository,
|
||||
SqlAlchemyDiarizationJobRepository,
|
||||
SqlAlchemyMeetingRepository,
|
||||
SqlAlchemyPreferencesRepository,
|
||||
SqlAlchemySegmentRepository,
|
||||
SqlAlchemySummaryRepository,
|
||||
)
|
||||
@@ -43,7 +45,9 @@ class SqlAlchemyUnitOfWork:
|
||||
self._session_factory = session_factory
|
||||
self._session: AsyncSession | None = None
|
||||
self._annotations_repo: SqlAlchemyAnnotationRepository | None = None
|
||||
self._diarization_jobs_repo: SqlAlchemyDiarizationJobRepository | None = None
|
||||
self._meetings_repo: SqlAlchemyMeetingRepository | None = None
|
||||
self._preferences_repo: SqlAlchemyPreferencesRepository | None = None
|
||||
self._segments_repo: SqlAlchemySegmentRepository | None = None
|
||||
self._summaries_repo: SqlAlchemySummaryRepository | None = None
|
||||
|
||||
@@ -87,6 +91,13 @@ class SqlAlchemyUnitOfWork:
|
||||
raise RuntimeError("UnitOfWork not in context")
|
||||
return self._annotations_repo
|
||||
|
||||
@property
|
||||
def diarization_jobs(self) -> SqlAlchemyDiarizationJobRepository:
|
||||
"""Get diarization jobs repository."""
|
||||
if self._diarization_jobs_repo is None:
|
||||
raise RuntimeError("UnitOfWork not in context")
|
||||
return self._diarization_jobs_repo
|
||||
|
||||
@property
|
||||
def meetings(self) -> SqlAlchemyMeetingRepository:
|
||||
"""Get meetings repository."""
|
||||
@@ -101,6 +112,13 @@ class SqlAlchemyUnitOfWork:
|
||||
raise RuntimeError("UnitOfWork not in context")
|
||||
return self._segments_repo
|
||||
|
||||
@property
|
||||
def preferences(self) -> SqlAlchemyPreferencesRepository:
|
||||
"""Get preferences repository."""
|
||||
if self._preferences_repo is None:
|
||||
raise RuntimeError("UnitOfWork not in context")
|
||||
return self._preferences_repo
|
||||
|
||||
@property
|
||||
def summaries(self) -> SqlAlchemySummaryRepository:
|
||||
"""Get summaries repository."""
|
||||
@@ -118,7 +136,9 @@ class SqlAlchemyUnitOfWork:
|
||||
"""
|
||||
self._session = self._session_factory()
|
||||
self._annotations_repo = SqlAlchemyAnnotationRepository(self._session)
|
||||
self._diarization_jobs_repo = SqlAlchemyDiarizationJobRepository(self._session)
|
||||
self._meetings_repo = SqlAlchemyMeetingRepository(self._session)
|
||||
self._preferences_repo = SqlAlchemyPreferencesRepository(self._session)
|
||||
self._segments_repo = SqlAlchemySegmentRepository(self._session)
|
||||
self._summaries_repo = SqlAlchemySummaryRepository(self._session)
|
||||
return self
|
||||
@@ -147,7 +167,9 @@ class SqlAlchemyUnitOfWork:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
self._annotations_repo = None
|
||||
self._diarization_jobs_repo = None
|
||||
self._meetings_repo = None
|
||||
self._preferences_repo = None
|
||||
self._segments_repo = None
|
||||
self._summaries_repo = None
|
||||
|
||||
|
||||
@@ -3,13 +3,13 @@
|
||||
Provides secure master key storage using OS credential stores.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import stat
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
|
||||
import keyring
|
||||
@@ -21,6 +21,42 @@ KEY_SIZE: Final[int] = 32 # 256-bit key
|
||||
SERVICE_NAME: Final[str] = "noteflow"
|
||||
KEY_NAME: Final[str] = "master_key"
|
||||
ENV_VAR_NAME: Final[str] = "NOTEFLOW_MASTER_KEY"
|
||||
DEFAULT_KEY_FILE: Final[Path] = Path.home() / ".noteflow" / ".master_key"
|
||||
|
||||
|
||||
def _decode_and_validate_key(encoded: str, source_name: str) -> bytes:
|
||||
"""Decode and validate a base64-encoded master key.
|
||||
|
||||
Args:
|
||||
encoded: Base64-encoded key string.
|
||||
source_name: Human-readable source name for error messages.
|
||||
|
||||
Returns:
|
||||
Decoded key bytes.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If decoding fails or key size is wrong.
|
||||
"""
|
||||
try:
|
||||
decoded = base64.b64decode(encoded, validate=True)
|
||||
except (binascii.Error, ValueError) as exc:
|
||||
raise RuntimeError(f"{source_name} contains invalid base64") from exc
|
||||
if len(decoded) != KEY_SIZE:
|
||||
raise RuntimeError(
|
||||
f"{source_name} has wrong key size: expected {KEY_SIZE}, got {len(decoded)}"
|
||||
)
|
||||
return decoded
|
||||
|
||||
|
||||
def _generate_key() -> tuple[bytes, str]:
|
||||
"""Generate a new random master key.
|
||||
|
||||
Returns:
|
||||
Tuple of (raw_key_bytes, base64_encoded_string).
|
||||
"""
|
||||
raw_key = secrets.token_bytes(KEY_SIZE)
|
||||
encoded = base64.b64encode(raw_key).decode("ascii")
|
||||
return raw_key, encoded
|
||||
|
||||
|
||||
class KeyringKeyStore:
|
||||
@@ -61,17 +97,7 @@ class KeyringKeyStore:
|
||||
# Check environment variable first (for headless/container deployments)
|
||||
if env_key := os.environ.get(ENV_VAR_NAME):
|
||||
logger.debug("Using master key from environment variable")
|
||||
try:
|
||||
decoded = base64.b64decode(env_key, validate=True)
|
||||
except (binascii.Error, ValueError) as exc:
|
||||
raise RuntimeError(
|
||||
f"{ENV_VAR_NAME} must be base64-encoded {KEY_SIZE}-byte key"
|
||||
) from exc
|
||||
if len(decoded) != KEY_SIZE:
|
||||
raise RuntimeError(
|
||||
f"{ENV_VAR_NAME} must decode to {KEY_SIZE} bytes, got {len(decoded)}"
|
||||
)
|
||||
return decoded
|
||||
return _decode_and_validate_key(env_key, f"Environment variable {ENV_VAR_NAME}")
|
||||
|
||||
try:
|
||||
# Try to retrieve existing key from keyring
|
||||
@@ -81,8 +107,7 @@ class KeyringKeyStore:
|
||||
return base64.b64decode(stored)
|
||||
|
||||
# Generate new key
|
||||
new_key = secrets.token_bytes(KEY_SIZE)
|
||||
encoded = base64.b64encode(new_key).decode("ascii")
|
||||
new_key, encoded = _generate_key()
|
||||
|
||||
# Store in keyring
|
||||
keyring.set_password(self._service_name, self._key_name, encoded)
|
||||
@@ -90,10 +115,12 @@ class KeyringKeyStore:
|
||||
return new_key
|
||||
|
||||
except keyring.errors.KeyringError as e:
|
||||
raise RuntimeError(
|
||||
f"Keyring unavailable: {e}. "
|
||||
f"Set {ENV_VAR_NAME} environment variable for headless mode."
|
||||
) from e
|
||||
# Fall back to file-based storage for headless environments
|
||||
logger.warning(
|
||||
"Keyring unavailable (%s), falling back to file-based key storage",
|
||||
e,
|
||||
)
|
||||
return FileKeyStore().get_or_create_master_key()
|
||||
|
||||
def delete_master_key(self) -> None:
|
||||
"""Delete the master key from the keychain.
|
||||
@@ -157,3 +184,69 @@ class InMemoryKeyStore:
|
||||
def has_master_key(self) -> bool:
|
||||
"""Check if master key exists."""
|
||||
return self._key is not None
|
||||
|
||||
|
||||
class FileKeyStore:
|
||||
"""File-based key storage for headless environments.
|
||||
|
||||
Stores the master key in a restricted-permissions file when keyring is
|
||||
unavailable. This is a fallback for headless servers, containers, and
|
||||
environments without a desktop session.
|
||||
|
||||
File permissions are set to 0600 (owner read/write only).
|
||||
"""
|
||||
|
||||
def __init__(self, key_file: Path | None = None) -> None:
|
||||
"""Initialize the file keystore.
|
||||
|
||||
Args:
|
||||
key_file: Path to key file. Defaults to ~/.noteflow/.master_key.
|
||||
"""
|
||||
self._key_file = key_file or DEFAULT_KEY_FILE
|
||||
|
||||
def get_or_create_master_key(self) -> bytes:
|
||||
"""Retrieve or generate the master encryption key.
|
||||
|
||||
Returns:
|
||||
32-byte master key.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If key file exists with wrong size or permissions fail.
|
||||
"""
|
||||
if self._key_file.exists():
|
||||
logger.debug("Retrieved master key from file: %s", self._key_file)
|
||||
content = self._key_file.read_text().strip()
|
||||
return _decode_and_validate_key(content, f"Key file {self._key_file}")
|
||||
|
||||
# Generate new key
|
||||
new_key, encoded = _generate_key()
|
||||
|
||||
# Create parent directory if needed
|
||||
self._key_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write key with restricted permissions
|
||||
self._key_file.write_text(encoded)
|
||||
self._key_file.chmod(stat.S_IRUSR | stat.S_IWUSR) # 0600
|
||||
|
||||
logger.info("Generated and stored master key in file: %s", self._key_file)
|
||||
return new_key
|
||||
|
||||
def delete_master_key(self) -> None:
|
||||
"""Delete the master key file.
|
||||
|
||||
Safe to call if file doesn't exist.
|
||||
"""
|
||||
if self._key_file.exists():
|
||||
self._key_file.unlink()
|
||||
logger.info("Deleted master key file: %s", self._key_file)
|
||||
else:
|
||||
logger.debug("Master key file not found, nothing to delete")
|
||||
|
||||
def has_master_key(self) -> bool:
|
||||
"""Check if master key file exists."""
|
||||
return self._key_file.exists()
|
||||
|
||||
@property
|
||||
def key_file(self) -> Path:
|
||||
"""Get the key file path."""
|
||||
return self._key_file
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
These protocols define the contracts for key storage and encryption components.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Factory for creating configured SummarizationService instances."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from noteflow.application.services.summarization_service import (
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Mock summarization provider for testing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
||||
@@ -103,7 +103,7 @@ class _SystemOutputSampler:
|
||||
try:
|
||||
devices = sd.query_devices()
|
||||
except Exception:
|
||||
return self._extracted_from__select_device_5(
|
||||
return self._mark_unavailable_with_warning(
|
||||
"Failed to query audio devices for app audio detection"
|
||||
)
|
||||
for idx, dev in enumerate(devices):
|
||||
@@ -111,21 +111,27 @@ class _SystemOutputSampler:
|
||||
if int(dev.get("max_input_channels", 0)) <= 0:
|
||||
continue
|
||||
if "monitor" in name or "loopback" in name:
|
||||
return self._extracted_from__select_device_24(idx)
|
||||
return self._mark_device_available(idx)
|
||||
self._available = False
|
||||
logger.warning("No loopback audio device found - app audio detection disabled")
|
||||
|
||||
# TODO Rename this here and in `_select_device`
|
||||
def _extracted_from__select_device_24(self, arg0):
|
||||
self._device = arg0
|
||||
self._available = True
|
||||
return
|
||||
def _mark_device_available(self, device_index: int) -> None:
|
||||
"""Mark the device as available for audio capture.
|
||||
|
||||
# TODO Rename this here and in `_select_device`
|
||||
def _extracted_from__select_device_5(self, arg0):
|
||||
Args:
|
||||
device_index: Index of the audio device.
|
||||
"""
|
||||
self._device = device_index
|
||||
self._available = True
|
||||
|
||||
def _mark_unavailable_with_warning(self, message: str) -> None:
|
||||
"""Mark device as unavailable and log a warning.
|
||||
|
||||
Args:
|
||||
message: Warning message to log.
|
||||
"""
|
||||
self._available = False
|
||||
logger.warning(arg0)
|
||||
return
|
||||
logger.warning(message)
|
||||
|
||||
def _ensure_stream(self) -> bool:
|
||||
if self._available is False:
|
||||
|
||||
@@ -8,7 +8,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.domain.triggers.entities import TriggerSignal, TriggerSource
|
||||
@@ -63,7 +63,7 @@ class CalendarProvider:
|
||||
if not self._settings.events:
|
||||
return None
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
now = datetime.now(UTC)
|
||||
window_start = now - timedelta(minutes=self._settings.lookbehind_minutes)
|
||||
window_end = now + timedelta(minutes=self._settings.lookahead_minutes)
|
||||
|
||||
@@ -145,6 +145,5 @@ def _parse_datetime(value: object) -> datetime | None:
|
||||
|
||||
def _ensure_tz(value: datetime) -> datetime:
|
||||
if value.tzinfo is None:
|
||||
return value.replace(tzinfo=timezone.utc)
|
||||
return value.astimezone(timezone.utc)
|
||||
|
||||
return value.replace(tzinfo=UTC)
|
||||
return value.astimezone(UTC)
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
Detect meeting applications in the foreground window.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
1
support/__init__.py
Normal file
1
support/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Shared development/test support utilities."""
|
||||
186
support/db_utils.py
Normal file
186
support/db_utils.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""PostgreSQL testcontainer fixtures and utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from importlib import import_module
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import quote
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from noteflow.infrastructure.persistence.models import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Self
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
|
||||
|
||||
|
||||
class PgTestContainer:
|
||||
"""Minimal Postgres testcontainer wrapper with custom readiness wait."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image: str = "pgvector/pgvector:pg16",
|
||||
username: str = "test",
|
||||
password: str = "test",
|
||||
dbname: str = "noteflow_test",
|
||||
port: int = 5432,
|
||||
) -> None:
|
||||
"""Initialize the container configuration.
|
||||
|
||||
Args:
|
||||
image: Docker image to use.
|
||||
username: PostgreSQL username.
|
||||
password: PostgreSQL password.
|
||||
dbname: Database name.
|
||||
port: PostgreSQL port.
|
||||
"""
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.dbname = dbname
|
||||
self.port = port
|
||||
|
||||
container_module = import_module("testcontainers.core.container")
|
||||
docker_container_cls = container_module.DockerContainer
|
||||
self._container = (
|
||||
docker_container_cls(image)
|
||||
.with_env("POSTGRES_USER", username)
|
||||
.with_env("POSTGRES_PASSWORD", password)
|
||||
.with_env("POSTGRES_DB", dbname)
|
||||
.with_exposed_ports(port)
|
||||
)
|
||||
|
||||
def start(self) -> Self:
|
||||
"""Start the container."""
|
||||
self._container.start()
|
||||
self._wait_until_ready()
|
||||
return self
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the container."""
|
||||
self._container.stop()
|
||||
|
||||
def get_connection_url(self) -> str:
|
||||
"""Return a SQLAlchemy-style connection URL."""
|
||||
host = self._container.get_container_host_ip()
|
||||
port = self._container._get_exposed_port(self.port)
|
||||
quoted_password = quote(self.password, safe=" +")
|
||||
return (
|
||||
f"postgresql+psycopg2://{self.username}:{quoted_password}@{host}:{port}/{self.dbname}"
|
||||
)
|
||||
|
||||
def _wait_until_ready(self, timeout: float = 30.0, interval: float = 0.5) -> None:
|
||||
"""Wait for Postgres to accept connections by running a simple query."""
|
||||
start_time = time.time()
|
||||
escaped_password = self.password.replace("'", "'\"'\"'")
|
||||
cmd = [
|
||||
"sh",
|
||||
"-c",
|
||||
(
|
||||
f"PGPASSWORD='{escaped_password}' "
|
||||
f"psql --username {self.username} --dbname {self.dbname} --host 127.0.0.1 "
|
||||
"-c 'select 1;'"
|
||||
),
|
||||
]
|
||||
last_error: str | None = None
|
||||
|
||||
while True:
|
||||
result = self._container.exec(cmd)
|
||||
if result.exit_code == 0:
|
||||
return
|
||||
if result.output:
|
||||
last_error = result.output.decode(errors="ignore")
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(
|
||||
"Postgres container did not become ready in time"
|
||||
+ (f": {last_error}" if last_error else "")
|
||||
)
|
||||
time.sleep(interval)
|
||||
|
||||
|
||||
# Module-level container singleton
|
||||
_container: PgTestContainer | None = None
|
||||
_database_url: str | None = None
|
||||
|
||||
|
||||
def get_or_create_container() -> tuple[PgTestContainer, str]:
|
||||
"""Get or create the PostgreSQL container singleton.
|
||||
|
||||
Returns:
|
||||
Tuple of (container, async_database_url).
|
||||
"""
|
||||
global _container, _database_url
|
||||
|
||||
if _container is None:
|
||||
container = PgTestContainer().start()
|
||||
_container = container
|
||||
url = container.get_connection_url()
|
||||
_database_url = url.replace("postgresql+psycopg2://", "postgresql+asyncpg://")
|
||||
|
||||
assert _container is not None, "Container should be initialized"
|
||||
assert _database_url is not None, "Database URL should be initialized"
|
||||
return _container, _database_url
|
||||
|
||||
|
||||
def stop_container() -> None:
|
||||
"""Stop and cleanup the container singleton."""
|
||||
global _container
|
||||
if _container is not None:
|
||||
_container.stop()
|
||||
_container = None
|
||||
|
||||
|
||||
async def initialize_test_schema(conn: AsyncConnection) -> None:
|
||||
"""Initialize test database schema.
|
||||
|
||||
Creates the pgvector extension and noteflow schema with all tables.
|
||||
|
||||
Args:
|
||||
conn: Async database connection.
|
||||
"""
|
||||
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
await conn.execute(text("DROP SCHEMA IF EXISTS noteflow CASCADE"))
|
||||
await conn.execute(text("CREATE SCHEMA noteflow"))
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
async def cleanup_test_schema(conn: AsyncConnection) -> None:
|
||||
"""Drop the test schema.
|
||||
|
||||
Args:
|
||||
conn: Async database connection.
|
||||
"""
|
||||
await conn.execute(text("DROP SCHEMA IF EXISTS noteflow CASCADE"))
|
||||
|
||||
|
||||
def create_test_session_factory(engine: AsyncEngine) -> async_sessionmaker[AsyncSession]:
|
||||
"""Create standard test session factory.
|
||||
|
||||
Args:
|
||||
engine: SQLAlchemy async engine.
|
||||
|
||||
Returns:
|
||||
Configured session factory.
|
||||
"""
|
||||
return async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
|
||||
def create_test_engine(database_url: str) -> AsyncEngine:
|
||||
"""Create test database engine.
|
||||
|
||||
Args:
|
||||
database_url: Async database URL.
|
||||
|
||||
Returns:
|
||||
SQLAlchemy async engine.
|
||||
"""
|
||||
return create_async_engine(database_url, echo=False)
|
||||
@@ -21,19 +21,6 @@ if TYPE_CHECKING:
|
||||
class TestMeetingServiceCreation:
|
||||
"""Tests for meeting creation operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_uow(self) -> MagicMock:
|
||||
"""Create a mock UnitOfWork."""
|
||||
uow = MagicMock()
|
||||
uow.__aenter__ = AsyncMock(return_value=uow)
|
||||
uow.__aexit__ = AsyncMock(return_value=None)
|
||||
uow.commit = AsyncMock()
|
||||
uow.rollback = AsyncMock()
|
||||
uow.meetings = MagicMock()
|
||||
uow.segments = MagicMock()
|
||||
uow.summaries = MagicMock()
|
||||
return uow
|
||||
|
||||
async def test_create_meeting_success(self, mock_uow: MagicMock) -> None:
|
||||
"""Test successful meeting creation."""
|
||||
created_meeting = Meeting.create(title="Test Meeting")
|
||||
@@ -61,18 +48,6 @@ class TestMeetingServiceCreation:
|
||||
class TestMeetingServiceRetrieval:
|
||||
"""Tests for meeting retrieval operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_uow(self) -> MagicMock:
|
||||
"""Create a mock UnitOfWork."""
|
||||
uow = MagicMock()
|
||||
uow.__aenter__ = AsyncMock(return_value=uow)
|
||||
uow.__aexit__ = AsyncMock(return_value=None)
|
||||
uow.commit = AsyncMock()
|
||||
uow.meetings = MagicMock()
|
||||
uow.segments = MagicMock()
|
||||
uow.summaries = MagicMock()
|
||||
return uow
|
||||
|
||||
async def test_get_meeting_found(self, mock_uow: MagicMock) -> None:
|
||||
"""Test retrieving existing meeting."""
|
||||
meeting_id = MeetingId(uuid4())
|
||||
@@ -116,16 +91,6 @@ class TestMeetingServiceRetrieval:
|
||||
class TestMeetingServiceStateTransitions:
|
||||
"""Tests for meeting state transition operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_uow(self) -> MagicMock:
|
||||
"""Create a mock UnitOfWork."""
|
||||
uow = MagicMock()
|
||||
uow.__aenter__ = AsyncMock(return_value=uow)
|
||||
uow.__aexit__ = AsyncMock(return_value=None)
|
||||
uow.commit = AsyncMock()
|
||||
uow.meetings = MagicMock()
|
||||
return uow
|
||||
|
||||
async def test_start_recording_success(self, mock_uow: MagicMock) -> None:
|
||||
"""Test starting recording on existing meeting."""
|
||||
meeting = Meeting.create(title="Test")
|
||||
@@ -224,16 +189,6 @@ class TestMeetingServiceStateTransitions:
|
||||
class TestMeetingServiceDeletion:
|
||||
"""Tests for meeting deletion operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_uow(self) -> MagicMock:
|
||||
"""Create a mock UnitOfWork."""
|
||||
uow = MagicMock()
|
||||
uow.__aenter__ = AsyncMock(return_value=uow)
|
||||
uow.__aexit__ = AsyncMock(return_value=None)
|
||||
uow.commit = AsyncMock()
|
||||
uow.meetings = MagicMock()
|
||||
return uow
|
||||
|
||||
async def test_delete_meeting_success(self, mock_uow: MagicMock) -> None:
|
||||
"""Test successful meeting deletion."""
|
||||
meeting_id = MeetingId(uuid4())
|
||||
@@ -315,16 +270,6 @@ class TestMeetingServiceDeletion:
|
||||
class TestMeetingServiceSegments:
|
||||
"""Tests for segment operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_uow(self) -> MagicMock:
|
||||
"""Create a mock UnitOfWork."""
|
||||
uow = MagicMock()
|
||||
uow.__aenter__ = AsyncMock(return_value=uow)
|
||||
uow.__aexit__ = AsyncMock(return_value=None)
|
||||
uow.commit = AsyncMock()
|
||||
uow.segments = MagicMock()
|
||||
return uow
|
||||
|
||||
async def test_add_segment_success(self, mock_uow: MagicMock) -> None:
|
||||
"""Test adding a segment to meeting."""
|
||||
meeting_id = MeetingId(uuid4())
|
||||
@@ -381,16 +326,6 @@ class TestMeetingServiceSegments:
|
||||
class TestMeetingServiceSummaries:
|
||||
"""Tests for summary operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_uow(self) -> MagicMock:
|
||||
"""Create a mock UnitOfWork."""
|
||||
uow = MagicMock()
|
||||
uow.__aenter__ = AsyncMock(return_value=uow)
|
||||
uow.__aexit__ = AsyncMock(return_value=None)
|
||||
uow.commit = AsyncMock()
|
||||
uow.summaries = MagicMock()
|
||||
return uow
|
||||
|
||||
async def test_save_summary_success(self, mock_uow: MagicMock) -> None:
|
||||
"""Test saving a meeting summary."""
|
||||
meeting_id = MeetingId(uuid4())
|
||||
@@ -439,15 +374,6 @@ class TestMeetingServiceSummaries:
|
||||
class TestMeetingServiceSearch:
|
||||
"""Tests for semantic search operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_uow(self) -> MagicMock:
|
||||
"""Create a mock UnitOfWork."""
|
||||
uow = MagicMock()
|
||||
uow.__aenter__ = AsyncMock(return_value=uow)
|
||||
uow.__aexit__ = AsyncMock(return_value=None)
|
||||
uow.segments = MagicMock()
|
||||
return uow
|
||||
|
||||
async def test_search_segments_delegates(self, mock_uow: MagicMock) -> None:
|
||||
"""Test search_segments delegates to repository."""
|
||||
meeting_id = MeetingId(uuid4())
|
||||
@@ -466,16 +392,6 @@ class TestMeetingServiceSearch:
|
||||
class TestMeetingServiceAnnotations:
|
||||
"""Tests for annotation operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_uow(self) -> MagicMock:
|
||||
"""Create a mock UnitOfWork."""
|
||||
uow = MagicMock()
|
||||
uow.__aenter__ = AsyncMock(return_value=uow)
|
||||
uow.__aexit__ = AsyncMock(return_value=None)
|
||||
uow.commit = AsyncMock()
|
||||
uow.annotations = MagicMock()
|
||||
return uow
|
||||
|
||||
async def test_add_annotation_success(self, mock_uow: MagicMock) -> None:
|
||||
"""Test adding an annotation commits and returns saved entity."""
|
||||
meeting_id = MeetingId(uuid4())
|
||||
@@ -537,19 +453,6 @@ class TestMeetingServiceAnnotations:
|
||||
class TestMeetingServiceAdditionalBranches:
|
||||
"""Additional branch coverage for MeetingService."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_uow(self) -> MagicMock:
|
||||
"""Create a mock UnitOfWork with all repos."""
|
||||
uow = MagicMock()
|
||||
uow.__aenter__ = AsyncMock(return_value=uow)
|
||||
uow.__aexit__ = AsyncMock(return_value=None)
|
||||
uow.commit = AsyncMock()
|
||||
uow.meetings = MagicMock()
|
||||
uow.segments = MagicMock()
|
||||
uow.summaries = MagicMock()
|
||||
uow.annotations = MagicMock()
|
||||
return uow
|
||||
|
||||
async def test_stop_meeting_not_found(self, mock_uow: MagicMock) -> None:
|
||||
"""stop_meeting should return None when meeting is missing."""
|
||||
mock_uow.meetings.get = AsyncMock(return_value=None)
|
||||
|
||||
@@ -2,26 +2,19 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.application.services.recovery_service import RecoveryService
|
||||
from noteflow.application.services.recovery_service import (
|
||||
AudioValidationResult,
|
||||
RecoveryService,
|
||||
)
|
||||
from noteflow.domain.entities import Meeting
|
||||
from noteflow.domain.value_objects import MeetingState
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_uow() -> MagicMock:
|
||||
"""Create a mock UnitOfWork."""
|
||||
uow = MagicMock()
|
||||
uow.__aenter__ = AsyncMock(return_value=uow)
|
||||
uow.__aexit__ = AsyncMock(return_value=None)
|
||||
uow.commit = AsyncMock()
|
||||
uow.meetings = MagicMock()
|
||||
return uow
|
||||
|
||||
|
||||
class TestRecoveryServiceRecovery:
|
||||
"""Tests for crash recovery operations."""
|
||||
|
||||
@@ -30,9 +23,10 @@ class TestRecoveryServiceRecovery:
|
||||
mock_uow.meetings.list_all = AsyncMock(return_value=([], 0))
|
||||
|
||||
service = RecoveryService(mock_uow)
|
||||
result = await service.recover_crashed_meetings()
|
||||
meetings, audio_failures = await service.recover_crashed_meetings()
|
||||
|
||||
assert result == []
|
||||
assert meetings == []
|
||||
assert audio_failures == 0
|
||||
mock_uow.commit.assert_not_called()
|
||||
|
||||
async def test_recover_single_recording_meeting(self, mock_uow: MagicMock) -> None:
|
||||
@@ -45,13 +39,13 @@ class TestRecoveryServiceRecovery:
|
||||
mock_uow.meetings.update = AsyncMock(return_value=meeting)
|
||||
|
||||
service = RecoveryService(mock_uow)
|
||||
result = await service.recover_crashed_meetings()
|
||||
meetings, _ = await service.recover_crashed_meetings()
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].state == MeetingState.ERROR
|
||||
assert result[0].metadata["crash_recovered"] == "true"
|
||||
assert result[0].metadata["crash_previous_state"] == "RECORDING"
|
||||
assert "crash_recovery_time" in result[0].metadata
|
||||
assert len(meetings) == 1
|
||||
assert meetings[0].state == MeetingState.ERROR
|
||||
assert meetings[0].metadata["crash_recovered"] == "true"
|
||||
assert meetings[0].metadata["crash_previous_state"] == "RECORDING"
|
||||
assert "crash_recovery_time" in meetings[0].metadata
|
||||
mock_uow.meetings.update.assert_called_once()
|
||||
mock_uow.commit.assert_called_once()
|
||||
|
||||
@@ -66,11 +60,11 @@ class TestRecoveryServiceRecovery:
|
||||
mock_uow.meetings.update = AsyncMock(return_value=meeting)
|
||||
|
||||
service = RecoveryService(mock_uow)
|
||||
result = await service.recover_crashed_meetings()
|
||||
meetings, _ = await service.recover_crashed_meetings()
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].state == MeetingState.ERROR
|
||||
assert result[0].metadata["crash_previous_state"] == "STOPPING"
|
||||
assert len(meetings) == 1
|
||||
assert meetings[0].state == MeetingState.ERROR
|
||||
assert meetings[0].metadata["crash_previous_state"] == "STOPPING"
|
||||
mock_uow.commit.assert_called_once()
|
||||
|
||||
async def test_recover_multiple_crashed_meetings(self, mock_uow: MagicMock) -> None:
|
||||
@@ -85,18 +79,18 @@ class TestRecoveryServiceRecovery:
|
||||
meeting3 = Meeting.create(title="Crashed 3")
|
||||
meeting3.start_recording()
|
||||
|
||||
meetings = [meeting1, meeting2, meeting3]
|
||||
mock_uow.meetings.list_all = AsyncMock(return_value=(meetings, 3))
|
||||
mock_uow.meetings.update = AsyncMock(side_effect=meetings)
|
||||
crashed_meetings = [meeting1, meeting2, meeting3]
|
||||
mock_uow.meetings.list_all = AsyncMock(return_value=(crashed_meetings, 3))
|
||||
mock_uow.meetings.update = AsyncMock(side_effect=crashed_meetings)
|
||||
|
||||
service = RecoveryService(mock_uow)
|
||||
result = await service.recover_crashed_meetings()
|
||||
meetings, _ = await service.recover_crashed_meetings()
|
||||
|
||||
assert len(result) == 3
|
||||
assert all(m.state == MeetingState.ERROR for m in result)
|
||||
assert result[0].metadata["crash_previous_state"] == "RECORDING"
|
||||
assert result[1].metadata["crash_previous_state"] == "STOPPING"
|
||||
assert result[2].metadata["crash_previous_state"] == "RECORDING"
|
||||
assert len(meetings) == 3
|
||||
assert all(m.state == MeetingState.ERROR for m in meetings)
|
||||
assert meetings[0].metadata["crash_previous_state"] == "RECORDING"
|
||||
assert meetings[1].metadata["crash_previous_state"] == "STOPPING"
|
||||
assert meetings[2].metadata["crash_previous_state"] == "RECORDING"
|
||||
assert mock_uow.meetings.update.call_count == 3
|
||||
mock_uow.commit.assert_called_once()
|
||||
|
||||
@@ -147,12 +141,191 @@ class TestRecoveryServiceMetadata:
|
||||
mock_uow.meetings.update = AsyncMock(return_value=meeting)
|
||||
|
||||
service = RecoveryService(mock_uow)
|
||||
result = await service.recover_crashed_meetings()
|
||||
meetings, _ = await service.recover_crashed_meetings()
|
||||
|
||||
assert len(result) == 1
|
||||
assert len(meetings) == 1
|
||||
# Verify original metadata preserved
|
||||
assert result[0].metadata["project"] == "NoteFlow"
|
||||
assert result[0].metadata["important"] == "yes"
|
||||
assert meetings[0].metadata["project"] == "NoteFlow"
|
||||
assert meetings[0].metadata["important"] == "yes"
|
||||
# Verify recovery metadata added
|
||||
assert result[0].metadata["crash_recovered"] == "true"
|
||||
assert result[0].metadata["crash_previous_state"] == "RECORDING"
|
||||
assert meetings[0].metadata["crash_recovered"] == "true"
|
||||
assert meetings[0].metadata["crash_previous_state"] == "RECORDING"
|
||||
|
||||
|
||||
class TestRecoveryServiceAudioValidation:
|
||||
"""Tests for audio file validation during recovery."""
|
||||
|
||||
@pytest.fixture
|
||||
def meetings_dir(self, tmp_path: Path) -> Path:
|
||||
"""Create temporary meetings directory."""
|
||||
return tmp_path / "meetings"
|
||||
|
||||
def test_audio_validation_skipped_without_meetings_dir(self, mock_uow: MagicMock) -> None:
|
||||
"""Test audio validation skipped when no meetings_dir configured."""
|
||||
meeting = Meeting.create(title="Test Meeting")
|
||||
meeting.start_recording()
|
||||
|
||||
service = RecoveryService(mock_uow, meetings_dir=None)
|
||||
result = service._validate_meeting_audio(meeting)
|
||||
|
||||
assert result.is_valid is True
|
||||
assert result.manifest_exists is True
|
||||
assert result.audio_exists is True
|
||||
assert "skipped" in (result.error_message or "").lower()
|
||||
|
||||
def test_audio_validation_missing_directory(
|
||||
self, mock_uow: MagicMock, meetings_dir: Path
|
||||
) -> None:
|
||||
"""Test validation fails when meeting directory does not exist."""
|
||||
meeting = Meeting.create(title="Missing Dir")
|
||||
meeting.start_recording()
|
||||
|
||||
service = RecoveryService(mock_uow, meetings_dir=meetings_dir)
|
||||
result = service._validate_meeting_audio(meeting)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert result.manifest_exists is False
|
||||
assert result.audio_exists is False
|
||||
assert "missing" in (result.error_message or "").lower()
|
||||
|
||||
def test_audio_validation_missing_manifest(
|
||||
self, mock_uow: MagicMock, meetings_dir: Path
|
||||
) -> None:
|
||||
"""Test validation fails when only audio.enc exists."""
|
||||
meeting = Meeting.create(title="Missing Manifest")
|
||||
meeting.start_recording()
|
||||
|
||||
# Create meeting directory with only audio.enc
|
||||
meeting_path = meetings_dir / str(meeting.id)
|
||||
meeting_path.mkdir(parents=True)
|
||||
(meeting_path / "audio.enc").touch()
|
||||
|
||||
service = RecoveryService(mock_uow, meetings_dir=meetings_dir)
|
||||
result = service._validate_meeting_audio(meeting)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert result.manifest_exists is False
|
||||
assert result.audio_exists is True
|
||||
assert "manifest.json" in (result.error_message or "")
|
||||
|
||||
def test_audio_validation_missing_audio(self, mock_uow: MagicMock, meetings_dir: Path) -> None:
|
||||
"""Test validation fails when only manifest.json exists."""
|
||||
meeting = Meeting.create(title="Missing Audio")
|
||||
meeting.start_recording()
|
||||
|
||||
# Create meeting directory with only manifest.json
|
||||
meeting_path = meetings_dir / str(meeting.id)
|
||||
meeting_path.mkdir(parents=True)
|
||||
(meeting_path / "manifest.json").touch()
|
||||
|
||||
service = RecoveryService(mock_uow, meetings_dir=meetings_dir)
|
||||
result = service._validate_meeting_audio(meeting)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert result.manifest_exists is True
|
||||
assert result.audio_exists is False
|
||||
assert "audio.enc" in (result.error_message or "")
|
||||
|
||||
def test_audio_validation_success(self, mock_uow: MagicMock, meetings_dir: Path) -> None:
|
||||
"""Test validation succeeds when both files exist."""
|
||||
meeting = Meeting.create(title="Complete Meeting")
|
||||
meeting.start_recording()
|
||||
|
||||
# Create meeting directory with both files
|
||||
meeting_path = meetings_dir / str(meeting.id)
|
||||
meeting_path.mkdir(parents=True)
|
||||
(meeting_path / "manifest.json").touch()
|
||||
(meeting_path / "audio.enc").touch()
|
||||
|
||||
service = RecoveryService(mock_uow, meetings_dir=meetings_dir)
|
||||
result = service._validate_meeting_audio(meeting)
|
||||
|
||||
assert result.is_valid is True
|
||||
assert result.manifest_exists is True
|
||||
assert result.audio_exists is True
|
||||
assert result.error_message is None
|
||||
|
||||
def test_audio_validation_uses_asset_path_metadata(
|
||||
self, mock_uow: MagicMock, meetings_dir: Path
|
||||
) -> None:
|
||||
"""Test validation uses asset_path from metadata if available."""
|
||||
meeting = Meeting.create(
|
||||
title="Custom Path",
|
||||
metadata={"asset_path": "custom-path-123"},
|
||||
)
|
||||
meeting.start_recording()
|
||||
|
||||
# Create meeting at custom asset path
|
||||
meeting_path = meetings_dir / "custom-path-123"
|
||||
meeting_path.mkdir(parents=True)
|
||||
(meeting_path / "manifest.json").touch()
|
||||
(meeting_path / "audio.enc").touch()
|
||||
|
||||
service = RecoveryService(mock_uow, meetings_dir=meetings_dir)
|
||||
result = service._validate_meeting_audio(meeting)
|
||||
|
||||
assert result.is_valid is True
|
||||
|
||||
async def test_recovery_counts_audio_failures(
|
||||
self, mock_uow: MagicMock, meetings_dir: Path
|
||||
) -> None:
|
||||
"""Test recovery tracks audio validation failure count."""
|
||||
meeting1 = Meeting.create(title="Has Audio")
|
||||
meeting1.start_recording()
|
||||
|
||||
meeting2 = Meeting.create(title="Missing Audio")
|
||||
meeting2.start_recording()
|
||||
|
||||
# Create directory for meeting1 only
|
||||
meeting1_path = meetings_dir / str(meeting1.id)
|
||||
meeting1_path.mkdir(parents=True)
|
||||
(meeting1_path / "manifest.json").touch()
|
||||
(meeting1_path / "audio.enc").touch()
|
||||
|
||||
mock_uow.meetings.list_all = AsyncMock(return_value=([meeting1, meeting2], 2))
|
||||
mock_uow.meetings.update = AsyncMock(side_effect=[meeting1, meeting2])
|
||||
|
||||
service = RecoveryService(mock_uow, meetings_dir=meetings_dir)
|
||||
meetings, audio_failures = await service.recover_crashed_meetings()
|
||||
|
||||
assert len(meetings) == 2
|
||||
assert audio_failures == 1
|
||||
assert meetings[0].metadata["audio_valid"] == "true"
|
||||
assert meetings[1].metadata["audio_valid"] == "false"
|
||||
assert "audio_error" in meetings[1].metadata
|
||||
|
||||
|
||||
class TestAudioValidationResult:
|
||||
"""Tests for AudioValidationResult dataclass."""
|
||||
|
||||
def test_audio_validation_result_is_frozen(self) -> None:
|
||||
"""Test AudioValidationResult is immutable."""
|
||||
result = AudioValidationResult(
|
||||
is_valid=True,
|
||||
manifest_exists=True,
|
||||
audio_exists=True,
|
||||
)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
result.is_valid = False # type: ignore[misc]
|
||||
|
||||
def test_audio_validation_result_optional_error(self) -> None:
|
||||
"""Test error_message defaults to None."""
|
||||
result = AudioValidationResult(
|
||||
is_valid=True,
|
||||
manifest_exists=True,
|
||||
audio_exists=True,
|
||||
)
|
||||
|
||||
assert result.error_message is None
|
||||
|
||||
def test_audio_validation_result_with_error(self) -> None:
|
||||
"""Test AudioValidationResult stores error message."""
|
||||
result = AudioValidationResult(
|
||||
is_valid=False,
|
||||
manifest_exists=False,
|
||||
audio_exists=False,
|
||||
error_message="Test error",
|
||||
)
|
||||
|
||||
assert result.error_message == "Test error"
|
||||
|
||||
@@ -58,15 +58,6 @@ class TestRetentionServiceProperties:
|
||||
class TestRetentionServiceFindExpired:
|
||||
"""Tests for find_expired_meetings method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_uow(self) -> MagicMock:
|
||||
"""Create mock UnitOfWork."""
|
||||
uow = MagicMock()
|
||||
uow.__aenter__ = AsyncMock(return_value=uow)
|
||||
uow.__aexit__ = AsyncMock(return_value=None)
|
||||
uow.meetings = MagicMock()
|
||||
return uow
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_expired_returns_meetings(self, mock_uow: MagicMock) -> None:
|
||||
"""find_expired_meetings should return meetings from repository."""
|
||||
@@ -93,16 +84,6 @@ class TestRetentionServiceFindExpired:
|
||||
class TestRetentionServiceRunCleanup:
|
||||
"""Tests for run_cleanup method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_uow(self) -> MagicMock:
|
||||
"""Create mock UnitOfWork."""
|
||||
uow = MagicMock()
|
||||
uow.__aenter__ = AsyncMock(return_value=uow)
|
||||
uow.__aexit__ = AsyncMock(return_value=None)
|
||||
uow.meetings = MagicMock()
|
||||
uow.commit = AsyncMock()
|
||||
return uow
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cleanup_disabled_returns_empty_report(self, mock_uow: MagicMock) -> None:
|
||||
"""run_cleanup should return empty report when disabled."""
|
||||
|
||||
@@ -144,7 +144,8 @@ class TestSummarizationServiceConfiguration:
|
||||
|
||||
assert SummarizationMode.LOCAL not in available
|
||||
|
||||
def test_cloud_requires_consent(self) -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cloud_requires_consent(self) -> None:
|
||||
"""Cloud mode should require consent to be available."""
|
||||
service = SummarizationService()
|
||||
service.register_provider(
|
||||
@@ -153,22 +154,23 @@ class TestSummarizationServiceConfiguration:
|
||||
)
|
||||
|
||||
available_without_consent = service.get_available_modes()
|
||||
service.grant_cloud_consent()
|
||||
await service.grant_cloud_consent()
|
||||
available_with_consent = service.get_available_modes()
|
||||
|
||||
assert SummarizationMode.CLOUD not in available_without_consent
|
||||
assert SummarizationMode.CLOUD in available_with_consent
|
||||
|
||||
def test_revoke_cloud_consent(self) -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_cloud_consent(self) -> None:
|
||||
"""Revoking consent should remove cloud from available modes."""
|
||||
service = SummarizationService()
|
||||
service.register_provider(
|
||||
SummarizationMode.CLOUD,
|
||||
MockProvider(name="cloud", requires_consent=True),
|
||||
)
|
||||
service.grant_cloud_consent()
|
||||
await service.grant_cloud_consent()
|
||||
|
||||
service.revoke_cloud_consent()
|
||||
await service.revoke_cloud_consent()
|
||||
available = service.get_available_modes()
|
||||
|
||||
assert SummarizationMode.CLOUD not in available
|
||||
|
||||
@@ -10,6 +10,7 @@ from __future__ import annotations
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -94,3 +95,24 @@ def mock_optional_extras() -> None:
|
||||
pywinctl_module.getAllWindows = lambda: []
|
||||
pywinctl_module.getAllTitles = lambda: []
|
||||
sys.modules["pywinctl"] = pywinctl_module
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_uow() -> MagicMock:
|
||||
"""Create a mock UnitOfWork for service tests.
|
||||
|
||||
Provides a fully-configured mock UnitOfWork with all repository mocks
|
||||
and async context manager support.
|
||||
"""
|
||||
uow = MagicMock()
|
||||
uow.__aenter__ = AsyncMock(return_value=uow)
|
||||
uow.__aexit__ = AsyncMock(return_value=None)
|
||||
uow.commit = AsyncMock()
|
||||
uow.rollback = AsyncMock()
|
||||
uow.meetings = MagicMock()
|
||||
uow.segments = MagicMock()
|
||||
uow.summaries = MagicMock()
|
||||
uow.annotations = MagicMock()
|
||||
uow.preferences = MagicMock()
|
||||
uow.diarization_jobs = MagicMock()
|
||||
return uow
|
||||
|
||||
15
tests/fixtures/__init__.py
vendored
Normal file
15
tests/fixtures/__init__.py
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Shared test fixtures and utilities."""
|
||||
|
||||
from support.db_utils import (
|
||||
PgTestContainer,
|
||||
create_test_session_factory,
|
||||
get_or_create_container,
|
||||
initialize_test_schema,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PgTestContainer",
|
||||
"create_test_session_factory",
|
||||
"get_or_create_container",
|
||||
"initialize_test_schema",
|
||||
]
|
||||
@@ -59,7 +59,9 @@ async def test_generate_summary_falls_back_when_provider_unavailable() -> None:
|
||||
meeting = store.create("Test Meeting")
|
||||
store.add_segment(
|
||||
str(meeting.id),
|
||||
Segment(segment_id=1, text="Action item noted", start_time=0.0, end_time=2.0, language="en"),
|
||||
Segment(
|
||||
segment_id=1, text="Action item noted", start_time=0.0, end_time=2.0, language="en"
|
||||
),
|
||||
)
|
||||
|
||||
response = await servicer.GenerateSummary(
|
||||
|
||||
@@ -80,7 +80,8 @@ class TestMeetingAudioWriterBasics:
|
||||
meetings_dir: Path,
|
||||
) -> None:
|
||||
"""Test audio conversion from float32 to PCM16."""
|
||||
writer = MeetingAudioWriter(crypto, meetings_dir)
|
||||
# Use small buffer to force immediate flush
|
||||
writer = MeetingAudioWriter(crypto, meetings_dir, buffer_size=1000)
|
||||
meeting_id = str(uuid4())
|
||||
dek = crypto.generate_dek()
|
||||
wrapped_dek = crypto.wrap_dek(dek)
|
||||
@@ -91,10 +92,12 @@ class TestMeetingAudioWriterBasics:
|
||||
test_audio = np.linspace(-1.0, 1.0, 1600, dtype=np.float32)
|
||||
writer.write_chunk(test_audio)
|
||||
|
||||
# Audio is 3200 bytes, buffer is 1000, so should flush
|
||||
assert writer.bytes_written > 0
|
||||
# PCM16 = 2 bytes/sample = 3200 bytes raw, but encrypted with overhead
|
||||
assert writer.bytes_written > 3200
|
||||
assert writer.chunk_count == 1
|
||||
assert writer.write_count == 1
|
||||
|
||||
writer.close()
|
||||
|
||||
@@ -103,25 +106,89 @@ class TestMeetingAudioWriterBasics:
|
||||
crypto: AesGcmCryptoBox,
|
||||
meetings_dir: Path,
|
||||
) -> None:
|
||||
"""Test writing multiple audio chunks."""
|
||||
writer = MeetingAudioWriter(crypto, meetings_dir)
|
||||
"""Test writing multiple audio chunks with buffering."""
|
||||
# Use small buffer to test buffering behavior
|
||||
writer = MeetingAudioWriter(crypto, meetings_dir, buffer_size=10000)
|
||||
meeting_id = str(uuid4())
|
||||
dek = crypto.generate_dek()
|
||||
wrapped_dek = crypto.wrap_dek(dek)
|
||||
|
||||
writer.open(meeting_id, dek, wrapped_dek)
|
||||
|
||||
# Write 100 chunks
|
||||
for _ in range(100):
|
||||
# Write 100 chunks of 1600 samples each (3200 bytes per write)
|
||||
# Buffer is 10000, so ~3 writes per encrypted chunk
|
||||
num_writes = 100
|
||||
bytes_per_write = 1600 * 2 # 3200 bytes
|
||||
|
||||
for _ in range(num_writes):
|
||||
audio = np.random.uniform(-0.5, 0.5, 1600).astype(np.float32)
|
||||
writer.write_chunk(audio)
|
||||
|
||||
# Should have written significant data
|
||||
assert writer.bytes_written > 100 * 3200 # At least raw PCM16 size
|
||||
assert writer.chunk_count == 100
|
||||
# write_count tracks incoming audio frames
|
||||
assert writer.write_count == num_writes
|
||||
|
||||
# Due to buffering, chunk_count should be much less than write_count
|
||||
# 100 writes * 3200 bytes = 320,000 bytes / 10000 buffer = ~32 flushes
|
||||
# Some bytes may still be buffered
|
||||
assert writer.chunk_count < num_writes
|
||||
|
||||
# Flush remaining and check bytes written before close
|
||||
writer.flush()
|
||||
|
||||
# Total raw bytes = 100 * 3200 = 320,000 bytes
|
||||
# Encrypted size includes overhead per chunk
|
||||
assert writer.bytes_written > num_writes * bytes_per_write
|
||||
|
||||
writer.close()
|
||||
|
||||
def test_buffering_reduces_chunk_overhead(
|
||||
self,
|
||||
crypto: AesGcmCryptoBox,
|
||||
meetings_dir: Path,
|
||||
) -> None:
|
||||
"""Test that buffering reduces encryption overhead."""
|
||||
# Create two writers with different buffer sizes
|
||||
small_buffer_writer = MeetingAudioWriter(crypto, meetings_dir, buffer_size=1000)
|
||||
large_buffer_writer = MeetingAudioWriter(crypto, meetings_dir, buffer_size=1_000_000)
|
||||
|
||||
dek = crypto.generate_dek()
|
||||
wrapped_dek = crypto.wrap_dek(dek)
|
||||
|
||||
# Write same audio to both
|
||||
meeting_id_small = str(uuid4())
|
||||
meeting_id_large = str(uuid4())
|
||||
|
||||
small_buffer_writer.open(meeting_id_small, dek, wrapped_dek)
|
||||
large_buffer_writer.open(meeting_id_large, dek, wrapped_dek)
|
||||
|
||||
# Generate consistent test data
|
||||
np.random.seed(42)
|
||||
|
||||
# Write 50 chunks (160,000 bytes raw)
|
||||
for _ in range(50):
|
||||
audio = np.random.uniform(-0.5, 0.5, 1600).astype(np.float32)
|
||||
small_buffer_writer.write_chunk(audio)
|
||||
|
||||
np.random.seed(42) # Reset seed to generate same audio
|
||||
|
||||
for _ in range(50):
|
||||
audio = np.random.uniform(-0.5, 0.5, 1600).astype(np.float32)
|
||||
large_buffer_writer.write_chunk(audio)
|
||||
|
||||
# Flush to ensure all data is written before comparing
|
||||
small_buffer_writer.flush()
|
||||
large_buffer_writer.flush()
|
||||
|
||||
# Large buffer should have fewer encrypted chunks (less overhead)
|
||||
assert large_buffer_writer.chunk_count < small_buffer_writer.chunk_count
|
||||
|
||||
# Large buffer should use less total disk space due to fewer chunks
|
||||
# Each chunk has 32 bytes overhead (4 length + 12 nonce + 16 tag)
|
||||
assert large_buffer_writer.bytes_written < small_buffer_writer.bytes_written
|
||||
|
||||
small_buffer_writer.close()
|
||||
large_buffer_writer.close()
|
||||
|
||||
def test_write_chunk_clamps_audio_range(
|
||||
self,
|
||||
crypto: AesGcmCryptoBox,
|
||||
@@ -135,7 +202,7 @@ class TestMeetingAudioWriterBasics:
|
||||
|
||||
writer.open(meeting_id, dek, wrapped_dek)
|
||||
writer.write_chunk(np.array([-2.0, 0.0, 2.0], dtype=np.float32))
|
||||
writer.close()
|
||||
writer.close() # Flushes buffer to disk
|
||||
|
||||
audio_path = meetings_dir / meeting_id / "audio.enc"
|
||||
reader = ChunkedAssetReader(crypto)
|
||||
@@ -150,6 +217,37 @@ class TestMeetingAudioWriterBasics:
|
||||
|
||||
reader.close()
|
||||
|
||||
def test_flush_writes_buffered_data(
|
||||
self,
|
||||
crypto: AesGcmCryptoBox,
|
||||
meetings_dir: Path,
|
||||
) -> None:
|
||||
"""Test explicit flush writes buffered data to disk."""
|
||||
# Large buffer to prevent auto-flush
|
||||
writer = MeetingAudioWriter(crypto, meetings_dir, buffer_size=1_000_000)
|
||||
meeting_id = str(uuid4())
|
||||
dek = crypto.generate_dek()
|
||||
wrapped_dek = crypto.wrap_dek(dek)
|
||||
|
||||
writer.open(meeting_id, dek, wrapped_dek)
|
||||
|
||||
# Write small audio chunk (won't trigger auto-flush)
|
||||
writer.write_chunk(np.zeros(1600, dtype=np.float32))
|
||||
|
||||
# Data should be buffered, not written
|
||||
assert writer.buffered_bytes > 0
|
||||
assert writer.chunk_count == 0
|
||||
|
||||
# Explicit flush
|
||||
writer.flush()
|
||||
|
||||
# Data should now be written
|
||||
assert writer.buffered_bytes == 0
|
||||
assert writer.chunk_count == 1
|
||||
assert writer.bytes_written > 0
|
||||
|
||||
writer.close()
|
||||
|
||||
|
||||
class TestMeetingAudioWriterErrors:
|
||||
"""Tests for MeetingAudioWriter error handling."""
|
||||
@@ -258,7 +356,7 @@ class TestMeetingAudioWriterIntegration:
|
||||
meetings_dir: Path,
|
||||
) -> None:
|
||||
"""Test writing audio, then reading it back encrypted."""
|
||||
# Write audio
|
||||
# Write audio (default buffer aggregates chunks)
|
||||
writer = MeetingAudioWriter(crypto, meetings_dir)
|
||||
meeting_id = str(uuid4())
|
||||
dek = crypto.generate_dek()
|
||||
@@ -282,22 +380,21 @@ class TestMeetingAudioWriterIntegration:
|
||||
reader = ChunkedAssetReader(crypto)
|
||||
reader.open(audio_path, dek)
|
||||
|
||||
read_chunks: list[np.ndarray] = []
|
||||
for chunk_bytes in reader.read_chunks():
|
||||
# Convert bytes back to PCM16 then to float32
|
||||
pcm16 = np.frombuffer(chunk_bytes, dtype=np.int16)
|
||||
audio_float = pcm16.astype(np.float32) / 32767.0
|
||||
read_chunks.append(audio_float)
|
||||
|
||||
# Collect all decrypted audio bytes (may be fewer chunks due to buffering)
|
||||
all_audio_bytes = b"".join(reader.read_chunks())
|
||||
reader.close()
|
||||
|
||||
# Verify we read same number of chunks
|
||||
assert len(read_chunks) == len(original_chunks)
|
||||
# Convert bytes back to float32
|
||||
pcm16 = np.frombuffer(all_audio_bytes, dtype=np.int16)
|
||||
read_audio = pcm16.astype(np.float32) / 32767.0
|
||||
|
||||
# Concatenate original chunks for comparison
|
||||
original_audio = np.concatenate(original_chunks)
|
||||
|
||||
# Verify audio content matches (within quantization error)
|
||||
for orig, read in zip(original_chunks, read_chunks, strict=True):
|
||||
# PCM16 quantization adds ~0.00003 max error
|
||||
assert np.allclose(orig, read, atol=0.0001)
|
||||
assert len(read_audio) == len(original_audio)
|
||||
# PCM16 quantization adds ~0.00003 max error
|
||||
assert np.allclose(original_audio, read_audio, atol=0.0001)
|
||||
|
||||
def test_manifest_wrapped_dek_can_decrypt_audio(
|
||||
self,
|
||||
@@ -332,3 +429,121 @@ class TestMeetingAudioWriterIntegration:
|
||||
assert len(chunks) == 1 # Should read the one chunk we wrote
|
||||
|
||||
reader.close()
|
||||
|
||||
|
||||
class TestMeetingAudioWriterPeriodicFlush:
|
||||
"""Tests for periodic flush thread functionality."""
|
||||
|
||||
def test_periodic_flush_thread_starts_on_open(
|
||||
self,
|
||||
crypto: AesGcmCryptoBox,
|
||||
meetings_dir: Path,
|
||||
) -> None:
|
||||
"""Test periodic flush thread is started when writer is opened."""
|
||||
writer = MeetingAudioWriter(crypto, meetings_dir)
|
||||
meeting_id = str(uuid4())
|
||||
dek = crypto.generate_dek()
|
||||
wrapped_dek = crypto.wrap_dek(dek)
|
||||
|
||||
assert writer._flush_thread is None
|
||||
|
||||
writer.open(meeting_id, dek, wrapped_dek)
|
||||
assert writer._flush_thread is not None
|
||||
assert writer._flush_thread.is_alive()
|
||||
|
||||
writer.close()
|
||||
assert writer._flush_thread is None or not writer._flush_thread.is_alive()
|
||||
|
||||
def test_periodic_flush_thread_stops_on_close(
|
||||
self,
|
||||
crypto: AesGcmCryptoBox,
|
||||
meetings_dir: Path,
|
||||
) -> None:
|
||||
"""Test periodic flush thread stops cleanly on close."""
|
||||
writer = MeetingAudioWriter(crypto, meetings_dir)
|
||||
meeting_id = str(uuid4())
|
||||
dek = crypto.generate_dek()
|
||||
wrapped_dek = crypto.wrap_dek(dek)
|
||||
|
||||
writer.open(meeting_id, dek, wrapped_dek)
|
||||
flush_thread = writer._flush_thread
|
||||
assert flush_thread is not None
|
||||
|
||||
writer.close()
|
||||
|
||||
# Thread should be stopped
|
||||
assert not flush_thread.is_alive()
|
||||
assert writer._stop_flush.is_set()
|
||||
|
||||
def test_flush_is_thread_safe(
|
||||
self,
|
||||
crypto: AesGcmCryptoBox,
|
||||
meetings_dir: Path,
|
||||
) -> None:
|
||||
"""Test concurrent writes and flushes do not corrupt data."""
|
||||
import threading
|
||||
|
||||
writer = MeetingAudioWriter(crypto, meetings_dir, buffer_size=1_000_000)
|
||||
meeting_id = str(uuid4())
|
||||
dek = crypto.generate_dek()
|
||||
wrapped_dek = crypto.wrap_dek(dek)
|
||||
|
||||
writer.open(meeting_id, dek, wrapped_dek)
|
||||
|
||||
errors: list[Exception] = []
|
||||
write_count = 100
|
||||
|
||||
def write_audio() -> None:
|
||||
try:
|
||||
for _ in range(write_count):
|
||||
audio = np.random.uniform(-0.5, 0.5, 1600).astype(np.float32)
|
||||
writer.write_chunk(audio)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
def flush_repeatedly() -> None:
|
||||
try:
|
||||
for _ in range(50):
|
||||
writer.flush()
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
write_thread = threading.Thread(target=write_audio)
|
||||
flush_thread = threading.Thread(target=flush_repeatedly)
|
||||
|
||||
write_thread.start()
|
||||
flush_thread.start()
|
||||
|
||||
write_thread.join()
|
||||
flush_thread.join()
|
||||
|
||||
# Check write_count before close (close resets it)
|
||||
assert writer.write_count == write_count
|
||||
|
||||
writer.close()
|
||||
|
||||
# No exceptions should have occurred
|
||||
assert not errors
|
||||
|
||||
def test_flush_when_closed_raises_error(
|
||||
self,
|
||||
crypto: AesGcmCryptoBox,
|
||||
meetings_dir: Path,
|
||||
) -> None:
|
||||
"""Test flush raises RuntimeError when writer is closed."""
|
||||
writer = MeetingAudioWriter(crypto, meetings_dir)
|
||||
|
||||
# Should raise when not open
|
||||
with pytest.raises(RuntimeError, match="not open"):
|
||||
writer.flush()
|
||||
|
||||
# Open, then close, then flush should also raise
|
||||
meeting_id = str(uuid4())
|
||||
dek = crypto.generate_dek()
|
||||
wrapped_dek = crypto.wrap_dek(dek)
|
||||
|
||||
writer.open(meeting_id, dek, wrapped_dek)
|
||||
writer.close()
|
||||
|
||||
with pytest.raises(RuntimeError, match="not open"):
|
||||
writer.flush()
|
||||
|
||||
@@ -54,8 +54,11 @@ def test_get_or_create_master_key_creates_and_reuses(monkeypatch: pytest.MonkeyP
|
||||
assert ("svc", "key") in storage
|
||||
|
||||
|
||||
def test_get_or_create_master_key_wraps_keyring_errors(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Keyring errors should surface as RuntimeError."""
|
||||
def test_get_or_create_master_key_falls_back_to_file(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Any,
|
||||
) -> None:
|
||||
"""Keyring errors should fall back to file-based key storage."""
|
||||
|
||||
class DummyErrors:
|
||||
class KeyringError(Exception): ...
|
||||
@@ -73,10 +76,15 @@ def test_get_or_create_master_key_wraps_keyring_errors(monkeypatch: pytest.Monke
|
||||
delete_password=raise_error,
|
||||
),
|
||||
)
|
||||
# Use temp path for file fallback
|
||||
key_file = tmp_path / ".master_key"
|
||||
monkeypatch.setattr(keystore, "DEFAULT_KEY_FILE", key_file)
|
||||
|
||||
ks = keystore.KeyringKeyStore()
|
||||
with pytest.raises(RuntimeError, match="Keyring unavailable"):
|
||||
ks.get_or_create_master_key()
|
||||
key = ks.get_or_create_master_key()
|
||||
|
||||
assert len(key) == keystore.KEY_SIZE
|
||||
assert key_file.exists()
|
||||
|
||||
|
||||
def test_delete_master_key_handles_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@@ -130,3 +138,81 @@ def test_has_master_key_false_on_errors(monkeypatch: pytest.MonkeyPatch) -> None
|
||||
|
||||
ks = keystore.KeyringKeyStore()
|
||||
assert ks.has_master_key() is False
|
||||
|
||||
|
||||
class TestFileKeyStore:
|
||||
"""Tests for FileKeyStore fallback implementation."""
|
||||
|
||||
def test_creates_and_reuses_key(self, tmp_path: Any) -> None:
|
||||
"""File key store should create key once and reuse it."""
|
||||
key_file = tmp_path / ".master_key"
|
||||
fks = keystore.FileKeyStore(key_file)
|
||||
|
||||
first = fks.get_or_create_master_key()
|
||||
second = fks.get_or_create_master_key()
|
||||
|
||||
assert len(first) == keystore.KEY_SIZE
|
||||
assert first == second
|
||||
assert key_file.exists()
|
||||
|
||||
def test_creates_parent_directories(self, tmp_path: Any) -> None:
|
||||
"""File key store should create parent directories."""
|
||||
key_file = tmp_path / "nested" / "dir" / ".master_key"
|
||||
fks = keystore.FileKeyStore(key_file)
|
||||
|
||||
fks.get_or_create_master_key()
|
||||
|
||||
assert key_file.exists()
|
||||
|
||||
def test_has_master_key_true_when_exists(self, tmp_path: Any) -> None:
|
||||
"""has_master_key should return True when file exists."""
|
||||
key_file = tmp_path / ".master_key"
|
||||
fks = keystore.FileKeyStore(key_file)
|
||||
fks.get_or_create_master_key()
|
||||
|
||||
assert fks.has_master_key() is True
|
||||
|
||||
def test_has_master_key_false_when_missing(self, tmp_path: Any) -> None:
|
||||
"""has_master_key should return False when file is missing."""
|
||||
key_file = tmp_path / ".master_key"
|
||||
fks = keystore.FileKeyStore(key_file)
|
||||
|
||||
assert fks.has_master_key() is False
|
||||
|
||||
def test_delete_master_key_removes_file(self, tmp_path: Any) -> None:
|
||||
"""delete_master_key should remove the key file."""
|
||||
key_file = tmp_path / ".master_key"
|
||||
fks = keystore.FileKeyStore(key_file)
|
||||
fks.get_or_create_master_key()
|
||||
|
||||
fks.delete_master_key()
|
||||
|
||||
assert not key_file.exists()
|
||||
|
||||
def test_delete_master_key_safe_when_missing(self, tmp_path: Any) -> None:
|
||||
"""delete_master_key should not raise when file is missing."""
|
||||
key_file = tmp_path / ".master_key"
|
||||
fks = keystore.FileKeyStore(key_file)
|
||||
|
||||
fks.delete_master_key() # Should not raise
|
||||
|
||||
def test_invalid_base64_raises_runtime_error(self, tmp_path: Any) -> None:
|
||||
"""Invalid base64 in key file should raise RuntimeError."""
|
||||
key_file = tmp_path / ".master_key"
|
||||
key_file.write_text("not-valid-base64!!!")
|
||||
fks = keystore.FileKeyStore(key_file)
|
||||
|
||||
with pytest.raises(RuntimeError, match="invalid base64"):
|
||||
fks.get_or_create_master_key()
|
||||
|
||||
def test_wrong_size_raises_runtime_error(self, tmp_path: Any) -> None:
|
||||
"""Wrong key size in file should raise RuntimeError."""
|
||||
import base64
|
||||
|
||||
key_file = tmp_path / ".master_key"
|
||||
# Write a key that's too short (16 bytes instead of 32)
|
||||
key_file.write_text(base64.b64encode(b"short_key").decode())
|
||||
fks = keystore.FileKeyStore(key_file)
|
||||
|
||||
with pytest.raises(RuntimeError, match="wrong key size"):
|
||||
fks.get_or_create_master_key()
|
||||
|
||||
@@ -2,112 +2,19 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from importlib import import_module
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import quote
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Self
|
||||
|
||||
from noteflow.infrastructure.persistence.models import Base
|
||||
|
||||
|
||||
# Store container reference at module level to reuse
|
||||
class PgTestContainer:
|
||||
"""Minimal Postgres testcontainer wrapper with custom readiness wait."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image: str = "pgvector/pgvector:pg16",
|
||||
username: str = "test",
|
||||
password: str = "test",
|
||||
dbname: str = "noteflow_test",
|
||||
port: int = 5432,
|
||||
) -> None:
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.dbname = dbname
|
||||
self.port = port
|
||||
|
||||
container_module = import_module("testcontainers.core.container")
|
||||
docker_container_cls = container_module.DockerContainer
|
||||
self._container = (
|
||||
docker_container_cls(image)
|
||||
.with_env("POSTGRES_USER", username)
|
||||
.with_env("POSTGRES_PASSWORD", password)
|
||||
.with_env("POSTGRES_DB", dbname)
|
||||
.with_exposed_ports(port)
|
||||
)
|
||||
|
||||
def start(self) -> Self:
|
||||
"""Start the container."""
|
||||
self._container.start()
|
||||
self._wait_until_ready()
|
||||
return self
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the container."""
|
||||
self._container.stop()
|
||||
|
||||
def get_connection_url(self) -> str:
|
||||
"""Return a SQLAlchemy-style connection URL."""
|
||||
host = self._container.get_container_host_ip()
|
||||
port = self._container._get_exposed_port(self.port)
|
||||
quoted_password = quote(self.password, safe=" +")
|
||||
return f"postgresql+psycopg2://{self.username}:{quoted_password}@{host}:{port}/{self.dbname}"
|
||||
|
||||
def _wait_until_ready(self, timeout: float = 30.0, interval: float = 0.5) -> None:
|
||||
"""Wait for Postgres to accept connections by running a simple query."""
|
||||
start_time = time.time()
|
||||
escaped_password = self.password.replace("'", "'\"'\"'")
|
||||
cmd = [
|
||||
"sh",
|
||||
"-c",
|
||||
(
|
||||
f"PGPASSWORD='{escaped_password}' "
|
||||
f"psql --username {self.username} --dbname {self.dbname} --host 127.0.0.1 "
|
||||
"-c 'select 1;'"
|
||||
),
|
||||
]
|
||||
last_error: str | None = None
|
||||
|
||||
while True:
|
||||
result = self._container.exec(cmd)
|
||||
if result.exit_code == 0:
|
||||
return
|
||||
if result.output:
|
||||
last_error = result.output.decode(errors="ignore")
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(
|
||||
"Postgres container did not become ready in time"
|
||||
+ (f": {last_error}" if last_error else "")
|
||||
)
|
||||
time.sleep(interval)
|
||||
|
||||
|
||||
_container: PgTestContainer | None = None
|
||||
_database_url: str | None = None
|
||||
|
||||
|
||||
def get_or_create_container() -> tuple[PgTestContainer, str]:
|
||||
"""Get or create the PostgreSQL container."""
|
||||
global _container, _database_url
|
||||
|
||||
if _container is None:
|
||||
container = PgTestContainer().start()
|
||||
_container = container
|
||||
url = container.get_connection_url()
|
||||
_database_url = url.replace("postgresql+psycopg2://", "postgresql+asyncpg://")
|
||||
|
||||
assert _container is not None, "Container should be initialized"
|
||||
assert _database_url is not None, "Database URL should be initialized"
|
||||
return _container, _database_url
|
||||
from support.db_utils import (
|
||||
cleanup_test_schema,
|
||||
create_test_engine,
|
||||
create_test_session_factory,
|
||||
get_or_create_container,
|
||||
initialize_test_schema,
|
||||
stop_container,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -115,26 +22,16 @@ async def session_factory() -> AsyncGenerator[async_sessionmaker[AsyncSession],
|
||||
"""Create a session factory and initialize the database schema."""
|
||||
_, database_url = get_or_create_container()
|
||||
|
||||
engine = create_async_engine(database_url, echo=False)
|
||||
engine = create_test_engine(database_url)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
# Create pgvector extension and schema
|
||||
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
await conn.execute(text("DROP SCHEMA IF EXISTS noteflow CASCADE"))
|
||||
await conn.execute(text("CREATE SCHEMA noteflow"))
|
||||
# Create all tables
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
await initialize_test_schema(conn)
|
||||
|
||||
yield create_test_session_factory(engine)
|
||||
|
||||
yield async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
# Cleanup - drop schema to reset for next test
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text("DROP SCHEMA IF EXISTS noteflow CASCADE"))
|
||||
await cleanup_test_schema(conn)
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
@@ -152,7 +49,4 @@ async def session(
|
||||
|
||||
def pytest_sessionfinish(session: pytest.Session, exitstatus: int) -> None:
|
||||
"""Cleanup container after all tests complete."""
|
||||
global _container
|
||||
if _container is not None:
|
||||
_container.stop()
|
||||
_container = None
|
||||
stop_container()
|
||||
|
||||
@@ -2,12 +2,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import quote
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
@@ -16,10 +13,17 @@ import pytest
|
||||
from noteflow.grpc.service import NoteFlowServicer
|
||||
from noteflow.infrastructure.security.crypto import AesGcmCryptoBox
|
||||
from noteflow.infrastructure.security.keystore import InMemoryKeyStore
|
||||
from support.db_utils import (
|
||||
cleanup_test_schema,
|
||||
create_test_engine,
|
||||
create_test_session_factory,
|
||||
get_or_create_container,
|
||||
initialize_test_schema,
|
||||
stop_container,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import Self
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
@@ -38,98 +42,6 @@ class MockAsrResult:
|
||||
no_speech_prob: float = 0.01
|
||||
|
||||
|
||||
# Store container reference at module level to reuse in stress tests
|
||||
class PgTestContainer:
|
||||
"""Minimal Postgres testcontainer wrapper with custom readiness wait."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image: str = "pgvector/pgvector:pg16",
|
||||
username: str = "test",
|
||||
password: str = "test",
|
||||
dbname: str = "noteflow_test",
|
||||
port: int = 5432,
|
||||
) -> None:
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.dbname = dbname
|
||||
self.port = port
|
||||
|
||||
container_module = import_module("testcontainers.core.container")
|
||||
docker_container_cls = container_module.DockerContainer
|
||||
self._container = (
|
||||
docker_container_cls(image)
|
||||
.with_env("POSTGRES_USER", username)
|
||||
.with_env("POSTGRES_PASSWORD", password)
|
||||
.with_env("POSTGRES_DB", dbname)
|
||||
.with_exposed_ports(port)
|
||||
)
|
||||
|
||||
def start(self) -> Self:
|
||||
"""Start the container."""
|
||||
self._container.start()
|
||||
self._wait_until_ready()
|
||||
return self
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the container."""
|
||||
self._container.stop()
|
||||
|
||||
def get_connection_url(self) -> str:
|
||||
"""Return a SQLAlchemy-style connection URL."""
|
||||
host = self._container.get_container_host_ip()
|
||||
port = self._container._get_exposed_port(self.port)
|
||||
quoted_password = quote(self.password, safe=" +")
|
||||
return f"postgresql+psycopg2://{self.username}:{quoted_password}@{host}:{port}/{self.dbname}"
|
||||
|
||||
def _wait_until_ready(self, timeout: float = 30.0, interval: float = 0.5) -> None:
|
||||
"""Wait for Postgres to accept connections by running a simple query."""
|
||||
start_time = time.time()
|
||||
escaped_password = self.password.replace("'", "'\"'\"'")
|
||||
cmd = [
|
||||
"sh",
|
||||
"-c",
|
||||
(
|
||||
f"PGPASSWORD='{escaped_password}' "
|
||||
f"psql --username {self.username} --dbname {self.dbname} --host 127.0.0.1 "
|
||||
"-c 'select 1;'"
|
||||
),
|
||||
]
|
||||
last_error: str | None = None
|
||||
|
||||
while True:
|
||||
result = self._container.exec(cmd)
|
||||
if result.exit_code == 0:
|
||||
return
|
||||
if result.output:
|
||||
last_error = result.output.decode(errors="ignore")
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(
|
||||
"Postgres container did not become ready in time"
|
||||
+ (f": {last_error}" if last_error else "")
|
||||
)
|
||||
time.sleep(interval)
|
||||
|
||||
|
||||
_container: PgTestContainer | None = None
|
||||
_database_url: str | None = None
|
||||
|
||||
|
||||
def get_or_create_container() -> tuple[PgTestContainer, str]:
|
||||
"""Get or create the PostgreSQL container for stress tests."""
|
||||
global _container, _database_url
|
||||
|
||||
if _container is None:
|
||||
container = PgTestContainer().start()
|
||||
_container = container
|
||||
url = container.get_connection_url()
|
||||
_database_url = url.replace("postgresql+psycopg2://", "postgresql+asyncpg://")
|
||||
|
||||
assert _container is not None, "Container should be initialized"
|
||||
assert _database_url is not None, "Database URL should be initialized"
|
||||
return _container, _database_url
|
||||
|
||||
|
||||
def create_mock_asr_engine(transcribe_results: list[str] | None = None) -> MagicMock:
|
||||
"""Create mock ASR engine with configurable transcription results.
|
||||
|
||||
@@ -199,47 +111,27 @@ def memory_servicer(mock_asr_engine: MagicMock, tmp_path: Path) -> NoteFlowServi
|
||||
)
|
||||
|
||||
|
||||
# Import session_factory from integration tests for PostgreSQL backend
|
||||
# This is lazily imported to avoid requiring testcontainers for non-integration tests
|
||||
@pytest.fixture
|
||||
async def postgres_session_factory() -> AsyncGenerator[async_sessionmaker[AsyncSession], None]:
|
||||
"""Create PostgreSQL session factory using testcontainers.
|
||||
|
||||
Uses a local container helper to avoid importing test modules.
|
||||
"""
|
||||
# Import here to avoid requiring testcontainers for all stress tests
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from noteflow.infrastructure.persistence.models import Base
|
||||
|
||||
_, database_url = get_or_create_container()
|
||||
|
||||
engine = create_async_engine(database_url, echo=False)
|
||||
engine = create_test_engine(database_url)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
await conn.execute(text("DROP SCHEMA IF EXISTS noteflow CASCADE"))
|
||||
await conn.execute(text("CREATE SCHEMA noteflow"))
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
await initialize_test_schema(conn)
|
||||
|
||||
yield async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
yield create_test_session_factory(engine)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text("DROP SCHEMA IF EXISTS noteflow CASCADE"))
|
||||
await cleanup_test_schema(conn)
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def pytest_sessionfinish(session: pytest.Session, exitstatus: int) -> None:
|
||||
"""Cleanup container after stress tests complete."""
|
||||
global _container
|
||||
if _container is not None:
|
||||
_container.stop()
|
||||
_container = None
|
||||
stop_container()
|
||||
|
||||
@@ -409,7 +409,11 @@ class TestWriterReaderRoundTrip:
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_multiple_chunks_roundtrip(self, crypto: AesGcmCryptoBox, meetings_dir: Path) -> None:
|
||||
"""Multiple chunk write and read preserves data."""
|
||||
"""Multiple chunk write and read preserves data.
|
||||
|
||||
Note: Due to buffering, the number of encrypted chunks may differ from
|
||||
the number of writes. This test verifies content integrity, not chunk count.
|
||||
"""
|
||||
meeting_id = str(uuid4())
|
||||
dek = crypto.generate_dek()
|
||||
wrapped_dek = crypto.wrap_dek(dek)
|
||||
@@ -425,14 +429,22 @@ class TestWriterReaderRoundTrip:
|
||||
reader = MeetingAudioReader(crypto, meetings_dir)
|
||||
loaded_chunks = reader.load_meeting_audio(meeting_id)
|
||||
|
||||
assert len(loaded_chunks) == len(original_chunks)
|
||||
for original, loaded in zip(original_chunks, loaded_chunks, strict=True):
|
||||
np.testing.assert_array_almost_equal(loaded.frames, original, decimal=4)
|
||||
# Concatenate original and loaded audio for comparison
|
||||
# Buffering may merge chunks, so we compare total content
|
||||
original_audio = np.concatenate(original_chunks)
|
||||
loaded_audio = np.concatenate([chunk.frames for chunk in loaded_chunks])
|
||||
|
||||
assert len(loaded_audio) == len(original_audio)
|
||||
np.testing.assert_array_almost_equal(loaded_audio, original_audio, decimal=4)
|
||||
|
||||
@pytest.mark.stress
|
||||
@pytest.mark.slow
|
||||
def test_large_audio_roundtrip(self, crypto: AesGcmCryptoBox, meetings_dir: Path) -> None:
|
||||
"""Large audio file (1000 chunks) write and read succeeds."""
|
||||
"""Large audio file (1000 chunks) write and read succeeds.
|
||||
|
||||
Note: Due to buffering, the number of encrypted chunks may differ from
|
||||
the number of writes. This test verifies total duration and sample count.
|
||||
"""
|
||||
meeting_id = str(uuid4())
|
||||
dek = crypto.generate_dek()
|
||||
wrapped_dek = crypto.wrap_dek(dek)
|
||||
@@ -449,11 +461,15 @@ class TestWriterReaderRoundTrip:
|
||||
reader = MeetingAudioReader(crypto, meetings_dir)
|
||||
chunks = reader.load_meeting_audio(meeting_id)
|
||||
|
||||
assert len(chunks) == chunk_count
|
||||
# Verify total duration and sample count (not chunk count)
|
||||
total_duration = sum(c.duration for c in chunks)
|
||||
expected_duration = chunk_count * (1600 / 16000)
|
||||
assert abs(total_duration - expected_duration) < 0.01
|
||||
|
||||
total_samples = sum(len(c.frames) for c in chunks)
|
||||
expected_samples = chunk_count * 1600
|
||||
assert total_samples == expected_samples
|
||||
|
||||
|
||||
class TestFileVersionHandling:
|
||||
"""Test file version validation."""
|
||||
|
||||
@@ -318,9 +318,7 @@ class TestEdgeCaseConfigurations:
|
||||
|
||||
list(segmenter.process_audio(silence, is_speech=False))
|
||||
list(segmenter.process_audio(speech, is_speech=True))
|
||||
if segments := list(
|
||||
segmenter.process_audio(more_silence, is_speech=False)
|
||||
):
|
||||
if segments := list(segmenter.process_audio(more_silence, is_speech=False)):
|
||||
seg = segments[0]
|
||||
assert seg.duration > 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user