Update configuration and enhance transcript component functionality

- Added `repomix-output.md` to `.gitignore` to exclude output files from version control.
- Updated `repomix.config.json` to include the `support/` directory in the include paths.
- Enhanced `TranscriptComponent` to improve search functionality with a debounce mechanism for input.
- Refactored search logic to filter visible segments based on the search query, ensuring better user experience.
- Updated `Meeting` and `Annotation` entities to use `utc_now()` for consistent UTC-aware timestamps.
- Introduced new repository protocols for `DiarizationJob` and `Preferences` to support additional functionalities.
- Implemented in-memory persistence for repositories to facilitate testing and development without a database.
- Added error handling and utility functions to streamline gRPC service responses and improve code maintainability.
This commit is contained in:
2025-12-20 20:45:58 +00:00
parent a1fc7edeea
commit d66aa6b958
36 changed files with 2347 additions and 905 deletions

1
.gitignore vendored
View File

@@ -5,3 +5,4 @@ spikes/
__pycache__/
.env
logs/status_line.json
repomix-output.md

View File

@@ -1,196 +1,382 @@
# Triage Review (Validated)
According to a document from December 19, 2025 (your repomix-output.md snapshot), heres a targeted review focused on duplication/redundancy, bug & race risks, optimization, and DRY/robustness—with concrete excerpts and “what to do next”.
Validated: 2025-12-19
Legend: Status = Confirmed, Partially confirmed, Not observed, Already implemented
Citations use [path:line] format.
## 2. Architecture & State Management
Highest-impact bug & race-condition risks
### Issue 2.1: Server-Side State Volatility
1) Streaming diarization looks globally stateful (cross-meeting interference risk)
Status: Confirmed
Severity: High
Location: src/noteflow/grpc/service.py, src/noteflow/grpc/_mixins/diarization.py, src/noteflow/grpc/server.py
In your streaming state init, you reset diarization streaming state for every meeting:
Evidence:
- In-memory stream state lives on the servicer (_active_streams, _audio_writers, _partial_buffers, _diarization_jobs). [src/noteflow/grpc/service.py:95] [src/noteflow/grpc/service.py:106] [src/noteflow/grpc/service.py:109] [src/noteflow/grpc/service.py:122]
- Background diarization jobs are stored only in a dict and status reads from it. [src/noteflow/grpc/_mixins/diarization.py:241] [src/noteflow/grpc/_mixins/diarization.py:480]
- Server shutdown only stops gRPC; no servicer cleanup hook is invoked. [src/noteflow/grpc/server.py:132]
- Crash recovery marks meetings ERROR but does not validate audio assets. [src/noteflow/application/services/recovery_service.py:21] [src/noteflow/application/services/recovery_service.py:71]
if self._diarization_engine is not None:
self._diarization_engine.reset_streaming()
Example:
- If job state is lost (for example, after a restart), polling will fail and surface a fetch error in the UI. [src/noteflow/grpc/_mixins/diarization.py:479] [src/noteflow/grpc/client.py:885] [src/noteflow/client/components/meeting_library.py:534]
Reusable code locations:
- `_close_audio_writer` already centralizes writer cleanup. [src/noteflow/grpc/service.py:247]
- Migration patterns for new tables exist (annotations). [src/noteflow/infrastructure/persistence/migrations/versions/b5c3e8a2d1f0_add_annotations_table.py:22]
And reset_streaming() resets the underlying streaming pipeline:
Actions:
- Persist diarization jobs (table or cache) and query from `GetDiarizationJobStatus`.
- Add a shutdown hook to close all `_audio_writers` and flush buffers.
- Optional: add asset integrity checks after RecoveryService marks a meeting ERROR.
def reset_streaming(self) -> None:
if self._streaming_pipeline is not None:
self._streaming_pipeline.reset()
### Issue 2.2: Implicit Meeting Asset Paths
Status: Confirmed
Severity: Medium
Location: src/noteflow/infrastructure/audio/reader.py, src/noteflow/infrastructure/audio/writer.py, src/noteflow/infrastructure/persistence/models.py
Why this is risky
• If two meeting streams happen at once (multiple clients, or a reconnect edge-case), meeting B can reset meeting As diarization pipeline mid-stream.
• You also have a global diarization lock in streaming paths (serialization), which suggests the diarization streaming pipeline is not intended to be shared concurrently across meetings. (Thats fine—but then enforce that constraint explicitly.)
Evidence:
- Audio assets are read/written under `meetings_dir / meeting_id`. [src/noteflow/infrastructure/audio/reader.py:72] [src/noteflow/infrastructure/audio/writer.py:94]
- MeetingModel defines meeting fields (id/title/state/metadata/wrapped_dek) but no asset path. [src/noteflow/infrastructure/persistence/models.py:38] [src/noteflow/infrastructure/persistence/models.py:64]
- Delete logic also assumes `meetings_dir / meeting_id`. [src/noteflow/application/services/meeting_service.py:195]
Recommendation
Pick one of these models and make it explicit:
• Model A: Only one active diarization stream allowed
• Enforce a single-stream invariant when diarization streaming is enabled.
• If a second stream begins, abort with FAILED_PRECONDITION (“diarization streaming is single-session”).
• Model B: Diarization pipeline is per meeting
• Store per-meeting diarization pipelines/state: self._diarization_streams[meeting_id] = DiarizationEngine(...) or engine.create_stream_session().
• Remove global .reset_streaming() and instead reset only the per-meeting session.
Example:
- Record a meeting with `meetings_dir` set to `~/.noteflow/meetings`, then change it to `/mnt/noteflow`. Playback will look in the new base and fail to find older audio.
Testing idea: add an async test that starts two streams and ensures diarization state for meeting A doesnt get reset when meeting B starts.
Reusable code locations:
- `MeetingAudioWriter.open` and `MeetingAudioReader.load_meeting_audio` are the path entry points. [src/noteflow/infrastructure/audio/writer.py:70] [src/noteflow/infrastructure/audio/reader.py:60]
- Migration templates live in `infrastructure/persistence/migrations/versions`. [src/noteflow/infrastructure/persistence/migrations/versions/b5c3e8a2d1f0_add_annotations_table.py:22]
Actions:
- Add `asset_path` (or `storage_path`) column to meetings.
- Store the relative path at creation time and use it on read/delete.
2) StopMeeting closes the audio writer even if the stream might still be writing
## 3. Concurrency & Performance
In StopMeeting, you close the writer immediately:
### Issue 3.1: Synchronous Blocking in Async gRPC (Streaming Diarization)
meeting_id = request.meeting_id
if meeting_id in self._audio_writers:
self._close_audio_writer(meeting_id)
Status: Confirmed
Severity: Medium
Location: src/noteflow/grpc/_mixins/diarization.py, src/noteflow/grpc/_mixins/streaming.py
Evidence:
- `_process_streaming_diarization` calls `process_chunk` synchronously. [src/noteflow/grpc/_mixins/diarization.py:63] [src/noteflow/grpc/_mixins/diarization.py:92]
- It is invoked in the async streaming loop on every chunk. [src/noteflow/grpc/_mixins/streaming.py:379]
- Diarization engine uses pyannote/diart pipelines. [src/noteflow/infrastructure/diarization/engine.py:1]
Why this is risky
• Your streaming loop writes audio chunks while streaming is active. If StopMeeting can be called while StreamTranscription is still mid-loop, you can get:
• “file not open” style exceptions,
• partial/corrupted encrypted audio artifacts,
• or silent loss depending on how _write_audio_chunk_safe is implemented.
Example:
- With diarization enabled on CPU, heavy `process_chunk` calls can stall the event loop, delaying transcript updates and heartbeats.
Recommendation
Make writer closure occur in exactly one place, ideally “stream teardown,” and make StopMeeting signal rather than force:
• Track active streams and gate closure:
• If meeting_id is currently streaming, set a stop_requested[meeting_id] = True and let the stream finally: close the writer.
• Or: in StopMeeting, call the same cleanup routine that a stream uses, but only after ensuring the stream is stopped/cancelled.
Reusable code locations:
- ASR already offloads blocking work via `run_in_executor`. [src/noteflow/infrastructure/asr/engine.py:156]
- Offline diarization uses `asyncio.to_thread`. [src/noteflow/grpc/_mixins/diarization.py:305]
Testing idea: a concurrency test that runs StreamTranscription and calls StopMeeting mid-stream, asserting no exceptions + writer is closed exactly once.
Actions:
- Offload streaming diarization to a thread/process pool similar to ASR.
- Consider a bounded queue so diarization lag does not backpressure streaming.
### Issue 3.2: VU Meter UI Updates on Every Audio Chunk
3) Naive datetimes + timezone-aware DB columns = subtle bugs and wrong timestamps
Status: Confirmed
Severity: Medium
Location: src/noteflow/client/app.py, src/noteflow/client/components/vu_meter.py
Your domain Meeting uses naive datetimes:
Evidence:
- Audio capture uses 100ms chunks. [src/noteflow/client/app.py:307]
- Each chunk triggers `VuMeterComponent.on_audio_frames`, which schedules a UI update. [src/noteflow/client/app.py:552] [src/noteflow/client/components/vu_meter.py:58]
created_at: datetime = field(default_factory=datetime.now)
Example:
- At 100ms chunks, the UI updates about 10 times per second. If chunk duration is lowered (e.g., 20ms), that becomes about 50 updates per second and can cause stutter.
Reusable code locations:
- Recording timer throttles updates with a fixed interval and background worker. [src/noteflow/client/components/recording_timer.py:14]
But your DB models use DateTime(timezone=True) while still defaulting to naive datetime.now:
Actions:
- Throttle VU updates (for example, 20 fps) or update only when delta exceeds a threshold.
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=datetime.now,
)
## 4. Domain Logic & Reliability
### Issue 4.1: Summarization Consent Persistence
And you convert to proto timestamps using .timestamp():
Status: Confirmed
Severity: Low (UX)
Location: src/noteflow/application/services/summarization_service.py
created_at=meeting.created_at.timestamp(),
Evidence:
- Consent is stored on `SummarizationServiceSettings` and defaults to False. [src/noteflow/application/services/summarization_service.py:56]
- `grant_cloud_consent` only mutates the in-memory settings. [src/noteflow/application/services/summarization_service.py:150]
Example:
- Users who grant cloud consent must re-consent after every server/client restart.
Why this is risky
• Naive/aware comparisons can explode at runtime (depending on code paths).
• .timestamp() on naive datetimes depends on local timezone assumptions.
• You already use UTC-aware time elsewhere (example: retention test uses datetime.now(UTC)), so the intent seems to be “UTC everywhere.”
Reusable code locations:
- Existing persistence patterns for per-meeting data live in `infrastructure/persistence`. [src/noteflow/infrastructure/persistence/models.py:32]
Recommendation
• Standardize: store and operate on UTC-aware datetimes everywhere.
• Domain: default_factory=lambda: datetime.now(UTC)
• DB: default=datetime.now(UTC) or server_default=func.now() (and be consistent about timezone)
• Add a test that ensures all exported/serialized timestamps are monotonic and consistent (no “timezone drift”).
Actions:
- Persist consent in a preferences table or config file and hydrate on startup.
### Issue 4.2: Annotation Validation / Point-in-Time Annotations
4) Keyring keystore bypasses your own validation logic
Status: Not observed (current code supports point annotations)
Severity: Low
Location: src/noteflow/domain/entities/annotation.py, src/noteflow/client/components/annotation_toolbar.py, src/noteflow/client/components/annotation_display.py
You have a proper validator:
Evidence:
- Validation allows `end_time == start_time`. [src/noteflow/domain/entities/annotation.py:37]
- UI creates point annotations by setting `start_time == end_time`. [src/noteflow/client/components/annotation_toolbar.py:189]
- Display uses `start_time` only, not duration. [src/noteflow/client/components/annotation_display.py:164]
def _decode_and_validate_key(self, encoded: str) -> bytes:
...
if len(decoded) != MASTER_KEY_SIZE:
raise WrongKeySizeError(...)
Example:
- Clicking a point annotation seeks to the exact timestamp (no duration needed).
Reusable code locations:
- If range annotations are introduced later, `AnnotationDisplayComponent` is where a start-end range would be rendered. [src/noteflow/client/components/annotation_display.py:148]
…but KeyringKeyStore skips it:
Action:
- None required now; revisit if range annotations are added to the UI or exports.
if stored:
return base64.b64decode(stored)
## 5. Suggested Git Issues (Validated)
Issue A: Persist Diarization Jobs to Database
Status: Confirmed
Evidence: Jobs live in-memory and `GetDiarizationJobStatus` reads from the dict. [src/noteflow/grpc/_mixins/diarization.py:241] [src/noteflow/grpc/_mixins/diarization.py:480]
Reusable code locations:
- SQLAlchemy models/migrations patterns. [src/noteflow/infrastructure/persistence/models.py:32] [src/noteflow/infrastructure/persistence/migrations/versions/b5c3e8a2d1f0_add_annotations_table.py:22]
Tasks:
- Add JobModel with status fields.
- Update mixin to persist and query DB.
Why this is risky
• Invalid base64 or wrong key sizes could silently produce bad keys, later causing encryption/decryption failures that look like “random corruption.”
Issue B: Implement Audio Writer Buffering
Status: Already implemented (close or re-scope)
Evidence: `MeetingAudioWriter` buffers and flushes based on `buffer_size`. [src/noteflow/infrastructure/audio/writer.py:126]
Reusable code locations:
- `AUDIO_BUFFER_SIZE_BYTES` constant. [src/noteflow/config/constants.py:26]
Tasks:
- None, unless you want to tune buffer size.
Recommendation
• Use _decode_and_validate_key() in all keystore implementations (file, env, keyring).
• Add a test that injects an invalid stored keyring value and asserts a clear, typed failure.
Issue C: Fallback for Headless Keyring
Status: Confirmed
Evidence: `KeyringKeyStore` only falls back to env var, not file storage. [src/noteflow/infrastructure/security/keystore.py:49]
Reusable code locations:
- `KeyringKeyStore` and `InMemoryKeyStore` live in the same module. [src/noteflow/infrastructure/security/keystore.py:35]
Tasks:
- Add `FileKeyStore` and wire fallback in server/service initialization.
Issue D: Throttled VU Meter in Client
Status: Confirmed
Evidence: Each chunk schedules a UI update with no throttle. [src/noteflow/client/components/vu_meter.py:58]
Reusable code locations:
- Background worker/throttle pattern in `RecordingTimerComponent`. [src/noteflow/client/components/recording_timer.py:14]
Tasks:
- Add a `last_update_time` and update at fixed cadence.
Biggest duplication & DRY pain points (and how to eliminate them)
Issue E: Explicit Asset Path Storage
Status: Confirmed
Evidence: Meeting paths derived from `meetings_dir / meeting_id`. [src/noteflow/infrastructure/audio/reader.py:72]
Reusable code locations:
- Meeting model + migrations. [src/noteflow/infrastructure/persistence/models.py:32]
Tasks:
- Add `asset_path` column and persist at create time.
1) gRPC mixins repeat “DB path vs memory path” everywhere
Issue F: PGVector Index Creation
Status: Confirmed (requires product decision)
Evidence: Migration uses `ivfflat` index created immediately. [src/noteflow/infrastructure/persistence/migrations/versions/6a9d9f408f40_initial_schema.py:95]
Reusable code locations:
- Same migration file for index changes. [src/noteflow/infrastructure/persistence/migrations/versions/6a9d9f408f40_initial_schema.py:95]
Tasks:
- Consider switching to HNSW or defer index creation until data exists.
Example in ListMeetings:
## 6. Code Quality & Nitpicks (Validated)
if self._use_database():
async with self._create_uow() as uow:
meetings = await uow.meetings.list_all()
else:
store = self._get_memory_store()
meetings = list(store._meetings.values())
return ListMeetingsResponse(meetings=[meeting_to_proto(m) for m in meetings])
- HTML export template is minimal inline CSS (no external framework). [src/noteflow/infrastructure/export/html.py:33]
Example: optionally add a lightweight stylesheet (with offline fallback) for nicer exports.
- Transcript partial row is updated in place (no remove/re-add), so flicker risk is already mitigated. [src/noteflow/client/components/transcript.py:182]
- `Segment.word_count` uses `text.split()` when no words are present. [src/noteflow/domain/entities/segment.py:72]
Example: for very large transcripts, a streaming count (for example, regex iterator) avoids allocating a full list.
Same pattern appears in summarization:
if self._use_database():
summary = await self._generate_summary_db(...)
else:
summary = self._generate_summary_memory(...)
And streaming segment persistence is duplicated too (and commits per result; more on perf below).
Why this matters
• This is a structural duplication multiplier.
• Code assistants will mirror the existing pattern and duplicate more.
• Youll fix a bug in one branch and forget the other.
Recommendation (best ROI refactor)
Create a single abstraction layer and make the gRPC layer depend on it:
• Define Ports/Protocols (async-friendly):
• MeetingRepository, SegmentRepository, SummaryRepository, etc.
• Provide two adapters:
• SqlAlchemyMeetingRepository
• InMemoryMeetingRepository
• In the servicer, inject one concrete implementation behind a unified interface:
• self.meetings_repo, self.segments_repo, self.summaries_repo
Then ListMeetings becomes:
meetings = await self.meetings_repo.list_all()
return ListMeetingsResponse(meetings=[meeting_to_proto(m) for m in meetings])
No branching in every RPC = dramatically less duplication, fewer assistant copy/pastes, and fewer inconsistencies.
Performance / optimization hotspots
1) DB add_batch is not actually batching
Your segment repos add_batch loops and calls add for each segment:
async def add_batch(...):
for seg in segments:
await self.add(meeting_id, seg)
Why this hurts
• If add() flushes/commits or even just issues INSERTs repeatedly, you get N roundtrips / flushes.
Recommendation
• Build ORM models for all segments and call session.add_all([...]), then flush once.
• If you need IDs, flush once and read them back.
• Keep commit decision at the UoW/service layer (not inside repo).
2) Streaming persistence commits inside a tight loop
Inside _process_audio_segment (DB branch):
for result in results:
...
await uow.segments.add(meeting.id, segment)
await uow.commit()
yield segment_to_proto_update(meeting_id, segment)
Why this hurts
• If ASR returns multiple results for one audio segment, you commit multiple times.
• Commits are expensive and can become the bottleneck.
Recommendation
• Commit once per processed audio segment (or once per gRPC request chunk batch), not once per ASR result.
• You can still yield after staging; if you must ensure durability before yielding, commit once after staging all results for that audio segment.
3) Transcript search re-renders the entire list every change (likely OK now, painful later)
**RESOLVED**: Added 200ms debounce timer and visibility toggling instead of full rebuild.
- Search input now uses `threading.Timer` with 200ms debounce
- All segment rows are created once and visibility is toggled via `container.visible`
- No more clearing and rebuilding on each keystroke
~~On search changes, you do a full rebuild:~~
~~self._segment_rows.clear()~~
~~self._list_view.controls.clear()~~
~~for idx, seg in enumerate(self._state.transcript_segments):~~
~~...~~
~~self._list_view.controls.append(row)~~
~~And it is triggered directly by search input changes:~~
~~self._search_query = (value or "").strip().lower()~~
~~...~~
~~self._rerender_all_segments()~~
~~Recommendation~~
~~• Add a debounce (150250ms) to search updates.~~
~~• Consider incremental filtering (track which rows match and only hide/show).~~
~~• If transcripts can become large: virtualized list rendering.~~
4) Client streaming stop/join can leak threads
stop_streaming joins with a timeout and clears the thread reference:
self._stream_thread.join(timeout=2.0)
self._stream_thread = None
If the thread doesnt exit in time, you can lose the handle while its still running.
Recommendation
• After join timeout, check is_alive() and log + keep the reference (or escalate).
• Better: make the worker reliably cancellable and join without timeout in normal shutdown paths.
How to write tests that catch more of these bugs
You already have a good base: there are stress tests around streaming cleanup / leak prevention (nice).
Whats missing tends to be concurrency + contract + regression tests that mirror real failure modes.
1) Add “race tests” for the exact dangerous interleavings
Concrete targets from this review:
• StopMeeting vs StreamTranscription write
• Start a stream, send some chunks, call StopMeeting mid-flight.
• Assert: no exception, audio writer closed once, meeting state consistent.
• Two streams with diarization enabled
• Start stream A, then start stream B.
• Assert: As diarization state isnt reset (or assert server rejects second stream if you enforce single-stream).
Implementation approach:
• Use grpc.aio and asyncio.gather() to force overlap.
• Keep tests deterministic by using fixed chunk data and controlled scheduling (await asyncio.sleep(0) strategically).
2) Contract tests for DB vs memory parity
As long as you support both backends, enforce “same behavior”:
• For each operation (create meeting, add segments, summary generation, deletion):
• Run the same test suite twice: once with in-memory store, once with DB store.
• Assert responses and side effects match.
This directly prevents the “fixed in DB path, broken in memory path” class of bugs.
3) Property-based tests for invariants (catches “weird” edge cases)
Use Hypothesis (or similar) for:
• Random sequences of MeetingState transitions: ensure only allowed transitions occur.
• Random segments + word timings:
• word start/end are within segment range,
• segment durations are non-negative,
• transcript ordering invariants.
These tests are excellent at flushing out hidden assumptions.
4) Mutation testing (biggest upgrade if “bugs slip through”)
If you feel like “tests pass but stuff breaks,” mutation testing will tell you where your suite is weak.
• Run mutation testing on core modules:
• gRPC mixins (streaming/summarization)
• repositories
• crypto keystore/unwrap logic
• Youll quickly find untested branches and “assert-less” code.
5) Make regressions mandatory: “bug → test → fix”
A simple process change that works:
• When you fix a bug, first write a failing test that reproduces it.
• Only then fix the code.
This steadily raises your “bugs cant sneak back in” floor.
6) CI guardrails
• Enforce coverage thresholds on:
• domain + application + grpc layers (UI excluded if unstable)
• Run stress/race tests nightly if theyre expensive.
How to get code assistants to stop duplicating code
1) Remove the architectural duplication magnets
The #1 driver is your repeated DB-vs-memory branching in every RPC/method.
If you refactor to a single repository/service interface, the assistant literally has fewer “duplicate-shaped” surfaces to copy.
2) Add an “assistant contract” file to your repo
Create a short, explicit guide your assistants must follow, e.g. ASSISTANT_GUIDELINES.md:
Include rules like:
• “Before writing code, search for existing helpers/converters/repositories and reuse them.”
• “No new _db / _memory functions. Add to the repository interface instead.”
• “If adding logic similar to existing code, refactor to shared helper instead of copy/paste.”
• “Prefer editing an existing module over creating a new one.”
3) Prompt pattern that reduces duplication dramatically
When you ask an assistant to implement something, use a structure like:
• Step 1: “List existing functions/files that already do part of this.”
• Step 2: “Propose the minimal diff touching the fewest files.”
• Step 3: “If code would be duplicated, refactor first (create helper + update callers).”
• Step 4: “Show me the final patch.”
It forces discovery + reuse before generation.
4) Add duplication detection to CI
Make duplication a failing signal, not a suggestion:
• Add a copy/paste detector (language-agnostic ones exist) and fail above a threshold.
• Pair it with lint rules that push toward DRY patterns.
5) Code review rubric (even if youre solo)
A quick checklist:
• “Is this new code already present elsewhere?”
• “Did we add a second way to do the same thing?”
• “Is there now both a DB and memory version of the same logic?”
If you only do 3 things next
1. Refactor gRPC mixins to a unified repository/service interface (kills most duplication + reduces assistant copy/paste).
2. Fix diarization streaming state scoping (global reset is the scariest race).
3. Add race tests for StopMeeting vs streaming + enforce UTC-aware datetimes.
If you want, I can also propose a concrete refactor sketch (new interfaces + how to wire the servicer) that eliminates the DB/memory branching with minimal disruption—using the patterns you already have (UoW + repositories).

View File

@@ -26,7 +26,7 @@
"includeLogsCount": 50
}
},
"include": ["src/", "tests/"],
"include": ["src/", "tests/", "support/"],
"ignore": {
"useGitignore": true,
"useDefaultPatterns": true,

View File

@@ -509,7 +509,13 @@ class NoteFlowClientApp(TriggerMixin):
self._audio_consumer_stop.set()
if self._audio_consumer_thread is not None:
self._audio_consumer_thread.join(timeout=1.0)
self._audio_consumer_thread = None
# Only clear reference if thread exited cleanly
if self._audio_consumer_thread.is_alive():
logger.warning(
"Audio consumer thread did not exit within timeout, keeping reference"
)
else:
self._audio_consumer_thread = None
# Drain remaining frames
while not self._audio_frame_queue.empty():
try:

View File

@@ -8,6 +8,7 @@ from __future__ import annotations
import hashlib
from collections.abc import Callable
from threading import Timer
from typing import TYPE_CHECKING
import flet as ft
@@ -15,6 +16,9 @@ import flet as ft
# REUSE existing formatting - do not recreate
from noteflow.infrastructure.export._formatting import format_timestamp
# Debounce delay for search input (milliseconds)
_SEARCH_DEBOUNCE_MS = 200
if TYPE_CHECKING:
from noteflow.client.state import AppState
@@ -42,10 +46,11 @@ class TranscriptComponent:
self._state = state
self._on_segment_click = on_segment_click
self._list_view: ft.ListView | None = None
self._segment_rows: list[ft.Container | None] = [] # Track rows for highlighting
self._segment_rows: list[ft.Container] = [] # All rows, use visible property to filter
self._search_field: ft.TextField | None = None
self._search_query: str = ""
self._partial_row: ft.Container | None = None # Live partial at bottom
self._search_timer: Timer | None = None # Debounce timer for search
def build(self) -> ft.Column:
"""Build transcript list view with search.
@@ -110,6 +115,11 @@ class TranscriptComponent:
def clear(self) -> None:
"""Clear all transcript segments and partials."""
# Cancel pending search timer
if self._search_timer is not None:
self._search_timer.cancel()
self._search_timer = None
self._state.clear_transcript()
self._segment_rows.clear()
self._partial_row = None
@@ -121,13 +131,42 @@ class TranscriptComponent:
self._state.request_update()
def _on_search_change(self, e: ft.ControlEvent) -> None:
"""Handle search field change.
"""Handle search field change with debounce.
Args:
e: Control event with new search value.
"""
self._search_query = (e.control.value or "").lower()
self._rerender_all_segments()
# Cancel pending timer
if self._search_timer is not None:
self._search_timer.cancel()
# Start new debounce timer
self._search_timer = Timer(
_SEARCH_DEBOUNCE_MS / 1000.0,
self._apply_search_filter,
)
self._search_timer.start()
def _apply_search_filter(self) -> None:
"""Apply search filter to existing rows via visibility toggle."""
self._state.run_on_ui_thread(self._toggle_row_visibility)
def _toggle_row_visibility(self) -> None:
"""Toggle visibility of rows based on search query (UI thread only)."""
if not self._list_view:
return
query = self._search_query
for idx, container in enumerate(self._segment_rows):
if idx >= len(self._state.transcript_segments):
continue
segment = self._state.transcript_segments[idx]
matches = not query or query in segment.text.lower()
container.visible = matches
self._state.request_update()
def _rerender_all_segments(self) -> None:
"""Re-render all segments with current search filter."""
@@ -137,15 +176,11 @@ class TranscriptComponent:
self._list_view.controls.clear()
self._segment_rows.clear()
query = self._search_query
for idx, segment in enumerate(self._state.transcript_segments):
# Filter by search query
if self._search_query and self._search_query not in segment.text.lower():
# Add placeholder to maintain index alignment
self._segment_rows.append(None)
continue
# Use original index for click handling
container = self._create_segment_row(segment, idx)
# Set visibility based on search query
container.visible = not query or query in segment.text.lower()
self._segment_rows.append(container)
self._list_view.controls.append(container)
@@ -167,14 +202,12 @@ class TranscriptComponent:
# Use the actual index from state (segments are appended before rendering)
segment_index = len(self._state.transcript_segments) - 1
# Filter by search query during live rendering
if self._search_query and self._search_query not in segment.text.lower():
self._segment_rows.append(None)
return
container = self._create_segment_row(segment, segment_index)
# Set visibility based on search query during live rendering
query = self._search_query
container.visible = not query or query in segment.text.lower()
self._segment_rows.append(container)
self._list_view.controls.append(container)
self._state.request_update()
@@ -347,8 +380,6 @@ class TranscriptComponent:
highlighted_index: Index of segment to highlight, or None to clear.
"""
for idx, container in enumerate(self._segment_rows):
if container is None:
continue
if idx == highlighted_index:
container.bgcolor = ft.Colors.YELLOW_100
container.border = ft.border.all(1, ft.Colors.YELLOW_700)
@@ -371,10 +402,6 @@ class TranscriptComponent:
if not self._list_view or segment_index >= len(self._segment_rows):
return
container = self._segment_rows[segment_index]
if container is None:
return
# Estimate row height for scroll calculation
estimated_row_height = 50
offset = segment_index * estimated_row_height

View File

@@ -9,6 +9,8 @@ from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING
from noteflow.domain.utils.time import utc_now
if TYPE_CHECKING:
from noteflow.domain.value_objects import AnnotationId, AnnotationType, MeetingId
@@ -29,7 +31,7 @@ class Annotation:
start_time: float
end_time: float
segment_ids: list[int] = field(default_factory=list)
created_at: datetime = field(default_factory=datetime.now)
created_at: datetime = field(default_factory=utc_now)
# Database primary key (set after persistence)
db_id: int | None = None

View File

@@ -7,6 +7,7 @@ from datetime import datetime
from typing import TYPE_CHECKING
from uuid import UUID, uuid4
from noteflow.domain.utils.time import utc_now
from noteflow.domain.value_objects import MeetingId, MeetingState
if TYPE_CHECKING:
@@ -25,7 +26,7 @@ class Meeting:
id: MeetingId
title: str
state: MeetingState = MeetingState.CREATED
created_at: datetime = field(default_factory=datetime.now)
created_at: datetime = field(default_factory=utc_now)
started_at: datetime | None = None
ended_at: datetime | None = None
segments: list[Segment] = field(default_factory=list)
@@ -50,7 +51,7 @@ class Meeting:
New Meeting instance.
"""
meeting_id = MeetingId(uuid4())
now = datetime.now()
now = utc_now()
if not title:
title = f"Meeting {now.strftime('%Y-%m-%d %H:%M')}"
@@ -98,7 +99,7 @@ class Meeting:
id=meeting_id,
title=title,
state=state,
created_at=created_at or datetime.now(),
created_at=created_at or utc_now(),
started_at=started_at,
ended_at=ended_at,
metadata=metadata or {},
@@ -115,7 +116,7 @@ class Meeting:
if not self.state.can_transition_to(MeetingState.RECORDING):
raise ValueError(f"Cannot start recording from state {self.state.name}")
self.state = MeetingState.RECORDING
self.started_at = datetime.now()
self.started_at = utc_now()
def begin_stopping(self) -> None:
"""Transition to stopping state for graceful shutdown.
@@ -140,7 +141,7 @@ class Meeting:
raise ValueError(f"Cannot stop recording from state {self.state.name}")
self.state = MeetingState.STOPPED
if self.ended_at is None:
self.ended_at = datetime.now()
self.ended_at = utc_now()
def complete(self) -> None:
"""Transition to completed state.
@@ -178,7 +179,7 @@ class Meeting:
if self.ended_at and self.started_at:
return (self.ended_at - self.started_at).total_seconds()
if self.started_at:
return (datetime.now() - self.started_at).total_seconds()
return (utc_now() - self.started_at).total_seconds()
return 0.0
@property

View File

@@ -9,6 +9,10 @@ from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING:
from noteflow.domain.entities import Annotation, Meeting, Segment, Summary
from noteflow.domain.value_objects import AnnotationId, MeetingId, MeetingState
from noteflow.infrastructure.persistence.repositories import (
DiarizationJob,
StreamingTurn,
)
class MeetingRepository(Protocol):
@@ -310,3 +314,154 @@ class AnnotationRepository(Protocol):
True if deleted, False if not found.
"""
...
class DiarizationJobRepository(Protocol):
"""Repository protocol for DiarizationJob operations.
Tracks background speaker diarization jobs for crash resilience
and client visibility.
"""
async def create(self, job: DiarizationJob) -> DiarizationJob:
"""Persist a new diarization job.
Args:
job: DiarizationJob data object.
Returns:
Created job.
"""
...
async def get(self, job_id: str) -> DiarizationJob | None:
"""Retrieve a job by ID.
Args:
job_id: Job identifier.
Returns:
Job if found, None otherwise.
"""
...
async def update_status(
self,
job_id: str,
status: int,
*,
segments_updated: int | None = None,
speaker_ids: list[str] | None = None,
error_message: str | None = None,
) -> bool:
"""Update job status and optional fields.
Args:
job_id: Job identifier.
status: New status value.
segments_updated: Optional segments count.
speaker_ids: Optional speaker IDs list.
error_message: Optional error message.
Returns:
True if job was updated, False if not found.
"""
...
async def prune_completed(self, ttl_seconds: float) -> int:
"""Delete completed/failed jobs older than TTL.
Args:
ttl_seconds: Time-to-live in seconds.
Returns:
Number of jobs deleted.
"""
...
async def add_streaming_turns(
self,
meeting_id: str,
turns: Sequence[StreamingTurn],
) -> int:
"""Persist streaming diarization turns for crash resilience.
Args:
meeting_id: Meeting identifier.
turns: Speaker turns to persist.
Returns:
Number of turns added.
"""
...
async def get_streaming_turns(self, meeting_id: str) -> list[StreamingTurn]:
"""Retrieve streaming diarization turns for a meeting.
Args:
meeting_id: Meeting identifier.
Returns:
List of streaming turns.
"""
...
async def clear_streaming_turns(self, meeting_id: str) -> int:
"""Delete streaming diarization turns for a meeting.
Args:
meeting_id: Meeting identifier.
Returns:
Number of turns deleted.
"""
...
async def mark_running_as_failed(self, error_message: str = "Server restarted") -> int:
"""Mark queued/running jobs as failed.
Args:
error_message: Error message to set on failed jobs.
Returns:
Number of jobs marked as failed.
"""
...
class PreferencesRepository(Protocol):
"""Repository protocol for user preferences operations.
Stores key-value user preferences for persistence across server restarts.
"""
async def get(self, key: str) -> object | None:
"""Get preference value by key.
Args:
key: Preference key.
Returns:
Preference value or None if not found.
"""
...
async def set(self, key: str, value: object) -> None:
"""Set preference value.
Args:
key: Preference key.
value: Preference value.
"""
...
async def delete(self, key: str) -> bool:
"""Delete a preference.
Args:
key: Preference key.
Returns:
True if deleted, False if not found.
"""
...

View File

@@ -2,23 +2,30 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol, Self
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable
if TYPE_CHECKING:
from .repositories import (
AnnotationRepository,
DiarizationJobRepository,
MeetingRepository,
PreferencesRepository,
SegmentRepository,
SummaryRepository,
)
@runtime_checkable
class UnitOfWork(Protocol):
"""Unit of Work protocol for managing transactions across repositories.
Provides transactional consistency when operating on multiple
aggregates. Use as a context manager for automatic commit/rollback.
Implementations may be backed by either a database (SqlAlchemyUnitOfWork)
or in-memory storage (MemoryUnitOfWork). The `supports_*` properties
indicate which features are available in the current implementation.
Example:
async with uow:
meeting = await uow.meetings.get(meeting_id)
@@ -26,10 +33,62 @@ class UnitOfWork(Protocol):
await uow.commit()
"""
annotations: AnnotationRepository
meetings: MeetingRepository
segments: SegmentRepository
summaries: SummaryRepository
# Core repositories (always available)
@property
def meetings(self) -> MeetingRepository:
"""Access the meetings repository."""
...
@property
def segments(self) -> SegmentRepository:
"""Access the segments repository."""
...
@property
def summaries(self) -> SummaryRepository:
"""Access the summaries repository."""
...
# Optional repositories (check supports_* before use)
@property
def annotations(self) -> AnnotationRepository:
"""Access the annotations repository."""
...
@property
def diarization_jobs(self) -> DiarizationJobRepository:
"""Access the diarization jobs repository."""
...
@property
def preferences(self) -> PreferencesRepository:
"""Access the preferences repository."""
...
# Feature flags for DB-only capabilities
@property
def supports_annotations(self) -> bool:
"""Check if annotation operations are supported.
Returns False for memory-only implementations.
"""
...
@property
def supports_diarization_jobs(self) -> bool:
"""Check if diarization job persistence is supported.
Returns False for memory-only implementations.
"""
...
@property
def supports_preferences(self) -> bool:
"""Check if user preferences persistence is supported.
Returns False for memory-only implementations.
"""
...
async def __aenter__(self) -> Self:
"""Enter the unit of work context.

View File

@@ -0,0 +1,5 @@
"""Domain utility functions."""
from noteflow.domain.utils.time import utc_now
__all__ = ["utc_now"]

View File

@@ -0,0 +1,21 @@
"""Time utilities for consistent timezone handling.
All datetime values in the domain should be UTC-aware to prevent:
- Naive/aware comparison errors
- Incorrect .timestamp() conversions
- Timezone drift between components
"""
from datetime import UTC, datetime
def utc_now() -> datetime:
"""Return current UTC datetime (timezone-aware).
Use this instead of datetime.now() throughout the codebase to ensure
consistent timezone-aware datetime handling.
Returns:
Current datetime in UTC timezone.
"""
return datetime.now(UTC)

View File

@@ -3,15 +3,21 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from uuid import UUID, uuid4
from uuid import uuid4
import grpc.aio
from noteflow.domain.entities import Annotation
from noteflow.domain.value_objects import AnnotationId, MeetingId
from noteflow.domain.value_objects import AnnotationId
from ..proto import noteflow_pb2
from .converters import annotation_to_proto, proto_to_annotation_type
from .converters import (
annotation_to_proto,
parse_annotation_id,
parse_meeting_id,
proto_to_annotation_type,
)
from .errors import abort_database_required, abort_invalid_argument, abort_not_found
if TYPE_CHECKING:
from .protocols import ServicerHost
@@ -30,27 +36,28 @@ class AnnotationMixin:
context: grpc.aio.ServicerContext,
) -> noteflow_pb2.Annotation:
"""Add an annotation to a meeting."""
if not self._use_database():
await context.abort(
grpc.StatusCode.UNIMPLEMENTED,
"Annotations require database persistence",
async with self._create_repository_provider() as repo:
if not repo.supports_annotations:
await abort_database_required(context, "Annotations")
try:
meeting_id = parse_meeting_id(request.meeting_id)
except ValueError:
await abort_invalid_argument(context, "Invalid meeting_id")
annotation_type = proto_to_annotation_type(request.annotation_type)
annotation = Annotation(
id=AnnotationId(uuid4()),
meeting_id=meeting_id,
annotation_type=annotation_type,
text=request.text,
start_time=request.start_time,
end_time=request.end_time,
segment_ids=list(request.segment_ids),
)
annotation_type = proto_to_annotation_type(request.annotation_type)
annotation = Annotation(
id=AnnotationId(uuid4()),
meeting_id=MeetingId(UUID(request.meeting_id)),
annotation_type=annotation_type,
text=request.text,
start_time=request.start_time,
end_time=request.end_time,
segment_ids=list(request.segment_ids),
)
async with self._create_uow() as uow:
saved = await uow.annotations.add(annotation)
await uow.commit()
saved = await repo.annotations.add(annotation)
await repo.commit()
return annotation_to_proto(saved)
async def GetAnnotation(
@@ -59,19 +66,18 @@ class AnnotationMixin:
context: grpc.aio.ServicerContext,
) -> noteflow_pb2.Annotation:
"""Get an annotation by ID."""
if not self._use_database():
await context.abort(
grpc.StatusCode.UNIMPLEMENTED,
"Annotations require database persistence",
)
async with self._create_repository_provider() as repo:
if not repo.supports_annotations:
await abort_database_required(context, "Annotations")
async with self._create_uow() as uow:
annotation = await uow.annotations.get(AnnotationId(UUID(request.annotation_id)))
try:
annotation_id = parse_annotation_id(request.annotation_id)
except ValueError:
await abort_invalid_argument(context, "Invalid annotation_id")
annotation = await repo.annotations.get(annotation_id)
if annotation is None:
await context.abort(
grpc.StatusCode.NOT_FOUND,
f"Annotation {request.annotation_id} not found",
)
await abort_not_found(context, "Annotation", request.annotation_id)
return annotation_to_proto(annotation)
async def ListAnnotations(
@@ -80,23 +86,23 @@ class AnnotationMixin:
context: grpc.aio.ServicerContext,
) -> noteflow_pb2.ListAnnotationsResponse:
"""List annotations for a meeting."""
if not self._use_database():
await context.abort(
grpc.StatusCode.UNIMPLEMENTED,
"Annotations require database persistence",
)
async with self._create_repository_provider() as repo:
if not repo.supports_annotations:
await abort_database_required(context, "Annotations")
async with self._create_uow() as uow:
meeting_id = MeetingId(UUID(request.meeting_id))
try:
meeting_id = parse_meeting_id(request.meeting_id)
except ValueError:
await abort_invalid_argument(context, "Invalid meeting_id")
# Check if time range filter is specified
if request.start_time > 0 or request.end_time > 0:
annotations = await uow.annotations.get_by_time_range(
annotations = await repo.annotations.get_by_time_range(
meeting_id,
request.start_time,
request.end_time,
)
else:
annotations = await uow.annotations.get_by_meeting(meeting_id)
annotations = await repo.annotations.get_by_meeting(meeting_id)
return noteflow_pb2.ListAnnotationsResponse(
annotations=[annotation_to_proto(a) for a in annotations]
@@ -108,19 +114,18 @@ class AnnotationMixin:
context: grpc.aio.ServicerContext,
) -> noteflow_pb2.Annotation:
"""Update an existing annotation."""
if not self._use_database():
await context.abort(
grpc.StatusCode.UNIMPLEMENTED,
"Annotations require database persistence",
)
async with self._create_repository_provider() as repo:
if not repo.supports_annotations:
await abort_database_required(context, "Annotations")
async with self._create_uow() as uow:
annotation = await uow.annotations.get(AnnotationId(UUID(request.annotation_id)))
try:
annotation_id = parse_annotation_id(request.annotation_id)
except ValueError:
await abort_invalid_argument(context, "Invalid annotation_id")
annotation = await repo.annotations.get(annotation_id)
if annotation is None:
await context.abort(
grpc.StatusCode.NOT_FOUND,
f"Annotation {request.annotation_id} not found",
)
await abort_not_found(context, "Annotation", request.annotation_id)
# Update fields if provided
if request.annotation_type != noteflow_pb2.ANNOTATION_TYPE_UNSPECIFIED:
@@ -134,8 +139,8 @@ class AnnotationMixin:
if request.segment_ids:
annotation.segment_ids = list(request.segment_ids)
updated = await uow.annotations.update(annotation)
await uow.commit()
updated = await repo.annotations.update(annotation)
await repo.commit()
return annotation_to_proto(updated)
async def DeleteAnnotation(
@@ -144,18 +149,17 @@ class AnnotationMixin:
context: grpc.aio.ServicerContext,
) -> noteflow_pb2.DeleteAnnotationResponse:
"""Delete an annotation."""
if not self._use_database():
await context.abort(
grpc.StatusCode.UNIMPLEMENTED,
"Annotations require database persistence",
)
async with self._create_repository_provider() as repo:
if not repo.supports_annotations:
await abort_database_required(context, "Annotations")
async with self._create_uow() as uow:
success = await uow.annotations.delete(AnnotationId(UUID(request.annotation_id)))
try:
annotation_id = parse_annotation_id(request.annotation_id)
except ValueError:
await abort_invalid_argument(context, "Invalid annotation_id")
success = await repo.annotations.delete(annotation_id)
if success:
await uow.commit()
await repo.commit()
return noteflow_pb2.DeleteAnnotationResponse(success=True)
await context.abort(
grpc.StatusCode.NOT_FOUND,
f"Annotation {request.annotation_id} not found",
)
await abort_not_found(context, "Annotation", request.annotation_id)

View File

@@ -4,10 +4,11 @@ from __future__ import annotations
import time
from typing import TYPE_CHECKING
from uuid import UUID
from noteflow.application.services.export_service import ExportFormat
from noteflow.domain.entities import Annotation, Meeting, Segment, Summary
from noteflow.domain.value_objects import AnnotationType, MeetingId
from noteflow.domain.entities import Annotation, Meeting, Segment, Summary, WordTiming
from noteflow.domain.value_objects import AnnotationId, AnnotationType, MeetingId
from noteflow.infrastructure.converters import AsrConverter
from ..proto import noteflow_pb2
@@ -16,6 +17,93 @@ if TYPE_CHECKING:
from noteflow.infrastructure.asr.dto import AsrResult
# -----------------------------------------------------------------------------
# ID Parsing Helpers (eliminate 11+ duplications)
# -----------------------------------------------------------------------------
def parse_meeting_id(meeting_id_str: str) -> MeetingId:
"""Parse string to MeetingId.
Consolidates the repeated `MeetingId(UUID(request.meeting_id))` pattern.
Args:
meeting_id_str: Meeting ID as string (UUID format).
Returns:
MeetingId value object.
"""
return MeetingId(UUID(meeting_id_str))
def parse_annotation_id(annotation_id_str: str) -> AnnotationId:
"""Parse string to AnnotationId.
Args:
annotation_id_str: Annotation ID as string (UUID format).
Returns:
AnnotationId value object.
"""
return AnnotationId(UUID(annotation_id_str))
# -----------------------------------------------------------------------------
# Proto Construction Helpers (eliminate duplicate word/segment building)
# -----------------------------------------------------------------------------
def word_to_proto(word: WordTiming) -> noteflow_pb2.WordTiming:
"""Convert domain WordTiming to protobuf.
Consolidates the repeated WordTiming construction pattern.
Args:
word: Domain WordTiming entity.
Returns:
Protobuf WordTiming message.
"""
return noteflow_pb2.WordTiming(
word=word.word,
start_time=word.start_time,
end_time=word.end_time,
probability=word.probability,
)
def segment_to_final_segment_proto(segment: Segment) -> noteflow_pb2.FinalSegment:
"""Convert domain Segment to FinalSegment protobuf.
Consolidates the repeated FinalSegment construction pattern.
Args:
segment: Domain Segment entity.
Returns:
Protobuf FinalSegment message.
"""
words = [word_to_proto(w) for w in segment.words]
return noteflow_pb2.FinalSegment(
segment_id=segment.segment_id,
text=segment.text,
start_time=segment.start_time,
end_time=segment.end_time,
words=words,
language=segment.language,
language_confidence=segment.language_confidence,
avg_logprob=segment.avg_logprob,
no_speech_prob=segment.no_speech_prob,
speaker_id=segment.speaker_id or "",
speaker_confidence=segment.speaker_confidence,
)
# -----------------------------------------------------------------------------
# Main Converter Functions
# -----------------------------------------------------------------------------
def meeting_to_proto(
meeting: Meeting,
include_segments: bool = True,
@@ -24,31 +112,7 @@ def meeting_to_proto(
"""Convert domain Meeting to protobuf."""
segments = []
if include_segments:
for seg in meeting.segments:
words = [
noteflow_pb2.WordTiming(
word=w.word,
start_time=w.start_time,
end_time=w.end_time,
probability=w.probability,
)
for w in seg.words
]
segments.append(
noteflow_pb2.FinalSegment(
segment_id=seg.segment_id,
text=seg.text,
start_time=seg.start_time,
end_time=seg.end_time,
words=words,
language=seg.language,
language_confidence=seg.language_confidence,
avg_logprob=seg.avg_logprob,
no_speech_prob=seg.no_speech_prob,
speaker_id=seg.speaker_id or "",
speaker_confidence=seg.speaker_confidence,
)
)
segments = [segment_to_final_segment_proto(seg) for seg in meeting.segments]
summary = None
if include_summary and meeting.summary:
@@ -104,32 +168,10 @@ def segment_to_proto_update(
segment: Segment,
) -> noteflow_pb2.TranscriptUpdate:
"""Convert domain Segment to protobuf TranscriptUpdate."""
words = [
noteflow_pb2.WordTiming(
word=w.word,
start_time=w.start_time,
end_time=w.end_time,
probability=w.probability,
)
for w in segment.words
]
final_segment = noteflow_pb2.FinalSegment(
segment_id=segment.segment_id,
text=segment.text,
start_time=segment.start_time,
end_time=segment.end_time,
words=words,
language=segment.language,
language_confidence=segment.language_confidence,
avg_logprob=segment.avg_logprob,
no_speech_prob=segment.no_speech_prob,
speaker_id=segment.speaker_id or "",
speaker_confidence=segment.speaker_confidence,
)
return noteflow_pb2.TranscriptUpdate(
meeting_id=meeting_id,
update_type=noteflow_pb2.UPDATE_TYPE_FINAL,
segment=final_segment,
segment=segment_to_final_segment_proto(segment),
server_timestamp=time.time(),
)

View File

@@ -14,7 +14,7 @@ import numpy as np
from numpy.typing import NDArray
from noteflow.domain.entities import Segment
from noteflow.domain.value_objects import MeetingId, MeetingState
from noteflow.domain.value_objects import MeetingState
from noteflow.infrastructure.audio.reader import MeetingAudioReader
from noteflow.infrastructure.diarization import SpeakerTurn, assign_speaker
from noteflow.infrastructure.persistence.repositories import (
@@ -23,13 +23,73 @@ from noteflow.infrastructure.persistence.repositories import (
)
from ..proto import noteflow_pb2
from .converters import parse_meeting_id
from .errors import abort_invalid_argument, abort_not_found
if TYPE_CHECKING:
from collections.abc import Sequence
from .protocols import ServicerHost
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Helper Functions (Eliminate Duplication)
# -----------------------------------------------------------------------------
def _create_diarization_error_response(
error_message: str,
status: int = noteflow_pb2.JOB_STATUS_FAILED,
) -> noteflow_pb2.RefineSpeakerDiarizationResponse:
"""Create error response for RefineSpeakerDiarization.
Consolidates the 6+ duplicated response construction patterns.
Args:
error_message: Error message describing the failure.
status: Job status code (default: JOB_STATUS_FAILED).
Returns:
Populated RefineSpeakerDiarizationResponse with error state.
"""
response = noteflow_pb2.RefineSpeakerDiarizationResponse()
response.segments_updated = 0
response.speaker_ids[:] = []
response.error_message = error_message
response.job_id = ""
response.status = status
return response
def _apply_speaker_to_segment(
segment: Segment,
turns: Sequence[SpeakerTurn],
) -> bool:
"""Assign speaker to segment from diarization turns.
Consolidates the 3 duplicated speaker assignment patterns.
Args:
segment: Domain segment to update.
turns: Sequence of speaker turns from diarization.
Returns:
True if speaker was assigned, False otherwise.
"""
speaker_id, confidence = assign_speaker(
segment.start_time,
segment.end_time,
turns,
)
if speaker_id is None:
return False
segment.speaker_id = speaker_id
segment.speaker_confidence = confidence
return True
class DiarizationMixin:
"""Mixin providing speaker diarization functionality.
@@ -46,6 +106,9 @@ class DiarizationMixin:
) -> None:
"""Process an audio chunk for streaming diarization (best-effort).
Uses per-meeting sessions to enable concurrent diarization without
race conditions. Each meeting has its own pipeline state.
Offloads heavy ML inference to thread pool to avoid blocking the event loop.
"""
if self._diarization_engine is None:
@@ -57,13 +120,21 @@ class DiarizationMixin:
loop = asyncio.get_running_loop()
# Get or create per-meeting session under lock
async with self._diarization_lock:
if not self._diarization_engine.is_streaming_loaded:
session = self._diarization_sessions.get(meeting_id)
if session is None:
try:
await loop.run_in_executor(
session = await loop.run_in_executor(
None,
self._diarization_engine.load_streaming_model,
self._diarization_engine.create_streaming_session,
meeting_id,
)
prior_turns = self._diarization_turns.get(meeting_id, [])
prior_stream_time = self._diarization_stream_time.get(meeting_id, 0.0)
if prior_turns or prior_stream_time:
session.restore(prior_turns, stream_time=prior_stream_time)
self._diarization_sessions[meeting_id] = session
except (RuntimeError, ValueError) as exc:
logger.warning(
"Streaming diarization disabled for meeting %s: %s",
@@ -73,56 +144,48 @@ class DiarizationMixin:
self._diarization_streaming_failed.add(meeting_id)
return
stream_time = self._diarization_stream_time.get(meeting_id, 0.0)
duration = len(audio) / self.DEFAULT_SAMPLE_RATE
try:
turns = await loop.run_in_executor(
None,
partial(
self._diarization_engine.process_chunk,
audio,
sample_rate=self.DEFAULT_SAMPLE_RATE,
),
)
except Exception as exc:
logger.warning(
"Streaming diarization failed for meeting %s: %s",
meeting_id,
exc,
)
self._diarization_streaming_failed.add(meeting_id)
return
diarization_turns = self._diarization_turns.setdefault(meeting_id, [])
adjusted_turns: list[SpeakerTurn] = []
for turn in turns:
adjusted = SpeakerTurn(
speaker=turn.speaker,
start=turn.start + stream_time,
end=turn.end + stream_time,
confidence=turn.confidence,
# Process chunk in thread pool (outside lock for parallelism)
try:
new_turns = await loop.run_in_executor(
None,
partial(
session.process_chunk,
audio,
sample_rate=self.DEFAULT_SAMPLE_RATE,
),
)
diarization_turns.append(adjusted)
adjusted_turns.append(adjusted)
except Exception as exc:
logger.warning(
"Streaming diarization failed for meeting %s: %s",
meeting_id,
exc,
)
self._diarization_streaming_failed.add(meeting_id)
return
self._diarization_stream_time[meeting_id] = stream_time + duration
# Populate _diarization_turns for compatibility with _maybe_assign_speaker
if new_turns:
diarization_turns = self._diarization_turns.setdefault(meeting_id, [])
diarization_turns.extend(new_turns)
# Persist turns immediately for crash resilience
if adjusted_turns and self._use_database():
# Update stream time for legacy compatibility
self._diarization_stream_time[meeting_id] = session.stream_time
# Persist turns immediately for crash resilience (DB only)
try:
async with self._create_uow() as uow:
repo_turns = [
StreamingTurn(
speaker=t.speaker,
start_time=t.start,
end_time=t.end,
confidence=t.confidence,
)
for t in adjusted_turns
]
await uow.diarization_jobs.add_streaming_turns(meeting_id, repo_turns)
await uow.commit()
async with self._create_repository_provider() as repo:
if repo.supports_diarization_jobs:
repo_turns = [
StreamingTurn(
speaker=t.speaker,
start_time=t.start,
end_time=t.end,
confidence=t.confidence,
)
for t in new_turns
]
await repo.diarization_jobs.add_streaming_turns(meeting_id, repo_turns)
await repo.commit()
except Exception:
logger.exception("Failed to persist streaming turns for %s", meeting_id)
@@ -136,21 +199,11 @@ class DiarizationMixin:
return
if meeting_id in self._diarization_streaming_failed:
return
turns = self._diarization_turns.get(meeting_id)
if not turns:
if turns := self._diarization_turns.get(meeting_id):
_apply_speaker_to_segment(segment, turns)
else:
return
speaker_id, confidence = assign_speaker(
segment.start_time,
segment.end_time,
turns,
)
if speaker_id is None:
return
segment.speaker_id = speaker_id
segment.speaker_confidence = confidence
async def _prune_diarization_jobs(self: ServicerHost) -> None:
"""Remove completed diarization jobs older than retention window.
@@ -168,24 +221,25 @@ class DiarizationMixin:
noteflow_pb2.JOB_STATUS_FAILED,
}
# Prune old completed jobs from database
if self._use_database():
async with self._create_uow() as uow:
pruned = await uow.diarization_jobs.prune_completed(
# Prune old completed jobs from database or in-memory store
async with self._create_repository_provider() as repo:
if repo.supports_diarization_jobs:
pruned = await repo.diarization_jobs.prune_completed(
self.DIARIZATION_JOB_TTL_SECONDS
)
await uow.commit()
await repo.commit()
if pruned > 0:
logger.debug("Pruned %d completed diarization jobs", pruned)
else:
cutoff = datetime.now() - timedelta(seconds=self.DIARIZATION_JOB_TTL_SECONDS)
expired = [
job_id
for job_id, job in self._diarization_jobs.items()
if job.status in terminal_statuses and job.updated_at < cutoff
]
for job_id in expired:
self._diarization_jobs.pop(job_id, None)
else:
# In-memory fallback: prune from local dict
cutoff = datetime.now() - timedelta(seconds=self.DIARIZATION_JOB_TTL_SECONDS)
expired = [
job_id
for job_id, job in self._diarization_jobs.items()
if job.status in terminal_statuses and job.updated_at < cutoff
]
for job_id in expired:
self._diarization_jobs.pop(job_id, None)
async def RefineSpeakerDiarization(
self: ServicerHost,
@@ -195,86 +249,52 @@ class DiarizationMixin:
"""Run post-meeting speaker diarization refinement.
Load the full meeting audio, run offline diarization, and update
segment speaker assignments. Job state is persisted to database.
segment speaker assignments. Job state is persisted when DB available.
"""
await self._prune_diarization_jobs()
if not self._diarization_refinement_enabled:
response = noteflow_pb2.RefineSpeakerDiarizationResponse()
response.segments_updated = 0
response.speaker_ids[:] = []
response.error_message = "Diarization refinement disabled on server"
response.job_id = ""
response.status = noteflow_pb2.JOB_STATUS_FAILED
return response
return _create_diarization_error_response("Diarization refinement disabled on server")
if self._diarization_engine is None:
response = noteflow_pb2.RefineSpeakerDiarizationResponse()
response.segments_updated = 0
response.speaker_ids[:] = []
response.error_message = "Diarization not enabled on server"
response.job_id = ""
response.status = noteflow_pb2.JOB_STATUS_FAILED
return response
return _create_diarization_error_response("Diarization not enabled on server")
try:
meeting_uuid = UUID(request.meeting_id)
UUID(request.meeting_id) # Validate UUID format
except ValueError:
response = noteflow_pb2.RefineSpeakerDiarizationResponse()
response.segments_updated = 0
response.speaker_ids[:] = []
response.error_message = "Invalid meeting_id"
response.job_id = ""
response.status = noteflow_pb2.JOB_STATUS_FAILED
return response
return _create_diarization_error_response("Invalid meeting_id")
if self._use_database():
async with self._create_uow() as uow:
meeting = await uow.meetings.get(MeetingId(meeting_uuid))
else:
store = self._get_memory_store()
meeting = store.get(request.meeting_id)
if meeting is None:
response = noteflow_pb2.RefineSpeakerDiarizationResponse()
response.segments_updated = 0
response.speaker_ids[:] = []
response.error_message = "Meeting not found"
response.job_id = ""
response.status = noteflow_pb2.JOB_STATUS_FAILED
return response
meeting_state = meeting.state
if meeting_state in (
MeetingState.UNSPECIFIED,
MeetingState.CREATED,
MeetingState.RECORDING,
MeetingState.STOPPING,
):
response = noteflow_pb2.RefineSpeakerDiarizationResponse()
response.segments_updated = 0
response.speaker_ids[:] = []
response.error_message = (
f"Meeting must be stopped before refinement (state: {meeting_state.name.lower()})"
async with self._create_repository_provider() as repo:
meeting = await repo.meetings.get(parse_meeting_id(request.meeting_id))
if meeting is None:
return _create_diarization_error_response("Meeting not found")
meeting_state = meeting.state
if meeting_state in (
MeetingState.UNSPECIFIED,
MeetingState.CREATED,
MeetingState.RECORDING,
MeetingState.STOPPING,
):
return _create_diarization_error_response(
f"Meeting must be stopped before refinement (state: {meeting_state.name.lower()})"
)
num_speakers = request.num_speakers if request.num_speakers > 0 else None
job_id = str(uuid4())
job = DiarizationJob(
job_id=job_id,
meeting_id=request.meeting_id,
status=noteflow_pb2.JOB_STATUS_QUEUED,
)
response.job_id = ""
response.status = noteflow_pb2.JOB_STATUS_FAILED
return response
num_speakers = request.num_speakers if request.num_speakers > 0 else None
job_id = str(uuid4())
job = DiarizationJob(
job_id=job_id,
meeting_id=request.meeting_id,
status=noteflow_pb2.JOB_STATUS_QUEUED,
)
# Persist job to database
if self._use_database():
async with self._create_uow() as uow:
await uow.diarization_jobs.create(job)
await uow.commit()
else:
self._diarization_jobs[job_id] = job
# Persist job if DB supports it, otherwise use in-memory dict
if repo.supports_diarization_jobs:
await repo.diarization_jobs.create(job)
await repo.commit()
else:
self._diarization_jobs[job_id] = job
# Create background task and store reference for potential cancellation
task = asyncio.create_task(self._run_diarization_job(job_id, num_speakers))
@@ -295,32 +315,31 @@ class DiarizationMixin:
) -> None:
"""Run background diarization job.
Updates job status in database as the job progresses.
Updates job status in repository as the job progresses.
"""
# Get meeting_id from database
# Get meeting_id and update status to RUNNING
meeting_id: str | None = None
job: DiarizationJob | None = None
if self._use_database():
async with self._create_uow() as uow:
job = await uow.diarization_jobs.get(job_id)
async with self._create_repository_provider() as repo:
if repo.supports_diarization_jobs:
job = await repo.diarization_jobs.get(job_id)
if job is None:
logger.warning("Diarization job %s not found in database", job_id)
return
meeting_id = job.meeting_id
# Update status to RUNNING
await uow.diarization_jobs.update_status(
await repo.diarization_jobs.update_status(
job_id,
noteflow_pb2.JOB_STATUS_RUNNING,
)
await uow.commit()
else:
job = self._diarization_jobs.get(job_id)
if job is None:
logger.warning("Diarization job %s not found in memory", job_id)
return
meeting_id = job.meeting_id
job.status = noteflow_pb2.JOB_STATUS_RUNNING
job.updated_at = datetime.now()
await repo.commit()
else:
job = self._diarization_jobs.get(job_id)
if job is None:
logger.warning("Diarization job %s not found in memory", job_id)
return
meeting_id = job.meeting_id
job.status = noteflow_pb2.JOB_STATUS_RUNNING
job.updated_at = datetime.now()
try:
updated_count = await self.refine_speaker_diarization(
@@ -330,17 +349,16 @@ class DiarizationMixin:
speaker_ids = await self._collect_speaker_ids(meeting_id)
# Update status to COMPLETED
if self._use_database():
async with self._create_uow() as uow:
await uow.diarization_jobs.update_status(
async with self._create_repository_provider() as repo:
if repo.supports_diarization_jobs:
await repo.diarization_jobs.update_status(
job_id,
noteflow_pb2.JOB_STATUS_COMPLETED,
segments_updated=updated_count,
speaker_ids=speaker_ids,
)
await uow.commit()
else:
if job is not None:
await repo.commit()
elif job is not None:
job.status = noteflow_pb2.JOB_STATUS_COMPLETED
job.segments_updated = updated_count
job.speaker_ids = speaker_ids
@@ -349,16 +367,15 @@ class DiarizationMixin:
except Exception as exc:
logger.exception("Diarization failed for meeting %s", meeting_id)
# Update status to FAILED
if self._use_database():
async with self._create_uow() as uow:
await uow.diarization_jobs.update_status(
async with self._create_repository_provider() as repo:
if repo.supports_diarization_jobs:
await repo.diarization_jobs.update_status(
job_id,
noteflow_pb2.JOB_STATUS_FAILED,
error_message=str(exc),
)
await uow.commit()
else:
if job is not None:
await repo.commit()
elif job is not None:
job.status = noteflow_pb2.JOB_STATUS_FAILED
job.error_message = str(exc)
job.updated_at = datetime.now()
@@ -452,53 +469,39 @@ class DiarizationMixin:
"""Apply diarization turns to segments and return updated count."""
updated_count = 0
if self._use_database():
async with self._create_uow() as uow:
segments = await uow.segments.get_by_meeting(MeetingId(UUID(meeting_id)))
for segment in segments:
if segment.db_id is None:
continue
speaker_id, confidence = assign_speaker(
segment.start_time,
segment.end_time,
turns,
)
if speaker_id is None:
continue
await uow.segments.update_speaker(
segment.db_id,
speaker_id,
confidence,
)
updated_count += 1
await uow.commit()
else:
store = self._get_memory_store()
if meeting := store.get(meeting_id):
for segment in meeting.segments:
speaker_id, confidence = assign_speaker(
segment.start_time,
segment.end_time,
turns,
)
if speaker_id is None:
continue
segment.speaker_id = speaker_id
segment.speaker_confidence = confidence
async with self._create_repository_provider() as repo:
try:
parsed_meeting_id = parse_meeting_id(meeting_id)
except ValueError:
logger.warning("Invalid meeting_id %s while applying diarization turns", meeting_id)
return 0
segments = await repo.segments.get_by_meeting(parsed_meeting_id)
for segment in segments:
if _apply_speaker_to_segment(segment, turns):
# For DB segments with db_id, use update_speaker
if segment.db_id is not None:
await repo.segments.update_speaker(
segment.db_id,
segment.speaker_id,
segment.speaker_confidence,
)
updated_count += 1
await repo.commit()
return updated_count
async def _collect_speaker_ids(self: ServicerHost, meeting_id: str) -> list[str]:
"""Collect distinct speaker IDs for a meeting."""
if self._use_database():
async with self._create_uow() as uow:
segments = await uow.segments.get_by_meeting(MeetingId(UUID(meeting_id)))
return sorted({s.speaker_id for s in segments if s.speaker_id})
store = self._get_memory_store()
if meeting := store.get(meeting_id):
return sorted({s.speaker_id for s in meeting.segments if s.speaker_id})
return []
async with self._create_repository_provider() as repo:
try:
parsed_meeting_id = parse_meeting_id(meeting_id)
except ValueError:
logger.warning("Invalid meeting_id %s while collecting speaker ids", meeting_id)
return []
segments = await repo.segments.get_by_meeting(parsed_meeting_id)
return sorted({s.speaker_id for s in segments if s.speaker_id})
async def RenameSpeaker(
self: ServicerHost,
@@ -511,42 +514,35 @@ class DiarizationMixin:
to use new_speaker_name instead.
"""
if not request.old_speaker_id or not request.new_speaker_name:
await context.abort(
grpc.StatusCode.INVALID_ARGUMENT,
"old_speaker_id and new_speaker_name are required",
await abort_invalid_argument(
context, "old_speaker_id and new_speaker_name are required"
)
try:
meeting_uuid = UUID(request.meeting_id)
meeting_id = parse_meeting_id(request.meeting_id)
except ValueError:
await context.abort(
grpc.StatusCode.INVALID_ARGUMENT,
"Invalid meeting_id",
)
await abort_invalid_argument(context, "Invalid meeting_id")
updated_count = 0
if self._use_database():
async with self._create_uow() as uow:
segments = await uow.segments.get_by_meeting(MeetingId(meeting_uuid))
async with self._create_repository_provider() as repo:
segments = await repo.segments.get_by_meeting(meeting_id)
for segment in segments:
if segment.speaker_id == request.old_speaker_id and segment.db_id:
await uow.segments.update_speaker(
for segment in segments:
if segment.speaker_id == request.old_speaker_id:
# For DB segments with db_id, use update_speaker
if segment.db_id is not None:
await repo.segments.update_speaker(
segment.db_id,
request.new_speaker_name,
segment.speaker_confidence,
)
updated_count += 1
await uow.commit()
else:
store = self._get_memory_store()
if meeting := store.get(request.meeting_id):
for segment in meeting.segments:
if segment.speaker_id == request.old_speaker_id:
else:
# Memory segments: update directly
segment.speaker_id = request.new_speaker_name
updated_count += 1
updated_count += 1
await repo.commit()
return noteflow_pb2.RenameSpeakerResponse(
segments_updated=updated_count,
@@ -560,35 +556,23 @@ class DiarizationMixin:
) -> noteflow_pb2.DiarizationJobStatus:
"""Return current status for a diarization job.
Queries job state from database for persistence across restarts.
Queries job state from repository for persistence across restarts.
"""
await self._prune_diarization_jobs()
if self._use_database():
async with self._create_uow() as uow:
job = await uow.diarization_jobs.get(request.job_id)
if job is None:
await context.abort(
grpc.StatusCode.NOT_FOUND,
"Diarization job not found",
)
return noteflow_pb2.DiarizationJobStatus(
job_id=job.job_id,
status=job.status,
segments_updated=job.segments_updated,
speaker_ids=job.speaker_ids,
error_message=job.error_message,
)
job = self._diarization_jobs.get(request.job_id)
if job is None:
await context.abort(
grpc.StatusCode.NOT_FOUND,
"Diarization job not found",
async with self._create_repository_provider() as repo:
if repo.supports_diarization_jobs:
job = await repo.diarization_jobs.get(request.job_id)
else:
job = self._diarization_jobs.get(request.job_id)
if job is None:
await abort_not_found(context, "Diarization job", request.job_id)
return noteflow_pb2.DiarizationJobStatus(
job_id=job.job_id,
status=job.status,
segments_updated=job.segments_updated,
speaker_ids=job.speaker_ids,
error_message=job.error_message,
)
return noteflow_pb2.DiarizationJobStatus(
job_id=job.job_id,
status=job.status,
segments_updated=job.segments_updated,
speaker_ids=job.speaker_ids,
error_message=job.error_message,
)

View File

@@ -0,0 +1,116 @@
"""Error response helpers for gRPC service mixins.
Consolidates the 34+ duplicated `await context.abort()` patterns
into reusable helper functions.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, NoReturn
import grpc
if TYPE_CHECKING:
import grpc.aio
async def abort_not_found(
context: grpc.aio.ServicerContext,
entity_type: str,
entity_id: str,
) -> NoReturn:
"""Abort with NOT_FOUND status.
Consolidates the repeated "entity not found" error pattern.
Args:
context: gRPC servicer context.
entity_type: Type of entity (e.g., "Meeting", "Annotation").
entity_id: ID of the missing entity.
Raises:
grpc.RpcError: Always raises NOT_FOUND.
"""
await context.abort(
grpc.StatusCode.NOT_FOUND,
f"{entity_type} {entity_id} not found",
)
# This line is unreachable but helps type checkers
raise AssertionError("Unreachable")
async def abort_database_required(
context: grpc.aio.ServicerContext,
feature: str,
) -> NoReturn:
"""Abort with UNIMPLEMENTED for DB-only features.
Consolidates the repeated "requires database persistence" pattern.
Args:
context: gRPC servicer context.
feature: Feature that requires database (e.g., "Annotations").
Raises:
grpc.RpcError: Always raises UNIMPLEMENTED.
"""
await context.abort(
grpc.StatusCode.UNIMPLEMENTED,
f"{feature} require database persistence",
)
raise AssertionError("Unreachable")
async def abort_invalid_argument(
context: grpc.aio.ServicerContext,
message: str,
) -> NoReturn:
"""Abort with INVALID_ARGUMENT status.
Args:
context: gRPC servicer context.
message: Error message describing the invalid argument.
Raises:
grpc.RpcError: Always raises INVALID_ARGUMENT.
"""
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, message)
raise AssertionError("Unreachable")
async def abort_failed_precondition(
context: grpc.aio.ServicerContext,
message: str,
) -> NoReturn:
"""Abort with FAILED_PRECONDITION status.
Use when operation cannot proceed due to system state.
Args:
context: gRPC servicer context.
message: Error message describing the precondition failure.
Raises:
grpc.RpcError: Always raises FAILED_PRECONDITION.
"""
await context.abort(grpc.StatusCode.FAILED_PRECONDITION, message)
raise AssertionError("Unreachable")
async def abort_internal(
context: grpc.aio.ServicerContext,
message: str,
) -> NoReturn:
"""Abort with INTERNAL status.
Use for unexpected server errors.
Args:
context: gRPC servicer context.
message: Error message describing the internal error.
Raises:
grpc.RpcError: Always raises INTERNAL.
"""
await context.abort(grpc.StatusCode.INTERNAL, message)
raise AssertionError("Unreachable")

View File

@@ -3,15 +3,14 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from uuid import UUID
import grpc.aio
from noteflow.application.services.export_service import ExportFormat, ExportService
from noteflow.domain.value_objects import MeetingId
from ..proto import noteflow_pb2
from .converters import proto_to_export_format
from .converters import parse_meeting_id, proto_to_export_format
from .errors import abort_invalid_argument, abort_not_found
if TYPE_CHECKING:
from .protocols import ServicerHost
@@ -21,7 +20,7 @@ class ExportMixin:
"""Mixin providing export functionality.
Requires host to implement ServicerHost protocol.
Export requires database persistence.
Works with both database and memory backends via RepositoryProvider.
"""
async def ExportTranscript(
@@ -30,19 +29,19 @@ class ExportMixin:
context: grpc.aio.ServicerContext,
) -> noteflow_pb2.ExportTranscriptResponse:
"""Export meeting transcript to specified format."""
if not self._use_database():
await context.abort(
grpc.StatusCode.UNIMPLEMENTED,
"Export requires database persistence",
)
# Map proto format to ExportFormat
fmt = proto_to_export_format(request.format)
export_service = ExportService(self._create_uow())
# Use unified repository provider - works with both DB and memory
try:
meeting_id = parse_meeting_id(request.meeting_id)
except ValueError:
await abort_invalid_argument(context, "Invalid meeting_id")
export_service = ExportService(self._create_repository_provider())
try:
content = await export_service.export_transcript(
MeetingId(UUID(request.meeting_id)),
meeting_id,
fmt,
)
exporter_info = export_service.get_supported_formats()
@@ -61,8 +60,5 @@ class ExportMixin:
format_name=fmt_name,
file_extension=fmt_ext,
)
except ValueError as e:
await context.abort(
grpc.StatusCode.NOT_FOUND,
str(e),
)
except ValueError:
await abort_not_found(context, "Meeting", request.meeting_id)

View File

@@ -2,25 +2,30 @@
from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING
from uuid import UUID
import grpc.aio
from noteflow.domain.entities import Meeting
from noteflow.domain.value_objects import MeetingId, MeetingState
from noteflow.domain.value_objects import MeetingState
from ..proto import noteflow_pb2
from .converters import meeting_to_proto
from .converters import meeting_to_proto, parse_meeting_id
from .errors import abort_invalid_argument, abort_not_found
if TYPE_CHECKING:
from .protocols import ServicerHost
# Timeout for waiting for stream to exit gracefully
STOP_WAIT_TIMEOUT_SECONDS: float = 2.0
class MeetingMixin:
"""Mixin providing meeting CRUD functionality.
Requires host to implement ServicerHost protocol.
Works with both database and memory backends via RepositoryProvider.
"""
async def CreateMeeting(
@@ -31,63 +36,61 @@ class MeetingMixin:
"""Create a new meeting."""
metadata = dict(request.metadata) if request.metadata else {}
if self._use_database():
async with self._create_uow() as uow:
meeting = Meeting.create(title=request.title, metadata=metadata)
saved = await uow.meetings.create(meeting)
await uow.commit()
return meeting_to_proto(saved)
else:
store = self._get_memory_store()
meeting = store.create(title=request.title, metadata=metadata)
return meeting_to_proto(meeting)
async with self._create_repository_provider() as repo:
meeting = Meeting.create(title=request.title, metadata=metadata)
saved = await repo.meetings.create(meeting)
await repo.commit()
return meeting_to_proto(saved)
async def StopMeeting(
self: ServicerHost,
request: noteflow_pb2.StopMeetingRequest,
context: grpc.aio.ServicerContext,
) -> noteflow_pb2.Meeting:
"""Stop a meeting using graceful STOPPING -> STOPPED transition."""
"""Stop a meeting using graceful STOPPING -> STOPPED transition.
If the meeting is actively streaming, signals the stream to stop
and waits briefly for it to exit before closing resources.
"""
meeting_id = request.meeting_id
# Close audio writer if open
# Signal stop to active stream and wait for graceful exit
if meeting_id in self._active_streams:
self._stop_requested.add(meeting_id)
# Wait briefly for stream to detect stop request and exit
wait_iterations = int(STOP_WAIT_TIMEOUT_SECONDS * 10) # 100ms intervals
for _ in range(wait_iterations):
if meeting_id not in self._active_streams:
break
await asyncio.sleep(0.1)
# Clean up stop request even if stream didn't exit
self._stop_requested.discard(meeting_id)
# Close audio writer if open (stream cleanup may have done this)
if meeting_id in self._audio_writers:
self._close_audio_writer(meeting_id)
if self._use_database():
async with self._create_uow() as uow:
meeting = await uow.meetings.get(MeetingId(UUID(meeting_id)))
if meeting is None:
await context.abort(
grpc.StatusCode.NOT_FOUND,
f"Meeting {meeting_id} not found",
)
try:
# Graceful shutdown: RECORDING -> STOPPING -> STOPPED
meeting.begin_stopping()
meeting.stop_recording()
except ValueError as e:
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
await uow.meetings.update(meeting)
# Clean up streaming diarization turns (no longer needed)
await uow.diarization_jobs.clear_streaming_turns(meeting_id)
await uow.commit()
return meeting_to_proto(meeting)
store = self._get_memory_store()
meeting = store.get(meeting_id)
if meeting is None:
await context.abort(
grpc.StatusCode.NOT_FOUND,
f"Meeting {meeting_id} not found",
)
try:
# Graceful shutdown: RECORDING -> STOPPING -> STOPPED
meeting.begin_stopping()
meeting.stop_recording()
except ValueError as e:
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
store.update(meeting)
return meeting_to_proto(meeting)
async with self._create_repository_provider() as repo:
try:
parsed_meeting_id = parse_meeting_id(meeting_id)
except ValueError:
await abort_invalid_argument(context, "Invalid meeting_id")
meeting = await repo.meetings.get(parsed_meeting_id)
if meeting is None:
await abort_not_found(context, "Meeting", meeting_id)
try:
# Graceful shutdown: RECORDING -> STOPPING -> STOPPED
meeting.begin_stopping()
meeting.stop_recording()
except ValueError as e:
await abort_invalid_argument(context, str(e))
await repo.meetings.update(meeting)
# Clean up streaming diarization turns if DB supports it
if repo.supports_diarization_jobs:
await repo.diarization_jobs.clear_streaming_turns(meeting_id)
await repo.commit()
return meeting_to_proto(meeting)
async def ListMeetings(
self: ServicerHost,
@@ -98,24 +101,10 @@ class MeetingMixin:
limit = request.limit or 100
offset = request.offset or 0
sort_desc = request.sort_order != noteflow_pb2.SORT_ORDER_CREATED_ASC
states = [MeetingState(s) for s in request.states] if request.states else None
if self._use_database():
states = [MeetingState(s) for s in request.states] if request.states else None
async with self._create_uow() as uow:
meetings, total = await uow.meetings.list_all(
states=states,
limit=limit,
offset=offset,
sort_desc=sort_desc,
)
return noteflow_pb2.ListMeetingsResponse(
meetings=[meeting_to_proto(m, include_segments=False) for m in meetings],
total_count=total,
)
else:
store = self._get_memory_store()
states = [MeetingState(s) for s in request.states] if request.states else None
meetings, total = store.list_all(
async with self._create_repository_provider() as repo:
meetings, total = await repo.meetings.list_all(
states=states,
limit=limit,
offset=offset,
@@ -132,39 +121,28 @@ class MeetingMixin:
context: grpc.aio.ServicerContext,
) -> noteflow_pb2.Meeting:
"""Get meeting details."""
if self._use_database():
async with self._create_uow() as uow:
meeting = await uow.meetings.get(MeetingId(UUID(request.meeting_id)))
if meeting is None:
await context.abort(
grpc.StatusCode.NOT_FOUND,
f"Meeting {request.meeting_id} not found",
)
# Load segments if requested
if request.include_segments:
segments = await uow.segments.get_by_meeting(meeting.id)
meeting.segments = list(segments)
# Load summary if requested
if request.include_summary:
summary = await uow.summaries.get_by_meeting(meeting.id)
meeting.summary = summary
return meeting_to_proto(
meeting,
include_segments=request.include_segments,
include_summary=request.include_summary,
)
store = self._get_memory_store()
meeting = store.get(request.meeting_id)
if meeting is None:
await context.abort(
grpc.StatusCode.NOT_FOUND,
f"Meeting {request.meeting_id} not found",
async with self._create_repository_provider() as repo:
try:
meeting_id = parse_meeting_id(request.meeting_id)
except ValueError:
await abort_invalid_argument(context, "Invalid meeting_id")
meeting = await repo.meetings.get(meeting_id)
if meeting is None:
await abort_not_found(context, "Meeting", request.meeting_id)
# Load segments if requested
if request.include_segments:
segments = await repo.segments.get_by_meeting(meeting.id)
meeting.segments = list(segments)
# Load summary if requested
if request.include_summary:
summary = await repo.summaries.get_by_meeting(meeting.id)
meeting.summary = summary
return meeting_to_proto(
meeting,
include_segments=request.include_segments,
include_summary=request.include_summary,
)
return meeting_to_proto(
meeting,
include_segments=request.include_segments,
include_summary=request.include_summary,
)
async def DeleteMeeting(
self: ServicerHost,
@@ -172,21 +150,14 @@ class MeetingMixin:
context: grpc.aio.ServicerContext,
) -> noteflow_pb2.DeleteMeetingResponse:
"""Delete a meeting."""
if self._use_database():
async with self._create_uow() as uow:
success = await uow.meetings.delete(MeetingId(UUID(request.meeting_id)))
if success:
await uow.commit()
return noteflow_pb2.DeleteMeetingResponse(success=True)
await context.abort(
grpc.StatusCode.NOT_FOUND,
f"Meeting {request.meeting_id} not found",
)
store = self._get_memory_store()
success = store.delete(request.meeting_id)
if not success:
await context.abort(
grpc.StatusCode.NOT_FOUND,
f"Meeting {request.meeting_id} not found",
)
return noteflow_pb2.DeleteMeetingResponse(success=True)
async with self._create_repository_provider() as repo:
try:
meeting_id = parse_meeting_id(request.meeting_id)
except ValueError:
await abort_invalid_argument(context, "Invalid meeting_id")
success = await repo.meetings.delete(meeting_id)
if success:
await repo.commit()
return noteflow_pb2.DeleteMeetingResponse(success=True)
await abort_not_found(context, "Meeting", request.meeting_id)

View File

@@ -13,9 +13,14 @@ if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from noteflow.domain.entities import Meeting
from noteflow.domain.ports.unit_of_work import UnitOfWork
from noteflow.infrastructure.asr import FasterWhisperEngine, Segmenter, StreamingVad
from noteflow.infrastructure.audio.writer import MeetingAudioWriter
from noteflow.infrastructure.diarization import DiarizationEngine, SpeakerTurn
from noteflow.infrastructure.diarization import (
DiarizationEngine,
DiarizationSession,
SpeakerTurn,
)
from noteflow.infrastructure.persistence.repositories import DiarizationJob
from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWork
from noteflow.infrastructure.security.crypto import AesGcmCryptoBox
@@ -53,6 +58,7 @@ class ServicerHost(Protocol):
_segment_counters: dict[str, int]
_stream_formats: dict[str, tuple[int, int]]
_active_streams: set[str]
_stop_requested: set[str] # Meeting IDs with pending stop requests
# Partial transcription state per meeting
_partial_buffers: dict[str, list[NDArray[np.float32]]]
@@ -63,6 +69,7 @@ class ServicerHost(Protocol):
_diarization_turns: dict[str, list[SpeakerTurn]]
_diarization_stream_time: dict[str, float]
_diarization_streaming_failed: set[str]
_diarization_sessions: dict[str, DiarizationSession]
# Background diarization task references (for cancellation)
_diarization_jobs: dict[str, DiarizationJob]
@@ -84,7 +91,19 @@ class ServicerHost(Protocol):
...
def _create_uow(self) -> SqlAlchemyUnitOfWork:
"""Create a new Unit of Work."""
"""Create a new Unit of Work (database-backed)."""
...
def _create_repository_provider(self) -> UnitOfWork:
"""Create a repository provider (database or memory backed).
Returns a UnitOfWork implementation appropriate for the current
configuration. Use this for operations that can work with either
backend, eliminating the need for if/else branching.
Returns:
SqlAlchemyUnitOfWork if database configured, MemoryUnitOfWork otherwise.
"""
...
def _next_segment_id(self, meeting_id: str, fallback: int = 0) -> int:

View File

@@ -8,19 +8,25 @@ import time
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import TYPE_CHECKING
from uuid import UUID
import grpc.aio
import numpy as np
from numpy.typing import NDArray
from noteflow.domain.value_objects import MeetingId
from noteflow.infrastructure.diarization import SpeakerTurn
from ..proto import noteflow_pb2
from .converters import create_segment_from_asr, create_vad_update, segment_to_proto_update
from .converters import (
create_segment_from_asr,
create_vad_update,
parse_meeting_id,
segment_to_proto_update,
)
from .errors import abort_failed_precondition, abort_invalid_argument
if TYPE_CHECKING:
from noteflow.domain.entities import Segment
from .protocols import ServicerHost
logger = logging.getLogger(__name__)
@@ -55,12 +61,10 @@ class StreamingMixin:
Receive audio chunks from client, process through ASR,
persist segments, and yield transcript updates.
Works with both database and memory backends via RepositoryProvider.
"""
if self._asr_engine is None or not self._asr_engine.is_loaded:
await context.abort(
grpc.StatusCode.FAILED_PRECONDITION,
"ASR engine not loaded",
)
await abort_failed_precondition(context, "ASR engine not loaded")
current_meeting_id: str | None = None
@@ -68,10 +72,7 @@ class StreamingMixin:
async for chunk in request_iterator:
meeting_id = chunk.meeting_id
if not meeting_id:
await context.abort(
grpc.StatusCode.INVALID_ARGUMENT,
"meeting_id required",
)
await abort_invalid_argument(context, "meeting_id required")
# Initialize stream on first chunk
if current_meeting_id is None:
@@ -80,11 +81,18 @@ class StreamingMixin:
return # Error already sent via context.abort
current_meeting_id = meeting_id
elif meeting_id != current_meeting_id:
await context.abort(
grpc.StatusCode.INVALID_ARGUMENT,
"Stream may only contain a single meeting_id",
await abort_invalid_argument(
context, "Stream may only contain a single meeting_id"
)
# Check for stop request (graceful shutdown from StopMeeting)
if current_meeting_id in self._stop_requested:
logger.info(
"Stop requested for meeting %s, exiting stream gracefully",
current_meeting_id,
)
break
# Process audio chunk
async for update in self._process_stream_chunk(current_meeting_id, chunk, context):
yield update
@@ -112,6 +120,8 @@ class StreamingMixin:
) -> _StreamSessionInit | None:
"""Initialize streaming for a meeting.
Uses unified repository provider for both DB and memory backends.
Args:
meeting_id: Meeting ID string.
context: gRPC context for error handling.
@@ -120,17 +130,11 @@ class StreamingMixin:
Initialization result, or None if error was sent.
"""
if meeting_id in self._active_streams:
await context.abort(
grpc.StatusCode.FAILED_PRECONDITION,
f"Meeting {meeting_id} already streaming",
)
await abort_failed_precondition(context, f"Meeting {meeting_id} already streaming")
self._active_streams.add(meeting_id)
if self._use_database():
init_result = await self._init_stream_session_db(meeting_id)
else:
init_result = self._init_stream_session_memory(meeting_id)
init_result = await self._init_stream_session(meeting_id)
if not init_result.success:
self._active_streams.discard(meeting_id)
@@ -138,11 +142,11 @@ class StreamingMixin:
return init_result
async def _init_stream_session_db(
async def _init_stream_session(
self: ServicerHost,
meeting_id: str,
) -> _StreamSessionInit:
"""Initialize stream session using database persistence.
"""Initialize stream session using unified repository provider.
Args:
meeting_id: Meeting ID string.
@@ -150,8 +154,17 @@ class StreamingMixin:
Returns:
Stream session initialization result.
"""
async with self._create_uow() as uow:
meeting = await uow.meetings.get(MeetingId(UUID(meeting_id)))
async with self._create_repository_provider() as repo:
try:
parsed_meeting_id = parse_meeting_id(meeting_id)
except ValueError:
return _StreamSessionInit(
next_segment_id=0,
error_code=grpc.StatusCode.INVALID_ARGUMENT,
error_message="Invalid meeting_id",
)
meeting = await repo.meetings.get(parsed_meeting_id)
if meeting is None:
return _StreamSessionInit(
next_segment_id=0,
@@ -170,82 +183,43 @@ class StreamingMixin:
)
if dek_updated or recording_updated:
await uow.meetings.update(meeting)
await uow.commit()
await repo.meetings.update(meeting)
await repo.commit()
next_segment_id = await uow.segments.get_next_segment_id(meeting.id)
next_segment_id = await repo.segments.get_next_segment_id(meeting.id)
self._open_meeting_audio_writer(
meeting_id, dek, wrapped_dek, asset_path=meeting.asset_path
)
self._init_streaming_state(meeting_id, next_segment_id)
# Load any persisted streaming turns (crash recovery)
persisted_turns = await uow.diarization_jobs.get_streaming_turns(meeting_id)
if persisted_turns:
domain_turns = [
SpeakerTurn(
speaker=t.speaker,
start=t.start_time,
end=t.end_time,
confidence=t.confidence,
# Load any persisted streaming turns (crash recovery) - DB only
if repo.supports_diarization_jobs:
persisted_turns = await repo.diarization_jobs.get_streaming_turns(meeting_id)
if persisted_turns:
domain_turns = [
SpeakerTurn(
speaker=t.speaker,
start=t.start_time,
end=t.end_time,
confidence=t.confidence,
)
for t in persisted_turns
]
self._diarization_turns[meeting_id] = domain_turns
# Advance stream time to avoid overlapping recovered turns
last_end = max(t.end_time for t in persisted_turns)
self._diarization_stream_time[meeting_id] = max(
self._diarization_stream_time.get(meeting_id, 0.0),
last_end,
)
logger.info(
"Loaded %d streaming diarization turns for meeting %s",
len(domain_turns),
meeting_id,
)
for t in persisted_turns
]
self._diarization_turns[meeting_id] = domain_turns
# Advance stream time to avoid overlapping recovered turns
last_end = max(t.end_time for t in persisted_turns)
self._diarization_stream_time[meeting_id] = max(
self._diarization_stream_time.get(meeting_id, 0.0),
last_end,
)
logger.info(
"Loaded %d streaming diarization turns for meeting %s",
len(domain_turns),
meeting_id,
)
return _StreamSessionInit(next_segment_id=next_segment_id)
def _init_stream_session_memory(
self: ServicerHost,
meeting_id: str,
) -> _StreamSessionInit:
"""Initialize stream session using in-memory store.
Args:
meeting_id: Meeting ID string.
Returns:
Stream session initialization result.
"""
store = self._get_memory_store()
meeting = store.get(meeting_id)
if meeting is None:
return _StreamSessionInit(
next_segment_id=0,
error_code=grpc.StatusCode.NOT_FOUND,
error_message=f"Meeting {meeting_id} not found",
)
dek, wrapped_dek, dek_updated = self._ensure_meeting_dek(meeting)
recording_updated, error_msg = self._start_meeting_if_needed(meeting)
if error_msg:
return _StreamSessionInit(
next_segment_id=0,
error_code=grpc.StatusCode.INVALID_ARGUMENT,
error_message=error_msg,
)
if dek_updated or recording_updated:
store.update(meeting)
next_segment_id = meeting.next_segment_id
self._open_meeting_audio_writer(meeting_id, dek, wrapped_dek, asset_path=meeting.asset_path)
self._init_streaming_state(meeting_id, next_segment_id)
return _StreamSessionInit(next_segment_id=next_segment_id)
async def _process_stream_chunk(
self: ServicerHost,
meeting_id: str,
@@ -269,7 +243,7 @@ class StreamingMixin:
chunk.channels,
)
except ValueError as e:
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
await abort_invalid_argument(context, str(e))
audio = self._decode_audio_chunk(chunk)
if audio is None:
@@ -278,7 +252,7 @@ class StreamingMixin:
try:
audio = self._convert_audio_format(audio, sample_rate, channels)
except ValueError as e:
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
await abort_invalid_argument(context, str(e))
# Write to encrypted audio file
self._write_audio_chunk_safe(meeting_id, audio)
@@ -550,6 +524,9 @@ class StreamingMixin:
) -> AsyncIterator[noteflow_pb2.TranscriptUpdate]:
"""Process a complete audio segment through ASR.
Uses unified repository provider for both DB and memory backends.
Batches all segments from ASR results and commits once per audio chunk.
Args:
meeting_id: Meeting identifier.
audio: Complete audio segment.
@@ -561,37 +538,21 @@ class StreamingMixin:
if len(audio) == 0 or self._asr_engine is None:
return
if self._use_database():
async with self._create_uow() as uow:
meeting = await uow.meetings.get(MeetingId(UUID(meeting_id)))
if meeting is None:
return
async with self._create_repository_provider() as repo:
try:
parsed_meeting_id = parse_meeting_id(meeting_id)
except ValueError:
logger.warning("Invalid meeting_id %s in streaming segment", meeting_id)
return
results = await self._asr_engine.transcribe_async(audio)
for result in results:
segment_id = self._next_segment_id(
meeting_id,
fallback=meeting.next_segment_id,
)
segment = create_segment_from_asr(
meeting.id,
segment_id,
result,
segment_start_time,
)
# Call diarization mixin method if available
if hasattr(self, "_maybe_assign_speaker"):
self._maybe_assign_speaker(meeting_id, segment)
meeting.add_segment(segment)
await uow.segments.add(meeting.id, segment)
await uow.commit()
yield segment_to_proto_update(meeting_id, segment)
else:
store = self._get_memory_store()
meeting = store.get(meeting_id)
meeting = await repo.meetings.get(parsed_meeting_id)
if meeting is None:
return
results = await self._asr_engine.transcribe_async(audio)
# Build all segments first
segments_to_add: list[tuple[Segment, noteflow_pb2.TranscriptUpdate]] = []
for result in results:
segment_id = self._next_segment_id(
meeting_id,
@@ -606,5 +567,13 @@ class StreamingMixin:
# Call diarization mixin method if available
if hasattr(self, "_maybe_assign_speaker"):
self._maybe_assign_speaker(meeting_id, segment)
store.add_segment(meeting_id, segment)
yield segment_to_proto_update(meeting_id, segment)
await repo.segments.add(meeting.id, segment)
segments_to_add.append((segment, segment_to_proto_update(meeting_id, segment)))
# Single commit for all segments in this audio chunk
if segments_to_add:
await repo.commit()
# Yield updates after commit
for _, update in segments_to_add:
yield update

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from uuid import UUID
import grpc.aio
@@ -13,7 +12,8 @@ from noteflow.domain.summarization import ProviderUnavailableError
from noteflow.domain.value_objects import MeetingId
from ..proto import noteflow_pb2
from .converters import summary_to_proto
from .converters import parse_meeting_id, summary_to_proto
from .errors import abort_invalid_argument, abort_not_found
if TYPE_CHECKING:
from noteflow.application.services.summarization_service import SummarizationService
@@ -27,6 +27,7 @@ class SummarizationMixin:
"""Mixin providing summarization functionality.
Requires host to implement ServicerHost protocol.
Works with both database and memory backends via RepositoryProvider.
"""
_summarization_service: SummarizationService | None
@@ -36,70 +37,38 @@ class SummarizationMixin:
request: noteflow_pb2.GenerateSummaryRequest,
context: grpc.aio.ServicerContext,
) -> noteflow_pb2.Summary:
"""Generate meeting summary using SummarizationService with fallback."""
if self._use_database():
return await self._generate_summary_db(request, context)
"""Generate meeting summary using SummarizationService with fallback.
return await self._generate_summary_memory(request, context)
async def _generate_summary_db(
self: ServicerHost,
request: noteflow_pb2.GenerateSummaryRequest,
context: grpc.aio.ServicerContext,
) -> noteflow_pb2.Summary:
"""Generate summary for a meeting stored in the database.
The potentially slow summarization step is executed outside the UoW to
avoid holding database connections while waiting on LLMs.
The potentially slow summarization step is executed outside the repository
context to avoid holding connections while waiting on LLMs.
"""
meeting_id = MeetingId(UUID(request.meeting_id))
try:
meeting_id = parse_meeting_id(request.meeting_id)
except ValueError:
await abort_invalid_argument(context, "Invalid meeting_id")
# 1) Load meeting, existing summary, and segments inside a short UoW
async with self._create_uow() as uow:
meeting = await uow.meetings.get(meeting_id)
# 1) Load meeting, existing summary, and segments in a short transaction
async with self._create_repository_provider() as repo:
meeting = await repo.meetings.get(meeting_id)
if meeting is None:
await context.abort(
grpc.StatusCode.NOT_FOUND,
f"Meeting {request.meeting_id} not found",
)
await abort_not_found(context, "Meeting", request.meeting_id)
existing = await uow.summaries.get_by_meeting(meeting.id)
existing = await repo.summaries.get_by_meeting(meeting.id)
if existing and not request.force_regenerate:
return summary_to_proto(existing)
segments = list(await uow.segments.get_by_meeting(meeting.id))
segments = list(await repo.segments.get_by_meeting(meeting.id))
# 2) Run summarization outside DB transaction
# 2) Run summarization outside repository context (slow LLM call)
summary = await self._summarize_or_placeholder(meeting_id, segments)
# 3) Persist in a fresh UoW
async with self._create_uow() as uow:
saved = await uow.summaries.save(summary)
await uow.commit()
# 3) Persist in a fresh transaction
async with self._create_repository_provider() as repo:
saved = await repo.summaries.save(summary)
await repo.commit()
return summary_to_proto(saved)
async def _generate_summary_memory(
self: ServicerHost,
request: noteflow_pb2.GenerateSummaryRequest,
context: grpc.aio.ServicerContext,
) -> noteflow_pb2.Summary:
"""Generate summary for meetings held in the in-memory store."""
store = self._get_memory_store()
meeting = store.get(request.meeting_id)
if meeting is None:
await context.abort(
grpc.StatusCode.NOT_FOUND,
f"Meeting {request.meeting_id} not found",
)
if meeting.summary and not request.force_regenerate:
return summary_to_proto(meeting.summary)
summary = await self._summarize_or_placeholder(meeting.id, meeting.segments)
store.set_summary(request.meeting_id, summary)
return summary_to_proto(summary)
async def _summarize_or_placeholder(
self: ServicerHost,
meeting_id: MeetingId,

View File

@@ -420,7 +420,11 @@ class NoteFlowClient:
if self._stream_thread:
self._stream_thread.join(timeout=2.0)
self._stream_thread = None
# Only clear reference if thread exited cleanly
if self._stream_thread.is_alive():
logger.warning("Stream thread did not exit within timeout, keeping reference")
else:
self._stream_thread = None
self._current_meeting_id = None
logger.info("Stopped streaming")

View File

@@ -46,6 +46,19 @@ class MeetingStore:
return meeting
def insert(self, meeting: Meeting) -> Meeting:
"""Insert an existing meeting into the store.
Args:
meeting: Meeting entity to store.
Returns:
Stored meeting.
"""
with self._lock:
self._meetings[str(meeting.id)] = meeting
return meeting
def get(self, meeting_id: str) -> Meeting | None:
"""Get a meeting by ID.
@@ -58,6 +71,18 @@ class MeetingStore:
with self._lock:
return self._meetings.get(meeting_id)
def count_by_state(self, state: MeetingState) -> int:
"""Count meetings in a specific state.
Args:
state: Meeting state to count.
Returns:
Number of meetings in the specified state.
"""
with self._lock:
return sum(m.state == state for m in self._meetings.values())
def list_all(
self,
states: Sequence[MeetingState] | None = None,
@@ -94,6 +119,22 @@ class MeetingStore:
return meetings, total
def find_older_than(self, cutoff: datetime) -> list[Meeting]:
"""Find completed meetings older than cutoff date.
Args:
cutoff: Cutoff datetime; meetings ended before this are returned.
Returns:
List of meetings with ended_at before cutoff.
"""
with self._lock:
return [
m
for m in self._meetings.values()
if m.ended_at is not None and m.ended_at < cutoff
]
def update(self, meeting: Meeting) -> Meeting:
"""Update a meeting in the store.
@@ -125,6 +166,19 @@ class MeetingStore:
meeting.add_segment(segment)
return meeting
def get_segments(self, meeting_id: str) -> list[Segment]:
"""Get a copy of segments for a meeting.
Args:
meeting_id: Meeting ID.
Returns:
List of segments (copy), or empty list if meeting not found.
"""
with self._lock:
meeting = self._meetings.get(meeting_id)
return [] if meeting is None else list(meeting.segments)
def set_summary(self, meeting_id: str, summary: Summary) -> Meeting | None:
"""Set meeting summary.
@@ -143,6 +197,35 @@ class MeetingStore:
meeting.summary = summary
return meeting
def get_summary(self, meeting_id: str) -> Summary | None:
"""Get meeting summary.
Args:
meeting_id: Meeting ID.
Returns:
Summary or None if missing.
"""
with self._lock:
meeting = self._meetings.get(meeting_id)
return meeting.summary if meeting else None
def clear_summary(self, meeting_id: str) -> bool:
"""Clear meeting summary.
Args:
meeting_id: Meeting ID.
Returns:
True if cleared, False if meeting not found or no summary set.
"""
with self._lock:
meeting = self._meetings.get(meeting_id)
if meeting is None or meeting.summary is None:
return False
meeting.summary = None
return True
def update_state(self, meeting_id: str, state: MeetingState) -> bool:
"""Atomically update meeting state.
@@ -194,6 +277,21 @@ class MeetingStore:
meeting.end_time = end_time
return True
def get_next_segment_id(self, meeting_id: str) -> int:
"""Get next segment ID for a meeting.
Args:
meeting_id: Meeting ID.
Returns:
Next segment ID (0 if meeting or segments missing).
"""
with self._lock:
meeting = self._meetings.get(meeting_id)
if meeting is None or not meeting.segments:
return 0
return max(s.segment_id for s in meeting.segments) + 1
def delete(self, meeting_id: str) -> bool:
"""Delete a meeting.

View File

@@ -14,9 +14,12 @@ import numpy as np
from noteflow.config.constants import DEFAULT_SAMPLE_RATE as _DEFAULT_SAMPLE_RATE
from noteflow.domain.entities import Meeting
from noteflow.domain.ports.unit_of_work import UnitOfWork
from noteflow.domain.value_objects import MeetingState
from noteflow.infrastructure.asr import Segmenter, SegmenterConfig, StreamingVad
from noteflow.infrastructure.audio.writer import MeetingAudioWriter
from noteflow.infrastructure.diarization import DiarizationSession
from noteflow.infrastructure.persistence.memory import MemoryUnitOfWork
from noteflow.infrastructure.persistence.repositories import DiarizationJob
from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWork
from noteflow.infrastructure.security.crypto import AesGcmCryptoBox
@@ -107,6 +110,7 @@ class NoteFlowServicer(
self._segment_counters: dict[str, int] = {}
self._stream_formats: dict[str, tuple[int, int]] = {}
self._active_streams: set[str] = set()
self._stop_requested: set[str] = set()
# Partial transcription state per meeting
self._partial_buffers: dict[str, list[NDArray[np.float32]]] = {}
@@ -117,6 +121,7 @@ class NoteFlowServicer(
self._diarization_turns: dict[str, list[SpeakerTurn]] = {}
self._diarization_stream_time: dict[str, float] = {}
self._diarization_streaming_failed: set[str] = set()
self._diarization_sessions: dict[str, DiarizationSession] = {}
# Track audio write failures to avoid log spam
self._audio_write_failed: set[str] = set()
@@ -155,11 +160,25 @@ class NoteFlowServicer(
return self._memory_store
def _create_uow(self) -> SqlAlchemyUnitOfWork:
"""Create a new Unit of Work."""
"""Create a new Unit of Work (database-backed)."""
if self._session_factory is None:
raise RuntimeError("Database not configured")
return SqlAlchemyUnitOfWork(self._session_factory)
def _create_repository_provider(self) -> UnitOfWork:
"""Create a repository provider (database or memory backed).
Returns a UnitOfWork implementation appropriate for the current
configuration. Use this for operations that can work with either
backend, eliminating the need for if/else branching.
Returns:
SqlAlchemyUnitOfWork if database configured, MemoryUnitOfWork otherwise.
"""
if self._session_factory is not None:
return SqlAlchemyUnitOfWork(self._session_factory)
return MemoryUnitOfWork(self._get_memory_store())
def _init_streaming_state(self, meeting_id: str, next_segment_id: int) -> None:
"""Initialize VAD, Segmenter, speaking state, and partial buffers for a meeting."""
self._vad_instances[meeting_id] = StreamingVad()
@@ -174,8 +193,8 @@ class NoteFlowServicer(
self._diarization_turns[meeting_id] = []
self._diarization_stream_time[meeting_id] = 0.0
self._diarization_streaming_failed.discard(meeting_id)
if self._diarization_engine is not None:
self._diarization_engine.reset_streaming()
# NOTE: Per-meeting diarization sessions are created lazily in
# _process_streaming_diarization() to avoid blocking on model load
def _cleanup_streaming_state(self, meeting_id: str) -> None:
"""Clean up VAD, Segmenter, speaking state, and partial buffers for a meeting."""
@@ -191,6 +210,10 @@ class NoteFlowServicer(
self._diarization_stream_time.pop(meeting_id, None)
self._diarization_streaming_failed.discard(meeting_id)
# Clean up per-meeting diarization session
if session := self._diarization_sessions.pop(meeting_id, None):
session.close()
def _ensure_meeting_dek(self, meeting: Meeting) -> tuple[bytes, bytes, bool]:
"""Ensure meeting has a DEK, generating one if needed.
@@ -344,6 +367,12 @@ class NoteFlowServicer(
self._diarization_tasks.clear()
# Close all diarization sessions
for meeting_id, session in list(self._diarization_sessions.items()):
logger.debug("Closing diarization session for meeting %s", meeting_id)
session.close()
self._diarization_sessions.clear()
# Close all audio writers
for meeting_id in list(self._audio_writers.keys()):
logger.debug("Closing audio writer for meeting %s", meeting_id)

View File

@@ -9,9 +9,11 @@ from noteflow.infrastructure.diarization.assigner import (
)
from noteflow.infrastructure.diarization.dto import SpeakerTurn
from noteflow.infrastructure.diarization.engine import DiarizationEngine
from noteflow.infrastructure.diarization.session import DiarizationSession
__all__ = [
"DiarizationEngine",
"DiarizationSession",
"SpeakerTurn",
"assign_speaker",
"assign_speakers_batch",

View File

@@ -9,10 +9,12 @@ Requires optional dependencies: pip install noteflow[diarization]
from __future__ import annotations
import logging
import warnings
from typing import TYPE_CHECKING
from noteflow.config.constants import DEFAULT_SAMPLE_RATE
from noteflow.infrastructure.diarization.dto import SpeakerTurn
from noteflow.infrastructure.diarization.session import DiarizationSession
if TYPE_CHECKING:
from collections.abc import Sequence
@@ -60,6 +62,10 @@ class DiarizationEngine:
self._streaming_pipeline = None
self._offline_pipeline = None
# Shared models for per-session pipelines (loaded once, reused)
self._segmentation_model = None
self._embedding_model = None
def _resolve_device(self) -> str:
"""Resolve the actual device to use based on availability.
@@ -133,6 +139,80 @@ class DiarizationEngine:
except Exception as e:
raise RuntimeError(f"Failed to load streaming diarization model: {e}") from e
def _ensure_streaming_models_loaded(self) -> None:
"""Ensure shared streaming models are loaded.
Loads the segmentation and embedding models that are shared across
all streaming sessions. These models are stateless and can be safely
reused by multiple concurrent sessions.
Raises:
RuntimeError: If model loading fails.
ValueError: If HuggingFace token is not provided.
"""
if self._segmentation_model is not None and self._embedding_model is not None:
return
if not self._hf_token:
raise ValueError("HuggingFace token required for pyannote models")
device = self._resolve_device()
logger.info("Loading shared streaming diarization models on %s...", device)
try:
from diart.models import EmbeddingModel, SegmentationModel
self._segmentation_model = SegmentationModel.from_pretrained(
"pyannote/segmentation-3.0",
use_hf_token=self._hf_token,
)
self._embedding_model = EmbeddingModel.from_pretrained(
"pyannote/wespeaker-voxceleb-resnet34-LM",
use_hf_token=self._hf_token,
)
logger.info("Shared streaming models loaded successfully")
except Exception as e:
raise RuntimeError(f"Failed to load streaming models: {e}") from e
def create_streaming_session(self, meeting_id: str) -> DiarizationSession:
"""Create a new per-meeting streaming diarization session.
Each session maintains its own pipeline state, enabling concurrent
diarization of multiple meetings without interference. The underlying
models are shared across sessions for memory efficiency.
Args:
meeting_id: Unique identifier for the meeting.
Returns:
New DiarizationSession for the meeting.
Raises:
RuntimeError: If model loading fails.
ValueError: If HuggingFace token is not provided.
"""
self._ensure_streaming_models_loaded()
from diart import SpeakerDiarization, SpeakerDiarizationConfig
config = SpeakerDiarizationConfig(
segmentation=self._segmentation_model,
embedding=self._embedding_model,
step=self._streaming_latency,
latency=self._streaming_latency,
device=self._resolve_device(),
)
pipeline = SpeakerDiarization(config)
logger.info("Created streaming session for meeting %s", meeting_id)
return DiarizationSession(
meeting_id=meeting_id,
_pipeline=pipeline,
_sample_rate=DEFAULT_SAMPLE_RATE,
)
def load_offline_model(self) -> None:
"""Load the offline diarization model (pyannote.audio).
@@ -287,7 +367,19 @@ class DiarizationEngine:
return turns
def reset_streaming(self) -> None:
"""Reset streaming pipeline state for a new recording."""
"""Reset streaming pipeline state for a new recording.
.. deprecated::
Use create_streaming_session() for per-meeting sessions instead.
This method resets global state and will cause race conditions
with concurrent meetings.
"""
warnings.warn(
"reset_streaming() is deprecated. Use create_streaming_session() for "
"per-meeting sessions to avoid race conditions with concurrent meetings.",
DeprecationWarning,
stacklevel=2,
)
if self._streaming_pipeline is not None:
self._streaming_pipeline.reset()
logger.debug("Streaming pipeline state reset")
@@ -296,6 +388,8 @@ class DiarizationEngine:
"""Unload all models to free memory."""
self._streaming_pipeline = None
self._offline_pipeline = None
self._segmentation_model = None
self._embedding_model = None
self._device = None
logger.info("Diarization models unloaded")

View File

@@ -0,0 +1,170 @@
"""Per-meeting diarization session for streaming speaker identification.
Each session maintains its own pipeline state, enabling concurrent meetings
without cross-talk. Shared models are loaded once and reused across sessions.
"""
from __future__ import annotations
import logging
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from noteflow.config.constants import DEFAULT_SAMPLE_RATE
from noteflow.infrastructure.diarization.dto import SpeakerTurn
if TYPE_CHECKING:
import numpy as np
from diart import SpeakerDiarization
from numpy.typing import NDArray
logger = logging.getLogger(__name__)
@dataclass
class DiarizationSession:
"""Per-meeting streaming diarization session.
Maintains independent pipeline state for a single meeting, enabling
concurrent diarization of multiple meetings without interference.
The session owns its own SpeakerDiarization pipeline instance but
shares the underlying segmentation and embedding models with other
sessions for memory efficiency.
"""
meeting_id: str
_pipeline: SpeakerDiarization
_sample_rate: int = DEFAULT_SAMPLE_RATE
_stream_time: float = field(default=0.0, init=False)
_turns: list[SpeakerTurn] = field(default_factory=list, init=False)
_closed: bool = field(default=False, init=False)
def process_chunk(
self,
audio: NDArray[np.float32],
sample_rate: int | None = None,
) -> Sequence[SpeakerTurn]:
"""Process an audio chunk and return new speaker turns.
Args:
audio: Audio samples as float32 array (mono).
sample_rate: Audio sample rate (defaults to session's configured rate).
Returns:
Sequence of speaker turns detected in this chunk,
with times adjusted to absolute stream position.
Raises:
RuntimeError: If session is closed.
"""
if self._closed:
raise RuntimeError(f"Session {self.meeting_id} is closed")
if audio.size == 0:
return []
rate = sample_rate or self._sample_rate
duration = len(audio) / rate
# Import here to avoid import errors when diart not installed
from pyannote.core import SlidingWindow, SlidingWindowFeature
# Reshape audio for diart: (samples,) -> (1, samples)
if audio.ndim == 1:
audio = audio.reshape(1, -1)
# Create SlidingWindowFeature for diart
window = SlidingWindow(start=0.0, duration=duration, step=duration)
waveform = SlidingWindowFeature(audio, window)
# Process through pipeline
results = self._pipeline([waveform])
# Convert results to turns with absolute time offsets
new_turns: list[SpeakerTurn] = []
for annotation, _ in results:
for track in annotation.itertracks(yield_label=True):
if len(track) == 3:
segment, _, speaker = track
turn = SpeakerTurn(
speaker=str(speaker),
start=segment.start + self._stream_time,
end=segment.end + self._stream_time,
)
new_turns.append(turn)
self._turns.append(turn)
self._stream_time += duration
return new_turns
def reset(self) -> None:
"""Reset session state for restarting diarization.
Clears accumulated turns and resets stream time to zero.
The underlying pipeline is also reset.
"""
if self._closed:
return
self._pipeline.reset()
self._stream_time = 0.0
self._turns.clear()
logger.debug("Session %s reset", self.meeting_id)
def restore(
self,
turns: Sequence[SpeakerTurn],
*,
stream_time: float | None = None,
) -> None:
"""Restore session state from prior streaming turns.
Use during crash recovery to continue diarization timelines.
Args:
turns: Previously collected speaker turns.
stream_time: Optional stream time to restore. If not provided,
uses the max end time from turns.
"""
if self._closed:
return
self._turns = list(turns)
if stream_time is None:
stream_time = max((t.end for t in turns), default=0.0)
self._stream_time = max(self._stream_time, stream_time)
logger.debug(
"Session %s restored stream_time=%.3f with %d turns",
self.meeting_id,
self._stream_time,
len(self._turns),
)
def close(self) -> None:
"""Close the session and release resources.
After closing, the session cannot be used for further processing.
"""
if self._closed:
return
self._closed = True
self._turns.clear()
logger.debug("Session %s closed", self.meeting_id)
@property
def stream_time(self) -> float:
"""Current stream time position in seconds."""
return self._stream_time
@property
def turns(self) -> list[SpeakerTurn]:
"""All accumulated speaker turns for this session."""
return list(self._turns)
@property
def is_closed(self) -> bool:
"""Check if session is closed."""
return self._closed

View File

@@ -0,0 +1,10 @@
"""In-memory persistence implementations.
Provides repository-pattern wrappers around the MeetingStore for use
when no database is configured. Implements the UnitOfWork protocol
for uniform access across database and memory backends.
"""
from .unit_of_work import MemoryUnitOfWork
__all__ = ["MemoryUnitOfWork"]

View File

@@ -0,0 +1,259 @@
"""Memory repository implementations wrapping MeetingStore.
Provides repository interfaces over the in-memory MeetingStore for
uniform access pattern with database implementations.
"""
from __future__ import annotations
from collections.abc import Sequence
from datetime import datetime
from typing import TYPE_CHECKING
from noteflow.domain.entities import Meeting, Segment, Summary
from noteflow.domain.value_objects import MeetingState
if TYPE_CHECKING:
from noteflow.domain.entities import Annotation
from noteflow.domain.value_objects import AnnotationId, MeetingId
from noteflow.grpc.meeting_store import MeetingStore
from noteflow.infrastructure.persistence.repositories import (
DiarizationJob,
StreamingTurn,
)
class MemoryMeetingRepository:
"""Meeting repository backed by MeetingStore."""
def __init__(self, store: MeetingStore) -> None:
"""Initialize with meeting store.
Args:
store: In-memory meeting store.
"""
self._store = store
async def create(self, meeting: Meeting) -> Meeting:
"""Persist a new meeting."""
return self._store.insert(meeting)
async def get(self, meeting_id: MeetingId) -> Meeting | None:
"""Retrieve a meeting by ID."""
return self._store.get(str(meeting_id))
async def update(self, meeting: Meeting) -> Meeting:
"""Update an existing meeting."""
return self._store.update(meeting)
async def delete(self, meeting_id: MeetingId) -> bool:
"""Delete a meeting and all associated data."""
return self._store.delete(str(meeting_id))
async def list_all(
self,
states: list[MeetingState] | None = None,
limit: int = 100,
offset: int = 0,
sort_desc: bool = True,
) -> tuple[Sequence[Meeting], int]:
"""List meetings with optional filtering."""
return self._store.list_all(states, limit, offset, sort_desc)
async def count_by_state(self, state: MeetingState) -> int:
"""Count meetings in a specific state."""
return self._store.count_by_state(state)
async def find_older_than(self, cutoff: datetime) -> Sequence[Meeting]:
"""Find completed meetings older than cutoff date."""
return self._store.find_older_than(cutoff)
class MemorySegmentRepository:
"""Segment repository backed by MeetingStore."""
def __init__(self, store: MeetingStore) -> None:
"""Initialize with meeting store."""
self._store = store
async def add(self, meeting_id: MeetingId, segment: Segment) -> Segment:
"""Add a segment to a meeting."""
self._store.add_segment(str(meeting_id), segment)
return segment
async def add_batch(
self,
meeting_id: MeetingId,
segments: Sequence[Segment],
) -> Sequence[Segment]:
"""Add multiple segments to a meeting in batch."""
for segment in segments:
self._store.add_segment(str(meeting_id), segment)
return segments
async def get_by_meeting(
self,
meeting_id: MeetingId,
include_words: bool = True,
) -> Sequence[Segment]:
"""Get all segments for a meeting."""
return self._store.get_segments(str(meeting_id))
async def search_semantic(
self,
query_embedding: list[float],
limit: int = 10,
meeting_id: MeetingId | None = None,
) -> Sequence[tuple[Segment, float]]:
"""Semantic search not supported in memory mode."""
return []
async def update_embedding(
self,
segment_db_id: int,
embedding: list[float],
) -> None:
"""Embeddings not supported in memory mode."""
async def update_speaker(
self,
segment_db_id: int,
speaker_id: str,
speaker_confidence: float,
) -> None:
"""Update speaker for segment - not applicable in memory mode.
In memory mode, segments are updated directly on the entity.
This method exists for interface compatibility.
"""
async def get_next_segment_id(self, meeting_id: MeetingId) -> int:
"""Get next segment ID for a meeting."""
return self._store.get_next_segment_id(str(meeting_id))
class MemorySummaryRepository:
"""Summary repository backed by MeetingStore."""
def __init__(self, store: MeetingStore) -> None:
"""Initialize with meeting store."""
self._store = store
async def save(self, summary: Summary) -> Summary:
"""Save or update a meeting summary."""
self._store.set_summary(str(summary.meeting_id), summary)
return summary
async def get_by_meeting(self, meeting_id: MeetingId) -> Summary | None:
"""Get summary for a meeting."""
return self._store.get_summary(str(meeting_id))
async def delete_by_meeting(self, meeting_id: MeetingId) -> bool:
"""Delete summary for a meeting."""
return self._store.clear_summary(str(meeting_id))
class UnsupportedAnnotationRepository:
"""Annotation repository that raises for unsupported operations.
Used in memory mode where annotations require database persistence.
"""
async def add(self, annotation: Annotation) -> Annotation:
"""Not supported in memory mode."""
raise NotImplementedError("Annotations require database persistence")
async def get(self, annotation_id: AnnotationId) -> Annotation | None:
"""Not supported in memory mode."""
raise NotImplementedError("Annotations require database persistence")
async def get_by_meeting(self, meeting_id: MeetingId) -> Sequence[Annotation]:
"""Not supported in memory mode."""
raise NotImplementedError("Annotations require database persistence")
async def get_by_time_range(
self,
meeting_id: MeetingId,
start_time: float,
end_time: float,
) -> Sequence[Annotation]:
"""Not supported in memory mode."""
raise NotImplementedError("Annotations require database persistence")
async def update(self, annotation: Annotation) -> Annotation:
"""Not supported in memory mode."""
raise NotImplementedError("Annotations require database persistence")
async def delete(self, annotation_id: AnnotationId) -> bool:
"""Not supported in memory mode."""
raise NotImplementedError("Annotations require database persistence")
class UnsupportedDiarizationJobRepository:
"""Diarization job repository that raises for unsupported operations.
Used in memory mode where jobs require database persistence.
"""
async def create(self, job: DiarizationJob) -> DiarizationJob:
"""Not supported in memory mode."""
raise NotImplementedError("Diarization jobs require database persistence")
async def get(self, job_id: str) -> DiarizationJob | None:
"""Not supported in memory mode."""
raise NotImplementedError("Diarization jobs require database persistence")
async def update_status(
self,
job_id: str,
status: int,
*,
segments_updated: int | None = None,
speaker_ids: list[str] | None = None,
error_message: str | None = None,
) -> bool:
"""Not supported in memory mode."""
raise NotImplementedError("Diarization jobs require database persistence")
async def prune_completed(self, ttl_seconds: float) -> int:
"""Not supported in memory mode."""
raise NotImplementedError("Diarization jobs require database persistence")
async def add_streaming_turns(
self,
meeting_id: str,
turns: Sequence[StreamingTurn],
) -> int:
"""Not supported in memory mode."""
raise NotImplementedError("Diarization jobs require database persistence")
async def get_streaming_turns(self, meeting_id: str) -> list[StreamingTurn]:
"""Not supported in memory mode."""
raise NotImplementedError("Diarization jobs require database persistence")
async def clear_streaming_turns(self, meeting_id: str) -> int:
"""Not supported in memory mode."""
raise NotImplementedError("Diarization jobs require database persistence")
async def mark_running_as_failed(self, error_message: str = "Server restarted") -> int:
"""Not supported in memory mode."""
raise NotImplementedError("Diarization jobs require database persistence")
class UnsupportedPreferencesRepository:
"""Preferences repository that raises for unsupported operations.
Used in memory mode where preferences require database persistence.
"""
async def get(self, key: str) -> object | None:
"""Not supported in memory mode."""
raise NotImplementedError("Preferences require database persistence")
async def set(self, key: str, value: object) -> None:
"""Not supported in memory mode."""
raise NotImplementedError("Preferences require database persistence")
async def delete(self, key: str) -> bool:
"""Not supported in memory mode."""
raise NotImplementedError("Preferences require database persistence")

View File

@@ -0,0 +1,133 @@
"""In-memory Unit of Work implementation.
Provides a UnitOfWork implementation backed by the MeetingStore for use
when no database is configured. Implements the same interface as
SqlAlchemyUnitOfWork for uniform access.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Self
from .repositories import (
MemoryMeetingRepository,
MemorySegmentRepository,
MemorySummaryRepository,
UnsupportedAnnotationRepository,
UnsupportedDiarizationJobRepository,
UnsupportedPreferencesRepository,
)
if TYPE_CHECKING:
from noteflow.grpc.meeting_store import MeetingStore
class MemoryUnitOfWork:
"""In-memory Unit of Work backed by MeetingStore.
Implements the same interface as SqlAlchemyUnitOfWork for uniform
access across database and memory backends.
Commit and rollback are no-ops since changes are applied directly
to the in-memory store.
Example:
async with MemoryUnitOfWork(store) as uow:
meeting = await uow.meetings.get(meeting_id)
await uow.segments.add(meeting_id, segment)
await uow.commit() # No-op, changes already applied
"""
def __init__(self, store: MeetingStore) -> None:
"""Initialize unit of work with meeting store.
Args:
store: In-memory meeting storage.
"""
self._store = store
self._meetings = MemoryMeetingRepository(store)
self._segments = MemorySegmentRepository(store)
self._summaries = MemorySummaryRepository(store)
self._annotations = UnsupportedAnnotationRepository()
self._diarization_jobs = UnsupportedDiarizationJobRepository()
self._preferences = UnsupportedPreferencesRepository()
# Core repositories
@property
def meetings(self) -> MemoryMeetingRepository:
"""Get meetings repository."""
return self._meetings
@property
def segments(self) -> MemorySegmentRepository:
"""Get segments repository."""
return self._segments
@property
def summaries(self) -> MemorySummaryRepository:
"""Get summaries repository."""
return self._summaries
# Optional repositories (unsupported in memory mode)
@property
def annotations(self) -> UnsupportedAnnotationRepository:
"""Get annotations repository (unsupported)."""
return self._annotations
@property
def diarization_jobs(self) -> UnsupportedDiarizationJobRepository:
"""Get diarization jobs repository (unsupported)."""
return self._diarization_jobs
@property
def preferences(self) -> UnsupportedPreferencesRepository:
"""Get preferences repository (unsupported)."""
return self._preferences
# Feature flags - limited in memory mode
@property
def supports_annotations(self) -> bool:
"""Annotations not supported in memory mode."""
return False
@property
def supports_diarization_jobs(self) -> bool:
"""Diarization job persistence not supported in memory mode."""
return False
@property
def supports_preferences(self) -> bool:
"""User preferences not supported in memory mode."""
return False
async def __aenter__(self) -> Self:
"""Enter the unit of work context.
Returns:
Self for use in async with statement.
"""
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: object,
) -> None:
"""Exit the unit of work context.
No-op for memory implementation since changes are applied directly.
"""
async def commit(self) -> None:
"""Commit the current transaction.
No-op for memory implementation - changes are applied directly.
"""
async def rollback(self) -> None:
"""Rollback the current transaction.
Note: Memory implementation does not support rollback.
Changes are applied directly and cannot be undone.
"""

View File

@@ -19,6 +19,8 @@ from sqlalchemy import (
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from noteflow.domain.utils.time import utc_now
# Vector dimension for embeddings (OpenAI compatible)
EMBEDDING_DIM = 1536
@@ -45,7 +47,7 @@ class MeetingModel(Base):
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=datetime.now,
default=utc_now,
)
started_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True),
@@ -121,7 +123,7 @@ class SegmentModel(Base):
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=datetime.now,
default=utc_now,
)
# Relationships
@@ -178,7 +180,7 @@ class SummaryModel(Base):
generated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=datetime.now,
default=utc_now,
)
model_version: Mapped[str | None] = mapped_column(String(50), nullable=True)
@@ -296,7 +298,7 @@ class AnnotationModel(Base):
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=datetime.now,
default=utc_now,
)
# Relationships
@@ -322,8 +324,8 @@ class UserPreferencesModel(Base):
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=datetime.now,
onupdate=datetime.now,
default=utc_now,
onupdate=utc_now,
)
@@ -355,13 +357,13 @@ class DiarizationJobModel(Base):
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=datetime.now,
default=utc_now,
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=datetime.now,
onupdate=datetime.now,
default=utc_now,
onupdate=utc_now,
)
@@ -390,5 +392,5 @@ class StreamingDiarizationTurnModel(Base):
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=datetime.now,
default=utc_now,
)

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING, Any, TypeVar
from sqlalchemy.ext.asyncio import AsyncSession
@@ -78,3 +78,63 @@ class BaseRepository:
"""
await self._session.delete(model)
await self._session.flush()
async def _add_all_and_flush(self, models: list[TModel]) -> list[TModel]:
"""Add multiple models to session and flush once.
Use this for batching inserts to reduce database round-trips.
Args:
models: List of ORM model instances to persist.
Returns:
The persisted models with generated fields populated.
"""
self._session.add_all(models)
await self._session.flush()
return models
async def _execute_count(self, stmt: Select[tuple[int]]) -> int:
"""Execute count query and return result.
Args:
stmt: SQLAlchemy select statement returning a count.
Returns:
Integer count value.
"""
result = await self._session.execute(stmt)
return result.scalar_one()
async def _execute_exists(self, stmt: Select[Any]) -> bool:
"""Check if any rows match the query.
More efficient than fetching all rows - stops at first match.
Args:
stmt: SQLAlchemy select statement.
Returns:
True if at least one row exists.
"""
result = await self._session.execute(stmt.limit(1))
return result.scalar() is not None
async def _update_fields(
self,
model: TModel,
**fields: Any,
) -> TModel:
"""Update model fields and flush.
Args:
model: ORM model instance to update.
**fields: Field name/value pairs to set.
Returns:
The updated model.
"""
for key, value in fields.items():
setattr(model, key, value)
await self._session.flush()
return model

View File

@@ -15,15 +15,15 @@ from noteflow.infrastructure.persistence.repositories._base import BaseRepositor
class SqlAlchemySegmentRepository(BaseRepository):
"""SQLAlchemy implementation of SegmentRepository."""
async def add(self, meeting_id: MeetingId, segment: Segment) -> Segment:
"""Add a segment to a meeting.
def _create_segment_model(self, meeting_id: MeetingId, segment: Segment) -> SegmentModel:
"""Create ORM model for a segment without persisting.
Args:
meeting_id: Meeting identifier.
segment: Segment to add.
segment: Domain segment.
Returns:
Added segment with db_id populated.
SegmentModel ready for persistence.
"""
model = SegmentModel(
meeting_id=UUID(str(meeting_id)),
@@ -46,6 +46,20 @@ class SqlAlchemySegmentRepository(BaseRepository):
word_model = WordTimingModel(**word_kwargs)
model.words.append(word_model)
return model
async def add(self, meeting_id: MeetingId, segment: Segment) -> Segment:
"""Add a segment to a meeting.
Args:
meeting_id: Meeting identifier.
segment: Segment to add.
Returns:
Added segment with db_id populated.
"""
model = self._create_segment_model(meeting_id, segment)
self._session.add(model)
await self._session.flush()
@@ -61,6 +75,8 @@ class SqlAlchemySegmentRepository(BaseRepository):
) -> Sequence[Segment]:
"""Add multiple segments to a meeting in batch.
Uses add_all() with a single flush for better performance.
Args:
meeting_id: Meeting identifier.
segments: Segments to add.
@@ -68,13 +84,25 @@ class SqlAlchemySegmentRepository(BaseRepository):
Returns:
Added segments with db_ids populated.
"""
result_segments: list[Segment] = []
if not segments:
return []
# Build all models upfront
models: list[SegmentModel] = []
for segment in segments:
added = await self.add(meeting_id, segment)
result_segments.append(added)
model = self._create_segment_model(meeting_id, segment)
models.append(model)
return result_segments
# Add all models and flush once
self._session.add_all(models)
await self._session.flush()
# Update segments with db_ids
for segment, model in zip(segments, models, strict=True):
segment.db_id = model.id
segment.meeting_id = meeting_id
return list(segments)
async def get_by_meeting(
self,

View File

@@ -126,6 +126,22 @@ class SqlAlchemyUnitOfWork:
raise RuntimeError("UnitOfWork not in context")
return self._summaries_repo
# Feature flags - all True for database-backed implementation
@property
def supports_annotations(self) -> bool:
"""Annotations are fully supported with database."""
return True
@property
def supports_diarization_jobs(self) -> bool:
"""Diarization job persistence is fully supported with database."""
return True
@property
def supports_preferences(self) -> bool:
"""User preferences persistence is fully supported with database."""
return True
async def __aenter__(self) -> Self:
"""Enter the unit of work context.

View File

@@ -104,7 +104,7 @@ class KeyringKeyStore:
stored = keyring.get_password(self._service_name, self._key_name)
if stored is not None:
logger.debug("Retrieved existing master key from keyring")
return base64.b64decode(stored)
return _decode_and_validate_key(stored, "Keyring storage")
# Generate new key
new_key, encoded = _generate_key()

View File

@@ -322,7 +322,7 @@ class TestTranscriptSearch:
"""Tests for transcript search functionality."""
def test_search_filters_segments(self) -> None:
"""Search should filter visible segments."""
"""Search should filter visible segments via visibility toggle."""
state = MockAppState()
component = TranscriptComponent(state)
component.build()
@@ -338,8 +338,9 @@ class TestTranscriptSearch:
component._search_query = "world"
component._rerender_all_segments()
# Should only show segments containing "world"
visible_count = sum(row is not None for row in component._segment_rows)
# All rows exist, but only matching ones are visible
assert len(component._segment_rows) == 3
visible_count = sum(row.visible for row in component._segment_rows)
assert visible_count == 2
def test_search_is_case_insensitive(self) -> None:
@@ -356,7 +357,9 @@ class TestTranscriptSearch:
component._search_query = "world"
component._rerender_all_segments()
visible_count = sum(row is not None for row in component._segment_rows)
# All rows exist, but only matching ones are visible
assert len(component._segment_rows) == 2
visible_count = sum(row.visible for row in component._segment_rows)
assert visible_count == 1

View File

@@ -9,6 +9,7 @@ import pytest
from noteflow.domain.entities.meeting import Meeting
from noteflow.domain.entities.segment import Segment
from noteflow.domain.entities.summary import Summary
from noteflow.domain.utils.time import utc_now
from noteflow.domain.value_objects import MeetingState
@@ -197,7 +198,7 @@ class TestMeetingProperties:
def test_duration_seconds_in_progress(self) -> None:
"""Test duration is > 0 when started but not ended."""
meeting = Meeting.create()
meeting.started_at = datetime.now() - timedelta(seconds=5)
meeting.started_at = utc_now() - timedelta(seconds=5)
assert meeting.duration_seconds >= 5.0
def test_is_active_created(self) -> None: