Update configuration and enhance transcript component functionality
- Added `repomix-output.md` to `.gitignore` to exclude output files from version control. - Updated `repomix.config.json` to include the `support/` directory in the include paths. - Enhanced `TranscriptComponent` to improve search functionality with a debounce mechanism for input. - Refactored search logic to filter visible segments based on the search query, ensuring better user experience. - Updated `Meeting` and `Annotation` entities to use `utc_now()` for consistent UTC-aware timestamps. - Introduced new repository protocols for `DiarizationJob` and `Preferences` to support additional functionalities. - Implemented in-memory persistence for repositories to facilitate testing and development without a database. - Added error handling and utility functions to streamline gRPC service responses and improve code maintainability.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -5,3 +5,4 @@ spikes/
|
||||
__pycache__/
|
||||
.env
|
||||
logs/status_line.json
|
||||
repomix-output.md
|
||||
|
||||
480
docs/triage.md
480
docs/triage.md
@@ -1,196 +1,382 @@
|
||||
# Triage Review (Validated)
|
||||
According to a document from December 19, 2025 (your repomix-output.md snapshot), here’s a targeted review focused on duplication/redundancy, bug & race risks, optimization, and DRY/robustness—with concrete excerpts and “what to do next”.
|
||||
|
||||
Validated: 2025-12-19
|
||||
Legend: Status = Confirmed, Partially confirmed, Not observed, Already implemented
|
||||
Citations use [path:line] format.
|
||||
⸻
|
||||
|
||||
## 2. Architecture & State Management
|
||||
Highest-impact bug & race-condition risks
|
||||
|
||||
### Issue 2.1: Server-Side State Volatility
|
||||
1) Streaming diarization looks globally stateful (cross-meeting interference risk)
|
||||
|
||||
Status: Confirmed
|
||||
Severity: High
|
||||
Location: src/noteflow/grpc/service.py, src/noteflow/grpc/_mixins/diarization.py, src/noteflow/grpc/server.py
|
||||
In your streaming state init, you reset diarization streaming state for every meeting:
|
||||
|
||||
Evidence:
|
||||
- In-memory stream state lives on the servicer (_active_streams, _audio_writers, _partial_buffers, _diarization_jobs). [src/noteflow/grpc/service.py:95] [src/noteflow/grpc/service.py:106] [src/noteflow/grpc/service.py:109] [src/noteflow/grpc/service.py:122]
|
||||
- Background diarization jobs are stored only in a dict and status reads from it. [src/noteflow/grpc/_mixins/diarization.py:241] [src/noteflow/grpc/_mixins/diarization.py:480]
|
||||
- Server shutdown only stops gRPC; no servicer cleanup hook is invoked. [src/noteflow/grpc/server.py:132]
|
||||
- Crash recovery marks meetings ERROR but does not validate audio assets. [src/noteflow/application/services/recovery_service.py:21] [src/noteflow/application/services/recovery_service.py:71]
|
||||
if self._diarization_engine is not None:
|
||||
self._diarization_engine.reset_streaming()
|
||||
|
||||
Example:
|
||||
- If job state is lost (for example, after a restart), polling will fail and surface a fetch error in the UI. [src/noteflow/grpc/_mixins/diarization.py:479] [src/noteflow/grpc/client.py:885] [src/noteflow/client/components/meeting_library.py:534]
|
||||

|
||||
|
||||
Reusable code locations:
|
||||
- `_close_audio_writer` already centralizes writer cleanup. [src/noteflow/grpc/service.py:247]
|
||||
- Migration patterns for new tables exist (annotations). [src/noteflow/infrastructure/persistence/migrations/versions/b5c3e8a2d1f0_add_annotations_table.py:22]
|
||||
And reset_streaming() resets the underlying streaming pipeline:
|
||||
|
||||
Actions:
|
||||
- Persist diarization jobs (table or cache) and query from `GetDiarizationJobStatus`.
|
||||
- Add a shutdown hook to close all `_audio_writers` and flush buffers.
|
||||
- Optional: add asset integrity checks after RecoveryService marks a meeting ERROR.
|
||||
def reset_streaming(self) -> None:
|
||||
if self._streaming_pipeline is not None:
|
||||
self._streaming_pipeline.reset()
|
||||
|
||||
### Issue 2.2: Implicit Meeting Asset Paths
|
||||

|
||||
|
||||
Status: Confirmed
|
||||
Severity: Medium
|
||||
Location: src/noteflow/infrastructure/audio/reader.py, src/noteflow/infrastructure/audio/writer.py, src/noteflow/infrastructure/persistence/models.py
|
||||
Why this is risky
|
||||
• If two meeting streams happen at once (multiple clients, or a reconnect edge-case), meeting B can reset meeting A’s diarization pipeline mid-stream.
|
||||
• You also have a global diarization lock in streaming paths (serialization), which suggests the diarization streaming pipeline is not intended to be shared concurrently across meetings. (That’s fine—but then enforce that constraint explicitly.) 
|
||||
|
||||
Evidence:
|
||||
- Audio assets are read/written under `meetings_dir / meeting_id`. [src/noteflow/infrastructure/audio/reader.py:72] [src/noteflow/infrastructure/audio/writer.py:94]
|
||||
- MeetingModel defines meeting fields (id/title/state/metadata/wrapped_dek) but no asset path. [src/noteflow/infrastructure/persistence/models.py:38] [src/noteflow/infrastructure/persistence/models.py:64]
|
||||
- Delete logic also assumes `meetings_dir / meeting_id`. [src/noteflow/application/services/meeting_service.py:195]
|
||||
Recommendation
|
||||
Pick one of these models and make it explicit:
|
||||
• Model A: Only one active diarization stream allowed
|
||||
• Enforce a single-stream invariant when diarization streaming is enabled.
|
||||
• If a second stream begins, abort with FAILED_PRECONDITION (“diarization streaming is single-session”).
|
||||
• Model B: Diarization pipeline is per meeting
|
||||
• Store per-meeting diarization pipelines/state: self._diarization_streams[meeting_id] = DiarizationEngine(...) or engine.create_stream_session().
|
||||
• Remove global .reset_streaming() and instead reset only the per-meeting session.
|
||||
|
||||
Example:
|
||||
- Record a meeting with `meetings_dir` set to `~/.noteflow/meetings`, then change it to `/mnt/noteflow`. Playback will look in the new base and fail to find older audio.
|
||||
Testing idea: add an async test that starts two streams and ensures diarization state for meeting A doesn’t get reset when meeting B starts.
|
||||
|
||||
Reusable code locations:
|
||||
- `MeetingAudioWriter.open` and `MeetingAudioReader.load_meeting_audio` are the path entry points. [src/noteflow/infrastructure/audio/writer.py:70] [src/noteflow/infrastructure/audio/reader.py:60]
|
||||
- Migration templates live in `infrastructure/persistence/migrations/versions`. [src/noteflow/infrastructure/persistence/migrations/versions/b5c3e8a2d1f0_add_annotations_table.py:22]
|
||||
⸻
|
||||
|
||||
Actions:
|
||||
- Add `asset_path` (or `storage_path`) column to meetings.
|
||||
- Store the relative path at creation time and use it on read/delete.
|
||||
2) StopMeeting closes the audio writer even if the stream might still be writing
|
||||
|
||||
## 3. Concurrency & Performance
|
||||
In StopMeeting, you close the writer immediately:
|
||||
|
||||
### Issue 3.1: Synchronous Blocking in Async gRPC (Streaming Diarization)
|
||||
meeting_id = request.meeting_id
|
||||
if meeting_id in self._audio_writers:
|
||||
self._close_audio_writer(meeting_id)
|
||||
|
||||
Status: Confirmed
|
||||
Severity: Medium
|
||||
Location: src/noteflow/grpc/_mixins/diarization.py, src/noteflow/grpc/_mixins/streaming.py
|
||||

|
||||
|
||||
Evidence:
|
||||
- `_process_streaming_diarization` calls `process_chunk` synchronously. [src/noteflow/grpc/_mixins/diarization.py:63] [src/noteflow/grpc/_mixins/diarization.py:92]
|
||||
- It is invoked in the async streaming loop on every chunk. [src/noteflow/grpc/_mixins/streaming.py:379]
|
||||
- Diarization engine uses pyannote/diart pipelines. [src/noteflow/infrastructure/diarization/engine.py:1]
|
||||
Why this is risky
|
||||
• Your streaming loop writes audio chunks while streaming is active. If StopMeeting can be called while StreamTranscription is still mid-loop, you can get:
|
||||
• “file not open” style exceptions,
|
||||
• partial/corrupted encrypted audio artifacts,
|
||||
• or silent loss depending on how _write_audio_chunk_safe is implemented.
|
||||
|
||||
Example:
|
||||
- With diarization enabled on CPU, heavy `process_chunk` calls can stall the event loop, delaying transcript updates and heartbeats.
|
||||
Recommendation
|
||||
Make writer closure occur in exactly one place, ideally “stream teardown,” and make StopMeeting signal rather than force:
|
||||
• Track active streams and gate closure:
|
||||
• If meeting_id is currently streaming, set a stop_requested[meeting_id] = True and let the stream finally: close the writer.
|
||||
• Or: in StopMeeting, call the same cleanup routine that a stream uses, but only after ensuring the stream is stopped/cancelled.
|
||||
|
||||
Reusable code locations:
|
||||
- ASR already offloads blocking work via `run_in_executor`. [src/noteflow/infrastructure/asr/engine.py:156]
|
||||
- Offline diarization uses `asyncio.to_thread`. [src/noteflow/grpc/_mixins/diarization.py:305]
|
||||
Testing idea: a concurrency test that runs StreamTranscription and calls StopMeeting mid-stream, asserting no exceptions + writer is closed exactly once.
|
||||
|
||||
Actions:
|
||||
- Offload streaming diarization to a thread/process pool similar to ASR.
|
||||
- Consider a bounded queue so diarization lag does not backpressure streaming.
|
||||
⸻
|
||||
|
||||
### Issue 3.2: VU Meter UI Updates on Every Audio Chunk
|
||||
3) Naive datetimes + timezone-aware DB columns = subtle bugs and wrong timestamps
|
||||
|
||||
Status: Confirmed
|
||||
Severity: Medium
|
||||
Location: src/noteflow/client/app.py, src/noteflow/client/components/vu_meter.py
|
||||
Your domain Meeting uses naive datetimes:
|
||||
|
||||
Evidence:
|
||||
- Audio capture uses 100ms chunks. [src/noteflow/client/app.py:307]
|
||||
- Each chunk triggers `VuMeterComponent.on_audio_frames`, which schedules a UI update. [src/noteflow/client/app.py:552] [src/noteflow/client/components/vu_meter.py:58]
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
Example:
|
||||
- At 100ms chunks, the UI updates about 10 times per second. If chunk duration is lowered (e.g., 20ms), that becomes about 50 updates per second and can cause stutter.
|
||||

|
||||
|
||||
Reusable code locations:
|
||||
- Recording timer throttles updates with a fixed interval and background worker. [src/noteflow/client/components/recording_timer.py:14]
|
||||
But your DB models use DateTime(timezone=True) while still defaulting to naive datetime.now:
|
||||
|
||||
Actions:
|
||||
- Throttle VU updates (for example, 20 fps) or update only when delta exceeds a threshold.
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=datetime.now,
|
||||
)
|
||||
|
||||
## 4. Domain Logic & Reliability
|
||||

|
||||
|
||||
### Issue 4.1: Summarization Consent Persistence
|
||||
And you convert to proto timestamps using .timestamp():
|
||||
|
||||
Status: Confirmed
|
||||
Severity: Low (UX)
|
||||
Location: src/noteflow/application/services/summarization_service.py
|
||||
created_at=meeting.created_at.timestamp(),
|
||||
|
||||
Evidence:
|
||||
- Consent is stored on `SummarizationServiceSettings` and defaults to False. [src/noteflow/application/services/summarization_service.py:56]
|
||||
- `grant_cloud_consent` only mutates the in-memory settings. [src/noteflow/application/services/summarization_service.py:150]
|
||||

|
||||
|
||||
Example:
|
||||
- Users who grant cloud consent must re-consent after every server/client restart.
|
||||
Why this is risky
|
||||
• Naive/aware comparisons can explode at runtime (depending on code paths).
|
||||
• .timestamp() on naive datetimes depends on local timezone assumptions.
|
||||
• You already use UTC-aware time elsewhere (example: retention test uses datetime.now(UTC)), so the intent seems to be “UTC everywhere.” 
|
||||
|
||||
Reusable code locations:
|
||||
- Existing persistence patterns for per-meeting data live in `infrastructure/persistence`. [src/noteflow/infrastructure/persistence/models.py:32]
|
||||
Recommendation
|
||||
• Standardize: store and operate on UTC-aware datetimes everywhere.
|
||||
• Domain: default_factory=lambda: datetime.now(UTC)
|
||||
• DB: default=datetime.now(UTC) or server_default=func.now() (and be consistent about timezone)
|
||||
• Add a test that ensures all exported/serialized timestamps are monotonic and consistent (no “timezone drift”).
|
||||
|
||||
Actions:
|
||||
- Persist consent in a preferences table or config file and hydrate on startup.
|
||||
⸻
|
||||
|
||||
### Issue 4.2: Annotation Validation / Point-in-Time Annotations
|
||||
4) Keyring keystore bypasses your own validation logic
|
||||
|
||||
Status: Not observed (current code supports point annotations)
|
||||
Severity: Low
|
||||
Location: src/noteflow/domain/entities/annotation.py, src/noteflow/client/components/annotation_toolbar.py, src/noteflow/client/components/annotation_display.py
|
||||
You have a proper validator:
|
||||
|
||||
Evidence:
|
||||
- Validation allows `end_time == start_time`. [src/noteflow/domain/entities/annotation.py:37]
|
||||
- UI creates point annotations by setting `start_time == end_time`. [src/noteflow/client/components/annotation_toolbar.py:189]
|
||||
- Display uses `start_time` only, not duration. [src/noteflow/client/components/annotation_display.py:164]
|
||||
def _decode_and_validate_key(self, encoded: str) -> bytes:
|
||||
...
|
||||
if len(decoded) != MASTER_KEY_SIZE:
|
||||
raise WrongKeySizeError(...)
|
||||
|
||||
Example:
|
||||
- Clicking a point annotation seeks to the exact timestamp (no duration needed).
|
||||

|
||||
|
||||
Reusable code locations:
|
||||
- If range annotations are introduced later, `AnnotationDisplayComponent` is where a start-end range would be rendered. [src/noteflow/client/components/annotation_display.py:148]
|
||||
…but KeyringKeyStore skips it:
|
||||
|
||||
Action:
|
||||
- None required now; revisit if range annotations are added to the UI or exports.
|
||||
if stored:
|
||||
return base64.b64decode(stored)
|
||||
|
||||
## 5. Suggested Git Issues (Validated)
|
||||

|
||||
|
||||
Issue A: Persist Diarization Jobs to Database
|
||||
Status: Confirmed
|
||||
Evidence: Jobs live in-memory and `GetDiarizationJobStatus` reads from the dict. [src/noteflow/grpc/_mixins/diarization.py:241] [src/noteflow/grpc/_mixins/diarization.py:480]
|
||||
Reusable code locations:
|
||||
- SQLAlchemy models/migrations patterns. [src/noteflow/infrastructure/persistence/models.py:32] [src/noteflow/infrastructure/persistence/migrations/versions/b5c3e8a2d1f0_add_annotations_table.py:22]
|
||||
Tasks:
|
||||
- Add JobModel with status fields.
|
||||
- Update mixin to persist and query DB.
|
||||
Why this is risky
|
||||
• Invalid base64 or wrong key sizes could silently produce bad keys, later causing encryption/decryption failures that look like “random corruption.”
|
||||
|
||||
Issue B: Implement Audio Writer Buffering
|
||||
Status: Already implemented (close or re-scope)
|
||||
Evidence: `MeetingAudioWriter` buffers and flushes based on `buffer_size`. [src/noteflow/infrastructure/audio/writer.py:126]
|
||||
Reusable code locations:
|
||||
- `AUDIO_BUFFER_SIZE_BYTES` constant. [src/noteflow/config/constants.py:26]
|
||||
Tasks:
|
||||
- None, unless you want to tune buffer size.
|
||||
Recommendation
|
||||
• Use _decode_and_validate_key() in all keystore implementations (file, env, keyring).
|
||||
• Add a test that injects an invalid stored keyring value and asserts a clear, typed failure.
|
||||
|
||||
Issue C: Fallback for Headless Keyring
|
||||
Status: Confirmed
|
||||
Evidence: `KeyringKeyStore` only falls back to env var, not file storage. [src/noteflow/infrastructure/security/keystore.py:49]
|
||||
Reusable code locations:
|
||||
- `KeyringKeyStore` and `InMemoryKeyStore` live in the same module. [src/noteflow/infrastructure/security/keystore.py:35]
|
||||
Tasks:
|
||||
- Add `FileKeyStore` and wire fallback in server/service initialization.
|
||||
⸻
|
||||
|
||||
Issue D: Throttled VU Meter in Client
|
||||
Status: Confirmed
|
||||
Evidence: Each chunk schedules a UI update with no throttle. [src/noteflow/client/components/vu_meter.py:58]
|
||||
Reusable code locations:
|
||||
- Background worker/throttle pattern in `RecordingTimerComponent`. [src/noteflow/client/components/recording_timer.py:14]
|
||||
Tasks:
|
||||
- Add a `last_update_time` and update at fixed cadence.
|
||||
Biggest duplication & DRY pain points (and how to eliminate them)
|
||||
|
||||
Issue E: Explicit Asset Path Storage
|
||||
Status: Confirmed
|
||||
Evidence: Meeting paths derived from `meetings_dir / meeting_id`. [src/noteflow/infrastructure/audio/reader.py:72]
|
||||
Reusable code locations:
|
||||
- Meeting model + migrations. [src/noteflow/infrastructure/persistence/models.py:32]
|
||||
Tasks:
|
||||
- Add `asset_path` column and persist at create time.
|
||||
1) gRPC mixins repeat “DB path vs memory path” everywhere
|
||||
|
||||
Issue F: PGVector Index Creation
|
||||
Status: Confirmed (requires product decision)
|
||||
Evidence: Migration uses `ivfflat` index created immediately. [src/noteflow/infrastructure/persistence/migrations/versions/6a9d9f408f40_initial_schema.py:95]
|
||||
Reusable code locations:
|
||||
- Same migration file for index changes. [src/noteflow/infrastructure/persistence/migrations/versions/6a9d9f408f40_initial_schema.py:95]
|
||||
Tasks:
|
||||
- Consider switching to HNSW or defer index creation until data exists.
|
||||
Example in ListMeetings:
|
||||
|
||||
## 6. Code Quality & Nitpicks (Validated)
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
meetings = await uow.meetings.list_all()
|
||||
else:
|
||||
store = self._get_memory_store()
|
||||
meetings = list(store._meetings.values())
|
||||
return ListMeetingsResponse(meetings=[meeting_to_proto(m) for m in meetings])
|
||||
|
||||
- HTML export template is minimal inline CSS (no external framework). [src/noteflow/infrastructure/export/html.py:33]
|
||||
Example: optionally add a lightweight stylesheet (with offline fallback) for nicer exports.
|
||||
- Transcript partial row is updated in place (no remove/re-add), so flicker risk is already mitigated. [src/noteflow/client/components/transcript.py:182]
|
||||
- `Segment.word_count` uses `text.split()` when no words are present. [src/noteflow/domain/entities/segment.py:72]
|
||||
Example: for very large transcripts, a streaming count (for example, regex iterator) avoids allocating a full list.
|
||||

|
||||
|
||||
Same pattern appears in summarization:
|
||||
|
||||
if self._use_database():
|
||||
summary = await self._generate_summary_db(...)
|
||||
else:
|
||||
summary = self._generate_summary_memory(...)
|
||||
|
||||
And streaming segment persistence is duplicated too (and commits per result; more on perf below). 
|
||||
|
||||
Why this matters
|
||||
• This is a structural duplication multiplier.
|
||||
• Code assistants will mirror the existing pattern and duplicate more.
|
||||
• You’ll fix a bug in one branch and forget the other.
|
||||
|
||||
Recommendation (best ROI refactor)
|
||||
Create a single abstraction layer and make the gRPC layer depend on it:
|
||||
• Define Ports/Protocols (async-friendly):
|
||||
• MeetingRepository, SegmentRepository, SummaryRepository, etc.
|
||||
• Provide two adapters:
|
||||
• SqlAlchemyMeetingRepository
|
||||
• InMemoryMeetingRepository
|
||||
• In the servicer, inject one concrete implementation behind a unified interface:
|
||||
• self.meetings_repo, self.segments_repo, self.summaries_repo
|
||||
|
||||
Then ListMeetings becomes:
|
||||
|
||||
meetings = await self.meetings_repo.list_all()
|
||||
return ListMeetingsResponse(meetings=[meeting_to_proto(m) for m in meetings])
|
||||
|
||||
No branching in every RPC = dramatically less duplication, fewer assistant copy/pastes, and fewer inconsistencies.
|
||||
|
||||
⸻
|
||||
|
||||
Performance / optimization hotspots
|
||||
|
||||
1) DB add_batch is not actually batching
|
||||
|
||||
Your segment repo’s add_batch loops and calls add for each segment:
|
||||
|
||||
async def add_batch(...):
|
||||
for seg in segments:
|
||||
await self.add(meeting_id, seg)
|
||||
|
||||

|
||||
|
||||
Why this hurts
|
||||
• If add() flushes/commits or even just issues INSERTs repeatedly, you get N roundtrips / flushes.
|
||||
|
||||
Recommendation
|
||||
• Build ORM models for all segments and call session.add_all([...]), then flush once.
|
||||
• If you need IDs, flush once and read them back.
|
||||
• Keep commit decision at the UoW/service layer (not inside repo).
|
||||
|
||||
⸻
|
||||
|
||||
2) Streaming persistence commits inside a tight loop
|
||||
|
||||
Inside _process_audio_segment (DB branch):
|
||||
|
||||
for result in results:
|
||||
...
|
||||
await uow.segments.add(meeting.id, segment)
|
||||
await uow.commit()
|
||||
yield segment_to_proto_update(meeting_id, segment)
|
||||
|
||||

|
||||
|
||||
Why this hurts
|
||||
• If ASR returns multiple results for one audio segment, you commit multiple times.
|
||||
• Commits are expensive and can become the bottleneck.
|
||||
|
||||
Recommendation
|
||||
• Commit once per processed audio segment (or once per gRPC request chunk batch), not once per ASR result.
|
||||
• You can still yield after staging; if you must ensure durability before yielding, commit once after staging all results for that audio segment.
|
||||
|
||||
⸻
|
||||
|
||||
3) Transcript search re-renders the entire list every change (likely OK now, painful later)
|
||||
|
||||
**RESOLVED**: Added 200ms debounce timer and visibility toggling instead of full rebuild.
|
||||
- Search input now uses `threading.Timer` with 200ms debounce
|
||||
- All segment rows are created once and visibility is toggled via `container.visible`
|
||||
- No more clearing and rebuilding on each keystroke
|
||||
|
||||
~~On search changes, you do a full rebuild:~~
|
||||
|
||||
~~self._segment_rows.clear()~~
|
||||
~~self._list_view.controls.clear()~~
|
||||
~~for idx, seg in enumerate(self._state.transcript_segments):~~
|
||||
~~...~~
|
||||
~~self._list_view.controls.append(row)~~
|
||||
|
||||

|
||||
|
||||
~~And it is triggered directly by search input changes:~~
|
||||
|
||||
~~self._search_query = (value or "").strip().lower()~~
|
||||
~~...~~
|
||||
~~self._rerender_all_segments()~~
|
||||
|
||||

|
||||
|
||||
~~Recommendation~~
|
||||
~~• Add a debounce (150–250ms) to search updates.~~
|
||||
~~• Consider incremental filtering (track which rows match and only hide/show).~~
|
||||
~~• If transcripts can become large: virtualized list rendering.~~
|
||||
|
||||
⸻
|
||||
|
||||
4) Client streaming stop/join can leak threads
|
||||
|
||||
stop_streaming joins with a timeout and clears the thread reference:
|
||||
|
||||
self._stream_thread.join(timeout=2.0)
|
||||
self._stream_thread = None
|
||||
|
||||
If the thread doesn’t exit in time, you can lose the handle while it’s still running.
|
||||
|
||||
Recommendation
|
||||
• After join timeout, check is_alive() and log + keep the reference (or escalate).
|
||||
• Better: make the worker reliably cancellable and join without timeout in normal shutdown paths.
|
||||
|
||||
⸻
|
||||
|
||||
How to write tests that catch more of these bugs
|
||||
|
||||
You already have a good base: there are stress tests around streaming cleanup / leak prevention (nice). 
|
||||
What’s missing tends to be concurrency + contract + regression tests that mirror real failure modes.
|
||||
|
||||
1) Add “race tests” for the exact dangerous interleavings
|
||||
|
||||
Concrete targets from this review:
|
||||
• StopMeeting vs StreamTranscription write
|
||||
• Start a stream, send some chunks, call StopMeeting mid-flight.
|
||||
• Assert: no exception, audio writer closed once, meeting state consistent.
|
||||
• Two streams with diarization enabled
|
||||
• Start stream A, then start stream B.
|
||||
• Assert: A’s diarization state isn’t reset (or assert server rejects second stream if you enforce single-stream).
|
||||
|
||||
Implementation approach:
|
||||
• Use grpc.aio and asyncio.gather() to force overlap.
|
||||
• Keep tests deterministic by using fixed chunk data and controlled scheduling (await asyncio.sleep(0) strategically).
|
||||
|
||||
2) Contract tests for DB vs memory parity
|
||||
|
||||
As long as you support both backends, enforce “same behavior”:
|
||||
• For each operation (create meeting, add segments, summary generation, deletion):
|
||||
• Run the same test suite twice: once with in-memory store, once with DB store.
|
||||
• Assert responses and side effects match.
|
||||
|
||||
This directly prevents the “fixed in DB path, broken in memory path” class of bugs.
|
||||
|
||||
3) Property-based tests for invariants (catches “weird” edge cases)
|
||||
|
||||
Use Hypothesis (or similar) for:
|
||||
• Random sequences of MeetingState transitions: ensure only allowed transitions occur.
|
||||
• Random segments + word timings:
|
||||
• word start/end are within segment range,
|
||||
• segment durations are non-negative,
|
||||
• transcript ordering invariants.
|
||||
|
||||
These tests are excellent at flushing out hidden assumptions.
|
||||
|
||||
4) Mutation testing (biggest upgrade if “bugs slip through”)
|
||||
|
||||
If you feel like “tests pass but stuff breaks,” mutation testing will tell you where your suite is weak.
|
||||
• Run mutation testing on core modules:
|
||||
• gRPC mixins (streaming/summarization)
|
||||
• repositories
|
||||
• crypto keystore/unwrap logic
|
||||
• You’ll quickly find untested branches and “assert-less” code.
|
||||
|
||||
5) Make regressions mandatory: “bug → test → fix”
|
||||
|
||||
A simple process change that works:
|
||||
• When you fix a bug, first write a failing test that reproduces it.
|
||||
• Only then fix the code.
|
||||
This steadily raises your “bugs can’t sneak back in” floor.
|
||||
|
||||
6) CI guardrails
|
||||
• Enforce coverage thresholds on:
|
||||
• domain + application + grpc layers (UI excluded if unstable)
|
||||
• Run stress/race tests nightly if they’re expensive.
|
||||
|
||||
⸻
|
||||
|
||||
How to get code assistants to stop duplicating code
|
||||
|
||||
1) Remove the architectural duplication magnets
|
||||
|
||||
The #1 driver is your repeated DB-vs-memory branching in every RPC/method. 
|
||||
|
||||
If you refactor to a single repository/service interface, the assistant literally has fewer “duplicate-shaped” surfaces to copy.
|
||||
|
||||
2) Add an “assistant contract” file to your repo
|
||||
|
||||
Create a short, explicit guide your assistants must follow, e.g. ASSISTANT_GUIDELINES.md:
|
||||
|
||||
Include rules like:
|
||||
• “Before writing code, search for existing helpers/converters/repositories and reuse them.”
|
||||
• “No new ‘_db’ / ‘_memory’ functions. Add to the repository interface instead.”
|
||||
• “If adding logic similar to existing code, refactor to shared helper instead of copy/paste.”
|
||||
• “Prefer editing an existing module over creating a new one.”
|
||||
|
||||
3) Prompt pattern that reduces duplication dramatically
|
||||
|
||||
When you ask an assistant to implement something, use a structure like:
|
||||
• Step 1: “List existing functions/files that already do part of this.”
|
||||
• Step 2: “Propose the minimal diff touching the fewest files.”
|
||||
• Step 3: “If code would be duplicated, refactor first (create helper + update callers).”
|
||||
• Step 4: “Show me the final patch.”
|
||||
|
||||
It forces discovery + reuse before generation.
|
||||
|
||||
4) Add duplication detection to CI
|
||||
|
||||
Make duplication a failing signal, not a suggestion:
|
||||
• Add a copy/paste detector (language-agnostic ones exist) and fail above a threshold.
|
||||
• Pair it with lint rules that push toward DRY patterns.
|
||||
|
||||
5) Code review rubric (even if you’re solo)
|
||||
|
||||
A quick checklist:
|
||||
• “Is this new code already present elsewhere?”
|
||||
• “Did we add a second way to do the same thing?”
|
||||
• “Is there now both a DB and memory version of the same logic?”
|
||||
|
||||
⸻
|
||||
|
||||
If you only do 3 things next
|
||||
1. Refactor gRPC mixins to a unified repository/service interface (kills most duplication + reduces assistant copy/paste). 
|
||||
2. Fix diarization streaming state scoping (global reset is the scariest race).  
|
||||
3. Add race tests for StopMeeting vs streaming + enforce UTC-aware datetimes.  
|
||||
|
||||
⸻
|
||||
|
||||
If you want, I can also propose a concrete refactor sketch (new interfaces + how to wire the servicer) that eliminates the DB/memory branching with minimal disruption—using the patterns you already have (UoW + repositories).
|
||||
@@ -26,7 +26,7 @@
|
||||
"includeLogsCount": 50
|
||||
}
|
||||
},
|
||||
"include": ["src/", "tests/"],
|
||||
"include": ["src/", "tests/", "support/"],
|
||||
"ignore": {
|
||||
"useGitignore": true,
|
||||
"useDefaultPatterns": true,
|
||||
|
||||
@@ -509,7 +509,13 @@ class NoteFlowClientApp(TriggerMixin):
|
||||
self._audio_consumer_stop.set()
|
||||
if self._audio_consumer_thread is not None:
|
||||
self._audio_consumer_thread.join(timeout=1.0)
|
||||
self._audio_consumer_thread = None
|
||||
# Only clear reference if thread exited cleanly
|
||||
if self._audio_consumer_thread.is_alive():
|
||||
logger.warning(
|
||||
"Audio consumer thread did not exit within timeout, keeping reference"
|
||||
)
|
||||
else:
|
||||
self._audio_consumer_thread = None
|
||||
# Drain remaining frames
|
||||
while not self._audio_frame_queue.empty():
|
||||
try:
|
||||
|
||||
@@ -8,6 +8,7 @@ from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from collections.abc import Callable
|
||||
from threading import Timer
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import flet as ft
|
||||
@@ -15,6 +16,9 @@ import flet as ft
|
||||
# REUSE existing formatting - do not recreate
|
||||
from noteflow.infrastructure.export._formatting import format_timestamp
|
||||
|
||||
# Debounce delay for search input (milliseconds)
|
||||
_SEARCH_DEBOUNCE_MS = 200
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.client.state import AppState
|
||||
|
||||
@@ -42,10 +46,11 @@ class TranscriptComponent:
|
||||
self._state = state
|
||||
self._on_segment_click = on_segment_click
|
||||
self._list_view: ft.ListView | None = None
|
||||
self._segment_rows: list[ft.Container | None] = [] # Track rows for highlighting
|
||||
self._segment_rows: list[ft.Container] = [] # All rows, use visible property to filter
|
||||
self._search_field: ft.TextField | None = None
|
||||
self._search_query: str = ""
|
||||
self._partial_row: ft.Container | None = None # Live partial at bottom
|
||||
self._search_timer: Timer | None = None # Debounce timer for search
|
||||
|
||||
def build(self) -> ft.Column:
|
||||
"""Build transcript list view with search.
|
||||
@@ -110,6 +115,11 @@ class TranscriptComponent:
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all transcript segments and partials."""
|
||||
# Cancel pending search timer
|
||||
if self._search_timer is not None:
|
||||
self._search_timer.cancel()
|
||||
self._search_timer = None
|
||||
|
||||
self._state.clear_transcript()
|
||||
self._segment_rows.clear()
|
||||
self._partial_row = None
|
||||
@@ -121,13 +131,42 @@ class TranscriptComponent:
|
||||
self._state.request_update()
|
||||
|
||||
def _on_search_change(self, e: ft.ControlEvent) -> None:
|
||||
"""Handle search field change.
|
||||
"""Handle search field change with debounce.
|
||||
|
||||
Args:
|
||||
e: Control event with new search value.
|
||||
"""
|
||||
self._search_query = (e.control.value or "").lower()
|
||||
self._rerender_all_segments()
|
||||
|
||||
# Cancel pending timer
|
||||
if self._search_timer is not None:
|
||||
self._search_timer.cancel()
|
||||
|
||||
# Start new debounce timer
|
||||
self._search_timer = Timer(
|
||||
_SEARCH_DEBOUNCE_MS / 1000.0,
|
||||
self._apply_search_filter,
|
||||
)
|
||||
self._search_timer.start()
|
||||
|
||||
def _apply_search_filter(self) -> None:
|
||||
"""Apply search filter to existing rows via visibility toggle."""
|
||||
self._state.run_on_ui_thread(self._toggle_row_visibility)
|
||||
|
||||
def _toggle_row_visibility(self) -> None:
|
||||
"""Toggle visibility of rows based on search query (UI thread only)."""
|
||||
if not self._list_view:
|
||||
return
|
||||
|
||||
query = self._search_query
|
||||
for idx, container in enumerate(self._segment_rows):
|
||||
if idx >= len(self._state.transcript_segments):
|
||||
continue
|
||||
segment = self._state.transcript_segments[idx]
|
||||
matches = not query or query in segment.text.lower()
|
||||
container.visible = matches
|
||||
|
||||
self._state.request_update()
|
||||
|
||||
def _rerender_all_segments(self) -> None:
|
||||
"""Re-render all segments with current search filter."""
|
||||
@@ -137,15 +176,11 @@ class TranscriptComponent:
|
||||
self._list_view.controls.clear()
|
||||
self._segment_rows.clear()
|
||||
|
||||
query = self._search_query
|
||||
for idx, segment in enumerate(self._state.transcript_segments):
|
||||
# Filter by search query
|
||||
if self._search_query and self._search_query not in segment.text.lower():
|
||||
# Add placeholder to maintain index alignment
|
||||
self._segment_rows.append(None)
|
||||
continue
|
||||
|
||||
# Use original index for click handling
|
||||
container = self._create_segment_row(segment, idx)
|
||||
# Set visibility based on search query
|
||||
container.visible = not query or query in segment.text.lower()
|
||||
self._segment_rows.append(container)
|
||||
self._list_view.controls.append(container)
|
||||
|
||||
@@ -167,14 +202,12 @@ class TranscriptComponent:
|
||||
|
||||
# Use the actual index from state (segments are appended before rendering)
|
||||
segment_index = len(self._state.transcript_segments) - 1
|
||||
|
||||
# Filter by search query during live rendering
|
||||
if self._search_query and self._search_query not in segment.text.lower():
|
||||
self._segment_rows.append(None)
|
||||
return
|
||||
|
||||
container = self._create_segment_row(segment, segment_index)
|
||||
|
||||
# Set visibility based on search query during live rendering
|
||||
query = self._search_query
|
||||
container.visible = not query or query in segment.text.lower()
|
||||
|
||||
self._segment_rows.append(container)
|
||||
self._list_view.controls.append(container)
|
||||
self._state.request_update()
|
||||
@@ -347,8 +380,6 @@ class TranscriptComponent:
|
||||
highlighted_index: Index of segment to highlight, or None to clear.
|
||||
"""
|
||||
for idx, container in enumerate(self._segment_rows):
|
||||
if container is None:
|
||||
continue
|
||||
if idx == highlighted_index:
|
||||
container.bgcolor = ft.Colors.YELLOW_100
|
||||
container.border = ft.border.all(1, ft.Colors.YELLOW_700)
|
||||
@@ -371,10 +402,6 @@ class TranscriptComponent:
|
||||
if not self._list_view or segment_index >= len(self._segment_rows):
|
||||
return
|
||||
|
||||
container = self._segment_rows[segment_index]
|
||||
if container is None:
|
||||
return
|
||||
|
||||
# Estimate row height for scroll calculation
|
||||
estimated_row_height = 50
|
||||
offset = segment_index * estimated_row_height
|
||||
|
||||
@@ -9,6 +9,8 @@ from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.domain.utils.time import utc_now
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.domain.value_objects import AnnotationId, AnnotationType, MeetingId
|
||||
|
||||
@@ -29,7 +31,7 @@ class Annotation:
|
||||
start_time: float
|
||||
end_time: float
|
||||
segment_ids: list[int] = field(default_factory=list)
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
created_at: datetime = field(default_factory=utc_now)
|
||||
|
||||
# Database primary key (set after persistence)
|
||||
db_id: int | None = None
|
||||
|
||||
@@ -7,6 +7,7 @@ from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from noteflow.domain.utils.time import utc_now
|
||||
from noteflow.domain.value_objects import MeetingId, MeetingState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -25,7 +26,7 @@ class Meeting:
|
||||
id: MeetingId
|
||||
title: str
|
||||
state: MeetingState = MeetingState.CREATED
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
created_at: datetime = field(default_factory=utc_now)
|
||||
started_at: datetime | None = None
|
||||
ended_at: datetime | None = None
|
||||
segments: list[Segment] = field(default_factory=list)
|
||||
@@ -50,7 +51,7 @@ class Meeting:
|
||||
New Meeting instance.
|
||||
"""
|
||||
meeting_id = MeetingId(uuid4())
|
||||
now = datetime.now()
|
||||
now = utc_now()
|
||||
|
||||
if not title:
|
||||
title = f"Meeting {now.strftime('%Y-%m-%d %H:%M')}"
|
||||
@@ -98,7 +99,7 @@ class Meeting:
|
||||
id=meeting_id,
|
||||
title=title,
|
||||
state=state,
|
||||
created_at=created_at or datetime.now(),
|
||||
created_at=created_at or utc_now(),
|
||||
started_at=started_at,
|
||||
ended_at=ended_at,
|
||||
metadata=metadata or {},
|
||||
@@ -115,7 +116,7 @@ class Meeting:
|
||||
if not self.state.can_transition_to(MeetingState.RECORDING):
|
||||
raise ValueError(f"Cannot start recording from state {self.state.name}")
|
||||
self.state = MeetingState.RECORDING
|
||||
self.started_at = datetime.now()
|
||||
self.started_at = utc_now()
|
||||
|
||||
def begin_stopping(self) -> None:
|
||||
"""Transition to stopping state for graceful shutdown.
|
||||
@@ -140,7 +141,7 @@ class Meeting:
|
||||
raise ValueError(f"Cannot stop recording from state {self.state.name}")
|
||||
self.state = MeetingState.STOPPED
|
||||
if self.ended_at is None:
|
||||
self.ended_at = datetime.now()
|
||||
self.ended_at = utc_now()
|
||||
|
||||
def complete(self) -> None:
|
||||
"""Transition to completed state.
|
||||
@@ -178,7 +179,7 @@ class Meeting:
|
||||
if self.ended_at and self.started_at:
|
||||
return (self.ended_at - self.started_at).total_seconds()
|
||||
if self.started_at:
|
||||
return (datetime.now() - self.started_at).total_seconds()
|
||||
return (utc_now() - self.started_at).total_seconds()
|
||||
return 0.0
|
||||
|
||||
@property
|
||||
|
||||
@@ -9,6 +9,10 @@ from typing import TYPE_CHECKING, Protocol
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.domain.entities import Annotation, Meeting, Segment, Summary
|
||||
from noteflow.domain.value_objects import AnnotationId, MeetingId, MeetingState
|
||||
from noteflow.infrastructure.persistence.repositories import (
|
||||
DiarizationJob,
|
||||
StreamingTurn,
|
||||
)
|
||||
|
||||
|
||||
class MeetingRepository(Protocol):
|
||||
@@ -310,3 +314,154 @@ class AnnotationRepository(Protocol):
|
||||
True if deleted, False if not found.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class DiarizationJobRepository(Protocol):
|
||||
"""Repository protocol for DiarizationJob operations.
|
||||
|
||||
Tracks background speaker diarization jobs for crash resilience
|
||||
and client visibility.
|
||||
"""
|
||||
|
||||
async def create(self, job: DiarizationJob) -> DiarizationJob:
|
||||
"""Persist a new diarization job.
|
||||
|
||||
Args:
|
||||
job: DiarizationJob data object.
|
||||
|
||||
Returns:
|
||||
Created job.
|
||||
"""
|
||||
...
|
||||
|
||||
async def get(self, job_id: str) -> DiarizationJob | None:
|
||||
"""Retrieve a job by ID.
|
||||
|
||||
Args:
|
||||
job_id: Job identifier.
|
||||
|
||||
Returns:
|
||||
Job if found, None otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
async def update_status(
|
||||
self,
|
||||
job_id: str,
|
||||
status: int,
|
||||
*,
|
||||
segments_updated: int | None = None,
|
||||
speaker_ids: list[str] | None = None,
|
||||
error_message: str | None = None,
|
||||
) -> bool:
|
||||
"""Update job status and optional fields.
|
||||
|
||||
Args:
|
||||
job_id: Job identifier.
|
||||
status: New status value.
|
||||
segments_updated: Optional segments count.
|
||||
speaker_ids: Optional speaker IDs list.
|
||||
error_message: Optional error message.
|
||||
|
||||
Returns:
|
||||
True if job was updated, False if not found.
|
||||
"""
|
||||
...
|
||||
|
||||
async def prune_completed(self, ttl_seconds: float) -> int:
|
||||
"""Delete completed/failed jobs older than TTL.
|
||||
|
||||
Args:
|
||||
ttl_seconds: Time-to-live in seconds.
|
||||
|
||||
Returns:
|
||||
Number of jobs deleted.
|
||||
"""
|
||||
...
|
||||
|
||||
async def add_streaming_turns(
|
||||
self,
|
||||
meeting_id: str,
|
||||
turns: Sequence[StreamingTurn],
|
||||
) -> int:
|
||||
"""Persist streaming diarization turns for crash resilience.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting identifier.
|
||||
turns: Speaker turns to persist.
|
||||
|
||||
Returns:
|
||||
Number of turns added.
|
||||
"""
|
||||
...
|
||||
|
||||
async def get_streaming_turns(self, meeting_id: str) -> list[StreamingTurn]:
|
||||
"""Retrieve streaming diarization turns for a meeting.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting identifier.
|
||||
|
||||
Returns:
|
||||
List of streaming turns.
|
||||
"""
|
||||
...
|
||||
|
||||
async def clear_streaming_turns(self, meeting_id: str) -> int:
|
||||
"""Delete streaming diarization turns for a meeting.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting identifier.
|
||||
|
||||
Returns:
|
||||
Number of turns deleted.
|
||||
"""
|
||||
...
|
||||
|
||||
async def mark_running_as_failed(self, error_message: str = "Server restarted") -> int:
|
||||
"""Mark queued/running jobs as failed.
|
||||
|
||||
Args:
|
||||
error_message: Error message to set on failed jobs.
|
||||
|
||||
Returns:
|
||||
Number of jobs marked as failed.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class PreferencesRepository(Protocol):
|
||||
"""Repository protocol for user preferences operations.
|
||||
|
||||
Stores key-value user preferences for persistence across server restarts.
|
||||
"""
|
||||
|
||||
async def get(self, key: str) -> object | None:
|
||||
"""Get preference value by key.
|
||||
|
||||
Args:
|
||||
key: Preference key.
|
||||
|
||||
Returns:
|
||||
Preference value or None if not found.
|
||||
"""
|
||||
...
|
||||
|
||||
async def set(self, key: str, value: object) -> None:
|
||||
"""Set preference value.
|
||||
|
||||
Args:
|
||||
key: Preference key.
|
||||
value: Preference value.
|
||||
"""
|
||||
...
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete a preference.
|
||||
|
||||
Args:
|
||||
key: Preference key.
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found.
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -2,23 +2,30 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol, Self
|
||||
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .repositories import (
|
||||
AnnotationRepository,
|
||||
DiarizationJobRepository,
|
||||
MeetingRepository,
|
||||
PreferencesRepository,
|
||||
SegmentRepository,
|
||||
SummaryRepository,
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class UnitOfWork(Protocol):
|
||||
"""Unit of Work protocol for managing transactions across repositories.
|
||||
|
||||
Provides transactional consistency when operating on multiple
|
||||
aggregates. Use as a context manager for automatic commit/rollback.
|
||||
|
||||
Implementations may be backed by either a database (SqlAlchemyUnitOfWork)
|
||||
or in-memory storage (MemoryUnitOfWork). The `supports_*` properties
|
||||
indicate which features are available in the current implementation.
|
||||
|
||||
Example:
|
||||
async with uow:
|
||||
meeting = await uow.meetings.get(meeting_id)
|
||||
@@ -26,10 +33,62 @@ class UnitOfWork(Protocol):
|
||||
await uow.commit()
|
||||
"""
|
||||
|
||||
annotations: AnnotationRepository
|
||||
meetings: MeetingRepository
|
||||
segments: SegmentRepository
|
||||
summaries: SummaryRepository
|
||||
# Core repositories (always available)
|
||||
@property
|
||||
def meetings(self) -> MeetingRepository:
|
||||
"""Access the meetings repository."""
|
||||
...
|
||||
|
||||
@property
|
||||
def segments(self) -> SegmentRepository:
|
||||
"""Access the segments repository."""
|
||||
...
|
||||
|
||||
@property
|
||||
def summaries(self) -> SummaryRepository:
|
||||
"""Access the summaries repository."""
|
||||
...
|
||||
|
||||
# Optional repositories (check supports_* before use)
|
||||
@property
|
||||
def annotations(self) -> AnnotationRepository:
|
||||
"""Access the annotations repository."""
|
||||
...
|
||||
|
||||
@property
|
||||
def diarization_jobs(self) -> DiarizationJobRepository:
|
||||
"""Access the diarization jobs repository."""
|
||||
...
|
||||
|
||||
@property
|
||||
def preferences(self) -> PreferencesRepository:
|
||||
"""Access the preferences repository."""
|
||||
...
|
||||
|
||||
# Feature flags for DB-only capabilities
|
||||
@property
|
||||
def supports_annotations(self) -> bool:
|
||||
"""Check if annotation operations are supported.
|
||||
|
||||
Returns False for memory-only implementations.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def supports_diarization_jobs(self) -> bool:
|
||||
"""Check if diarization job persistence is supported.
|
||||
|
||||
Returns False for memory-only implementations.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def supports_preferences(self) -> bool:
|
||||
"""Check if user preferences persistence is supported.
|
||||
|
||||
Returns False for memory-only implementations.
|
||||
"""
|
||||
...
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""Enter the unit of work context.
|
||||
|
||||
5
src/noteflow/domain/utils/__init__.py
Normal file
5
src/noteflow/domain/utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Domain utility functions."""
|
||||
|
||||
from noteflow.domain.utils.time import utc_now
|
||||
|
||||
__all__ = ["utc_now"]
|
||||
21
src/noteflow/domain/utils/time.py
Normal file
21
src/noteflow/domain/utils/time.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Time utilities for consistent timezone handling.
|
||||
|
||||
All datetime values in the domain should be UTC-aware to prevent:
|
||||
- Naive/aware comparison errors
|
||||
- Incorrect .timestamp() conversions
|
||||
- Timezone drift between components
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
"""Return current UTC datetime (timezone-aware).
|
||||
|
||||
Use this instead of datetime.now() throughout the codebase to ensure
|
||||
consistent timezone-aware datetime handling.
|
||||
|
||||
Returns:
|
||||
Current datetime in UTC timezone.
|
||||
"""
|
||||
return datetime.now(UTC)
|
||||
@@ -3,15 +3,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import uuid4
|
||||
|
||||
import grpc.aio
|
||||
|
||||
from noteflow.domain.entities import Annotation
|
||||
from noteflow.domain.value_objects import AnnotationId, MeetingId
|
||||
from noteflow.domain.value_objects import AnnotationId
|
||||
|
||||
from ..proto import noteflow_pb2
|
||||
from .converters import annotation_to_proto, proto_to_annotation_type
|
||||
from .converters import (
|
||||
annotation_to_proto,
|
||||
parse_annotation_id,
|
||||
parse_meeting_id,
|
||||
proto_to_annotation_type,
|
||||
)
|
||||
from .errors import abort_database_required, abort_invalid_argument, abort_not_found
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .protocols import ServicerHost
|
||||
@@ -30,27 +36,28 @@ class AnnotationMixin:
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.Annotation:
|
||||
"""Add an annotation to a meeting."""
|
||||
if not self._use_database():
|
||||
await context.abort(
|
||||
grpc.StatusCode.UNIMPLEMENTED,
|
||||
"Annotations require database persistence",
|
||||
async with self._create_repository_provider() as repo:
|
||||
if not repo.supports_annotations:
|
||||
await abort_database_required(context, "Annotations")
|
||||
|
||||
try:
|
||||
meeting_id = parse_meeting_id(request.meeting_id)
|
||||
except ValueError:
|
||||
await abort_invalid_argument(context, "Invalid meeting_id")
|
||||
|
||||
annotation_type = proto_to_annotation_type(request.annotation_type)
|
||||
annotation = Annotation(
|
||||
id=AnnotationId(uuid4()),
|
||||
meeting_id=meeting_id,
|
||||
annotation_type=annotation_type,
|
||||
text=request.text,
|
||||
start_time=request.start_time,
|
||||
end_time=request.end_time,
|
||||
segment_ids=list(request.segment_ids),
|
||||
)
|
||||
|
||||
annotation_type = proto_to_annotation_type(request.annotation_type)
|
||||
|
||||
annotation = Annotation(
|
||||
id=AnnotationId(uuid4()),
|
||||
meeting_id=MeetingId(UUID(request.meeting_id)),
|
||||
annotation_type=annotation_type,
|
||||
text=request.text,
|
||||
start_time=request.start_time,
|
||||
end_time=request.end_time,
|
||||
segment_ids=list(request.segment_ids),
|
||||
)
|
||||
|
||||
async with self._create_uow() as uow:
|
||||
saved = await uow.annotations.add(annotation)
|
||||
await uow.commit()
|
||||
saved = await repo.annotations.add(annotation)
|
||||
await repo.commit()
|
||||
return annotation_to_proto(saved)
|
||||
|
||||
async def GetAnnotation(
|
||||
@@ -59,19 +66,18 @@ class AnnotationMixin:
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.Annotation:
|
||||
"""Get an annotation by ID."""
|
||||
if not self._use_database():
|
||||
await context.abort(
|
||||
grpc.StatusCode.UNIMPLEMENTED,
|
||||
"Annotations require database persistence",
|
||||
)
|
||||
async with self._create_repository_provider() as repo:
|
||||
if not repo.supports_annotations:
|
||||
await abort_database_required(context, "Annotations")
|
||||
|
||||
async with self._create_uow() as uow:
|
||||
annotation = await uow.annotations.get(AnnotationId(UUID(request.annotation_id)))
|
||||
try:
|
||||
annotation_id = parse_annotation_id(request.annotation_id)
|
||||
except ValueError:
|
||||
await abort_invalid_argument(context, "Invalid annotation_id")
|
||||
|
||||
annotation = await repo.annotations.get(annotation_id)
|
||||
if annotation is None:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
f"Annotation {request.annotation_id} not found",
|
||||
)
|
||||
await abort_not_found(context, "Annotation", request.annotation_id)
|
||||
return annotation_to_proto(annotation)
|
||||
|
||||
async def ListAnnotations(
|
||||
@@ -80,23 +86,23 @@ class AnnotationMixin:
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.ListAnnotationsResponse:
|
||||
"""List annotations for a meeting."""
|
||||
if not self._use_database():
|
||||
await context.abort(
|
||||
grpc.StatusCode.UNIMPLEMENTED,
|
||||
"Annotations require database persistence",
|
||||
)
|
||||
async with self._create_repository_provider() as repo:
|
||||
if not repo.supports_annotations:
|
||||
await abort_database_required(context, "Annotations")
|
||||
|
||||
async with self._create_uow() as uow:
|
||||
meeting_id = MeetingId(UUID(request.meeting_id))
|
||||
try:
|
||||
meeting_id = parse_meeting_id(request.meeting_id)
|
||||
except ValueError:
|
||||
await abort_invalid_argument(context, "Invalid meeting_id")
|
||||
# Check if time range filter is specified
|
||||
if request.start_time > 0 or request.end_time > 0:
|
||||
annotations = await uow.annotations.get_by_time_range(
|
||||
annotations = await repo.annotations.get_by_time_range(
|
||||
meeting_id,
|
||||
request.start_time,
|
||||
request.end_time,
|
||||
)
|
||||
else:
|
||||
annotations = await uow.annotations.get_by_meeting(meeting_id)
|
||||
annotations = await repo.annotations.get_by_meeting(meeting_id)
|
||||
|
||||
return noteflow_pb2.ListAnnotationsResponse(
|
||||
annotations=[annotation_to_proto(a) for a in annotations]
|
||||
@@ -108,19 +114,18 @@ class AnnotationMixin:
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.Annotation:
|
||||
"""Update an existing annotation."""
|
||||
if not self._use_database():
|
||||
await context.abort(
|
||||
grpc.StatusCode.UNIMPLEMENTED,
|
||||
"Annotations require database persistence",
|
||||
)
|
||||
async with self._create_repository_provider() as repo:
|
||||
if not repo.supports_annotations:
|
||||
await abort_database_required(context, "Annotations")
|
||||
|
||||
async with self._create_uow() as uow:
|
||||
annotation = await uow.annotations.get(AnnotationId(UUID(request.annotation_id)))
|
||||
try:
|
||||
annotation_id = parse_annotation_id(request.annotation_id)
|
||||
except ValueError:
|
||||
await abort_invalid_argument(context, "Invalid annotation_id")
|
||||
|
||||
annotation = await repo.annotations.get(annotation_id)
|
||||
if annotation is None:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
f"Annotation {request.annotation_id} not found",
|
||||
)
|
||||
await abort_not_found(context, "Annotation", request.annotation_id)
|
||||
|
||||
# Update fields if provided
|
||||
if request.annotation_type != noteflow_pb2.ANNOTATION_TYPE_UNSPECIFIED:
|
||||
@@ -134,8 +139,8 @@ class AnnotationMixin:
|
||||
if request.segment_ids:
|
||||
annotation.segment_ids = list(request.segment_ids)
|
||||
|
||||
updated = await uow.annotations.update(annotation)
|
||||
await uow.commit()
|
||||
updated = await repo.annotations.update(annotation)
|
||||
await repo.commit()
|
||||
return annotation_to_proto(updated)
|
||||
|
||||
async def DeleteAnnotation(
|
||||
@@ -144,18 +149,17 @@ class AnnotationMixin:
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.DeleteAnnotationResponse:
|
||||
"""Delete an annotation."""
|
||||
if not self._use_database():
|
||||
await context.abort(
|
||||
grpc.StatusCode.UNIMPLEMENTED,
|
||||
"Annotations require database persistence",
|
||||
)
|
||||
async with self._create_repository_provider() as repo:
|
||||
if not repo.supports_annotations:
|
||||
await abort_database_required(context, "Annotations")
|
||||
|
||||
async with self._create_uow() as uow:
|
||||
success = await uow.annotations.delete(AnnotationId(UUID(request.annotation_id)))
|
||||
try:
|
||||
annotation_id = parse_annotation_id(request.annotation_id)
|
||||
except ValueError:
|
||||
await abort_invalid_argument(context, "Invalid annotation_id")
|
||||
|
||||
success = await repo.annotations.delete(annotation_id)
|
||||
if success:
|
||||
await uow.commit()
|
||||
await repo.commit()
|
||||
return noteflow_pb2.DeleteAnnotationResponse(success=True)
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
f"Annotation {request.annotation_id} not found",
|
||||
)
|
||||
await abort_not_found(context, "Annotation", request.annotation_id)
|
||||
|
||||
@@ -4,10 +4,11 @@ from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from noteflow.application.services.export_service import ExportFormat
|
||||
from noteflow.domain.entities import Annotation, Meeting, Segment, Summary
|
||||
from noteflow.domain.value_objects import AnnotationType, MeetingId
|
||||
from noteflow.domain.entities import Annotation, Meeting, Segment, Summary, WordTiming
|
||||
from noteflow.domain.value_objects import AnnotationId, AnnotationType, MeetingId
|
||||
from noteflow.infrastructure.converters import AsrConverter
|
||||
|
||||
from ..proto import noteflow_pb2
|
||||
@@ -16,6 +17,93 @@ if TYPE_CHECKING:
|
||||
from noteflow.infrastructure.asr.dto import AsrResult
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# ID Parsing Helpers (eliminate 11+ duplications)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def parse_meeting_id(meeting_id_str: str) -> MeetingId:
|
||||
"""Parse string to MeetingId.
|
||||
|
||||
Consolidates the repeated `MeetingId(UUID(request.meeting_id))` pattern.
|
||||
|
||||
Args:
|
||||
meeting_id_str: Meeting ID as string (UUID format).
|
||||
|
||||
Returns:
|
||||
MeetingId value object.
|
||||
"""
|
||||
return MeetingId(UUID(meeting_id_str))
|
||||
|
||||
|
||||
def parse_annotation_id(annotation_id_str: str) -> AnnotationId:
|
||||
"""Parse string to AnnotationId.
|
||||
|
||||
Args:
|
||||
annotation_id_str: Annotation ID as string (UUID format).
|
||||
|
||||
Returns:
|
||||
AnnotationId value object.
|
||||
"""
|
||||
return AnnotationId(UUID(annotation_id_str))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Proto Construction Helpers (eliminate duplicate word/segment building)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def word_to_proto(word: WordTiming) -> noteflow_pb2.WordTiming:
|
||||
"""Convert domain WordTiming to protobuf.
|
||||
|
||||
Consolidates the repeated WordTiming construction pattern.
|
||||
|
||||
Args:
|
||||
word: Domain WordTiming entity.
|
||||
|
||||
Returns:
|
||||
Protobuf WordTiming message.
|
||||
"""
|
||||
return noteflow_pb2.WordTiming(
|
||||
word=word.word,
|
||||
start_time=word.start_time,
|
||||
end_time=word.end_time,
|
||||
probability=word.probability,
|
||||
)
|
||||
|
||||
|
||||
def segment_to_final_segment_proto(segment: Segment) -> noteflow_pb2.FinalSegment:
|
||||
"""Convert domain Segment to FinalSegment protobuf.
|
||||
|
||||
Consolidates the repeated FinalSegment construction pattern.
|
||||
|
||||
Args:
|
||||
segment: Domain Segment entity.
|
||||
|
||||
Returns:
|
||||
Protobuf FinalSegment message.
|
||||
"""
|
||||
words = [word_to_proto(w) for w in segment.words]
|
||||
return noteflow_pb2.FinalSegment(
|
||||
segment_id=segment.segment_id,
|
||||
text=segment.text,
|
||||
start_time=segment.start_time,
|
||||
end_time=segment.end_time,
|
||||
words=words,
|
||||
language=segment.language,
|
||||
language_confidence=segment.language_confidence,
|
||||
avg_logprob=segment.avg_logprob,
|
||||
no_speech_prob=segment.no_speech_prob,
|
||||
speaker_id=segment.speaker_id or "",
|
||||
speaker_confidence=segment.speaker_confidence,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Main Converter Functions
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def meeting_to_proto(
|
||||
meeting: Meeting,
|
||||
include_segments: bool = True,
|
||||
@@ -24,31 +112,7 @@ def meeting_to_proto(
|
||||
"""Convert domain Meeting to protobuf."""
|
||||
segments = []
|
||||
if include_segments:
|
||||
for seg in meeting.segments:
|
||||
words = [
|
||||
noteflow_pb2.WordTiming(
|
||||
word=w.word,
|
||||
start_time=w.start_time,
|
||||
end_time=w.end_time,
|
||||
probability=w.probability,
|
||||
)
|
||||
for w in seg.words
|
||||
]
|
||||
segments.append(
|
||||
noteflow_pb2.FinalSegment(
|
||||
segment_id=seg.segment_id,
|
||||
text=seg.text,
|
||||
start_time=seg.start_time,
|
||||
end_time=seg.end_time,
|
||||
words=words,
|
||||
language=seg.language,
|
||||
language_confidence=seg.language_confidence,
|
||||
avg_logprob=seg.avg_logprob,
|
||||
no_speech_prob=seg.no_speech_prob,
|
||||
speaker_id=seg.speaker_id or "",
|
||||
speaker_confidence=seg.speaker_confidence,
|
||||
)
|
||||
)
|
||||
segments = [segment_to_final_segment_proto(seg) for seg in meeting.segments]
|
||||
|
||||
summary = None
|
||||
if include_summary and meeting.summary:
|
||||
@@ -104,32 +168,10 @@ def segment_to_proto_update(
|
||||
segment: Segment,
|
||||
) -> noteflow_pb2.TranscriptUpdate:
|
||||
"""Convert domain Segment to protobuf TranscriptUpdate."""
|
||||
words = [
|
||||
noteflow_pb2.WordTiming(
|
||||
word=w.word,
|
||||
start_time=w.start_time,
|
||||
end_time=w.end_time,
|
||||
probability=w.probability,
|
||||
)
|
||||
for w in segment.words
|
||||
]
|
||||
final_segment = noteflow_pb2.FinalSegment(
|
||||
segment_id=segment.segment_id,
|
||||
text=segment.text,
|
||||
start_time=segment.start_time,
|
||||
end_time=segment.end_time,
|
||||
words=words,
|
||||
language=segment.language,
|
||||
language_confidence=segment.language_confidence,
|
||||
avg_logprob=segment.avg_logprob,
|
||||
no_speech_prob=segment.no_speech_prob,
|
||||
speaker_id=segment.speaker_id or "",
|
||||
speaker_confidence=segment.speaker_confidence,
|
||||
)
|
||||
return noteflow_pb2.TranscriptUpdate(
|
||||
meeting_id=meeting_id,
|
||||
update_type=noteflow_pb2.UPDATE_TYPE_FINAL,
|
||||
segment=final_segment,
|
||||
segment=segment_to_final_segment_proto(segment),
|
||||
server_timestamp=time.time(),
|
||||
)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from noteflow.domain.entities import Segment
|
||||
from noteflow.domain.value_objects import MeetingId, MeetingState
|
||||
from noteflow.domain.value_objects import MeetingState
|
||||
from noteflow.infrastructure.audio.reader import MeetingAudioReader
|
||||
from noteflow.infrastructure.diarization import SpeakerTurn, assign_speaker
|
||||
from noteflow.infrastructure.persistence.repositories import (
|
||||
@@ -23,13 +23,73 @@ from noteflow.infrastructure.persistence.repositories import (
|
||||
)
|
||||
|
||||
from ..proto import noteflow_pb2
|
||||
from .converters import parse_meeting_id
|
||||
from .errors import abort_invalid_argument, abort_not_found
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from .protocols import ServicerHost
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Helper Functions (Eliminate Duplication)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _create_diarization_error_response(
|
||||
error_message: str,
|
||||
status: int = noteflow_pb2.JOB_STATUS_FAILED,
|
||||
) -> noteflow_pb2.RefineSpeakerDiarizationResponse:
|
||||
"""Create error response for RefineSpeakerDiarization.
|
||||
|
||||
Consolidates the 6+ duplicated response construction patterns.
|
||||
|
||||
Args:
|
||||
error_message: Error message describing the failure.
|
||||
status: Job status code (default: JOB_STATUS_FAILED).
|
||||
|
||||
Returns:
|
||||
Populated RefineSpeakerDiarizationResponse with error state.
|
||||
"""
|
||||
response = noteflow_pb2.RefineSpeakerDiarizationResponse()
|
||||
response.segments_updated = 0
|
||||
response.speaker_ids[:] = []
|
||||
response.error_message = error_message
|
||||
response.job_id = ""
|
||||
response.status = status
|
||||
return response
|
||||
|
||||
|
||||
def _apply_speaker_to_segment(
|
||||
segment: Segment,
|
||||
turns: Sequence[SpeakerTurn],
|
||||
) -> bool:
|
||||
"""Assign speaker to segment from diarization turns.
|
||||
|
||||
Consolidates the 3 duplicated speaker assignment patterns.
|
||||
|
||||
Args:
|
||||
segment: Domain segment to update.
|
||||
turns: Sequence of speaker turns from diarization.
|
||||
|
||||
Returns:
|
||||
True if speaker was assigned, False otherwise.
|
||||
"""
|
||||
speaker_id, confidence = assign_speaker(
|
||||
segment.start_time,
|
||||
segment.end_time,
|
||||
turns,
|
||||
)
|
||||
if speaker_id is None:
|
||||
return False
|
||||
segment.speaker_id = speaker_id
|
||||
segment.speaker_confidence = confidence
|
||||
return True
|
||||
|
||||
|
||||
class DiarizationMixin:
|
||||
"""Mixin providing speaker diarization functionality.
|
||||
|
||||
@@ -46,6 +106,9 @@ class DiarizationMixin:
|
||||
) -> None:
|
||||
"""Process an audio chunk for streaming diarization (best-effort).
|
||||
|
||||
Uses per-meeting sessions to enable concurrent diarization without
|
||||
race conditions. Each meeting has its own pipeline state.
|
||||
|
||||
Offloads heavy ML inference to thread pool to avoid blocking the event loop.
|
||||
"""
|
||||
if self._diarization_engine is None:
|
||||
@@ -57,13 +120,21 @@ class DiarizationMixin:
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Get or create per-meeting session under lock
|
||||
async with self._diarization_lock:
|
||||
if not self._diarization_engine.is_streaming_loaded:
|
||||
session = self._diarization_sessions.get(meeting_id)
|
||||
if session is None:
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
session = await loop.run_in_executor(
|
||||
None,
|
||||
self._diarization_engine.load_streaming_model,
|
||||
self._diarization_engine.create_streaming_session,
|
||||
meeting_id,
|
||||
)
|
||||
prior_turns = self._diarization_turns.get(meeting_id, [])
|
||||
prior_stream_time = self._diarization_stream_time.get(meeting_id, 0.0)
|
||||
if prior_turns or prior_stream_time:
|
||||
session.restore(prior_turns, stream_time=prior_stream_time)
|
||||
self._diarization_sessions[meeting_id] = session
|
||||
except (RuntimeError, ValueError) as exc:
|
||||
logger.warning(
|
||||
"Streaming diarization disabled for meeting %s: %s",
|
||||
@@ -73,56 +144,48 @@ class DiarizationMixin:
|
||||
self._diarization_streaming_failed.add(meeting_id)
|
||||
return
|
||||
|
||||
stream_time = self._diarization_stream_time.get(meeting_id, 0.0)
|
||||
duration = len(audio) / self.DEFAULT_SAMPLE_RATE
|
||||
|
||||
try:
|
||||
turns = await loop.run_in_executor(
|
||||
None,
|
||||
partial(
|
||||
self._diarization_engine.process_chunk,
|
||||
audio,
|
||||
sample_rate=self.DEFAULT_SAMPLE_RATE,
|
||||
),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Streaming diarization failed for meeting %s: %s",
|
||||
meeting_id,
|
||||
exc,
|
||||
)
|
||||
self._diarization_streaming_failed.add(meeting_id)
|
||||
return
|
||||
|
||||
diarization_turns = self._diarization_turns.setdefault(meeting_id, [])
|
||||
adjusted_turns: list[SpeakerTurn] = []
|
||||
for turn in turns:
|
||||
adjusted = SpeakerTurn(
|
||||
speaker=turn.speaker,
|
||||
start=turn.start + stream_time,
|
||||
end=turn.end + stream_time,
|
||||
confidence=turn.confidence,
|
||||
# Process chunk in thread pool (outside lock for parallelism)
|
||||
try:
|
||||
new_turns = await loop.run_in_executor(
|
||||
None,
|
||||
partial(
|
||||
session.process_chunk,
|
||||
audio,
|
||||
sample_rate=self.DEFAULT_SAMPLE_RATE,
|
||||
),
|
||||
)
|
||||
diarization_turns.append(adjusted)
|
||||
adjusted_turns.append(adjusted)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Streaming diarization failed for meeting %s: %s",
|
||||
meeting_id,
|
||||
exc,
|
||||
)
|
||||
self._diarization_streaming_failed.add(meeting_id)
|
||||
return
|
||||
|
||||
self._diarization_stream_time[meeting_id] = stream_time + duration
|
||||
# Populate _diarization_turns for compatibility with _maybe_assign_speaker
|
||||
if new_turns:
|
||||
diarization_turns = self._diarization_turns.setdefault(meeting_id, [])
|
||||
diarization_turns.extend(new_turns)
|
||||
|
||||
# Persist turns immediately for crash resilience
|
||||
if adjusted_turns and self._use_database():
|
||||
# Update stream time for legacy compatibility
|
||||
self._diarization_stream_time[meeting_id] = session.stream_time
|
||||
|
||||
# Persist turns immediately for crash resilience (DB only)
|
||||
try:
|
||||
async with self._create_uow() as uow:
|
||||
repo_turns = [
|
||||
StreamingTurn(
|
||||
speaker=t.speaker,
|
||||
start_time=t.start,
|
||||
end_time=t.end,
|
||||
confidence=t.confidence,
|
||||
)
|
||||
for t in adjusted_turns
|
||||
]
|
||||
await uow.diarization_jobs.add_streaming_turns(meeting_id, repo_turns)
|
||||
await uow.commit()
|
||||
async with self._create_repository_provider() as repo:
|
||||
if repo.supports_diarization_jobs:
|
||||
repo_turns = [
|
||||
StreamingTurn(
|
||||
speaker=t.speaker,
|
||||
start_time=t.start,
|
||||
end_time=t.end,
|
||||
confidence=t.confidence,
|
||||
)
|
||||
for t in new_turns
|
||||
]
|
||||
await repo.diarization_jobs.add_streaming_turns(meeting_id, repo_turns)
|
||||
await repo.commit()
|
||||
except Exception:
|
||||
logger.exception("Failed to persist streaming turns for %s", meeting_id)
|
||||
|
||||
@@ -136,21 +199,11 @@ class DiarizationMixin:
|
||||
return
|
||||
if meeting_id in self._diarization_streaming_failed:
|
||||
return
|
||||
turns = self._diarization_turns.get(meeting_id)
|
||||
if not turns:
|
||||
if turns := self._diarization_turns.get(meeting_id):
|
||||
_apply_speaker_to_segment(segment, turns)
|
||||
else:
|
||||
return
|
||||
|
||||
speaker_id, confidence = assign_speaker(
|
||||
segment.start_time,
|
||||
segment.end_time,
|
||||
turns,
|
||||
)
|
||||
if speaker_id is None:
|
||||
return
|
||||
|
||||
segment.speaker_id = speaker_id
|
||||
segment.speaker_confidence = confidence
|
||||
|
||||
async def _prune_diarization_jobs(self: ServicerHost) -> None:
|
||||
"""Remove completed diarization jobs older than retention window.
|
||||
|
||||
@@ -168,24 +221,25 @@ class DiarizationMixin:
|
||||
noteflow_pb2.JOB_STATUS_FAILED,
|
||||
}
|
||||
|
||||
# Prune old completed jobs from database
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
pruned = await uow.diarization_jobs.prune_completed(
|
||||
# Prune old completed jobs from database or in-memory store
|
||||
async with self._create_repository_provider() as repo:
|
||||
if repo.supports_diarization_jobs:
|
||||
pruned = await repo.diarization_jobs.prune_completed(
|
||||
self.DIARIZATION_JOB_TTL_SECONDS
|
||||
)
|
||||
await uow.commit()
|
||||
await repo.commit()
|
||||
if pruned > 0:
|
||||
logger.debug("Pruned %d completed diarization jobs", pruned)
|
||||
else:
|
||||
cutoff = datetime.now() - timedelta(seconds=self.DIARIZATION_JOB_TTL_SECONDS)
|
||||
expired = [
|
||||
job_id
|
||||
for job_id, job in self._diarization_jobs.items()
|
||||
if job.status in terminal_statuses and job.updated_at < cutoff
|
||||
]
|
||||
for job_id in expired:
|
||||
self._diarization_jobs.pop(job_id, None)
|
||||
else:
|
||||
# In-memory fallback: prune from local dict
|
||||
cutoff = datetime.now() - timedelta(seconds=self.DIARIZATION_JOB_TTL_SECONDS)
|
||||
expired = [
|
||||
job_id
|
||||
for job_id, job in self._diarization_jobs.items()
|
||||
if job.status in terminal_statuses and job.updated_at < cutoff
|
||||
]
|
||||
for job_id in expired:
|
||||
self._diarization_jobs.pop(job_id, None)
|
||||
|
||||
async def RefineSpeakerDiarization(
|
||||
self: ServicerHost,
|
||||
@@ -195,86 +249,52 @@ class DiarizationMixin:
|
||||
"""Run post-meeting speaker diarization refinement.
|
||||
|
||||
Load the full meeting audio, run offline diarization, and update
|
||||
segment speaker assignments. Job state is persisted to database.
|
||||
segment speaker assignments. Job state is persisted when DB available.
|
||||
"""
|
||||
await self._prune_diarization_jobs()
|
||||
|
||||
if not self._diarization_refinement_enabled:
|
||||
response = noteflow_pb2.RefineSpeakerDiarizationResponse()
|
||||
response.segments_updated = 0
|
||||
response.speaker_ids[:] = []
|
||||
response.error_message = "Diarization refinement disabled on server"
|
||||
response.job_id = ""
|
||||
response.status = noteflow_pb2.JOB_STATUS_FAILED
|
||||
return response
|
||||
return _create_diarization_error_response("Diarization refinement disabled on server")
|
||||
|
||||
if self._diarization_engine is None:
|
||||
response = noteflow_pb2.RefineSpeakerDiarizationResponse()
|
||||
response.segments_updated = 0
|
||||
response.speaker_ids[:] = []
|
||||
response.error_message = "Diarization not enabled on server"
|
||||
response.job_id = ""
|
||||
response.status = noteflow_pb2.JOB_STATUS_FAILED
|
||||
return response
|
||||
return _create_diarization_error_response("Diarization not enabled on server")
|
||||
|
||||
try:
|
||||
meeting_uuid = UUID(request.meeting_id)
|
||||
UUID(request.meeting_id) # Validate UUID format
|
||||
except ValueError:
|
||||
response = noteflow_pb2.RefineSpeakerDiarizationResponse()
|
||||
response.segments_updated = 0
|
||||
response.speaker_ids[:] = []
|
||||
response.error_message = "Invalid meeting_id"
|
||||
response.job_id = ""
|
||||
response.status = noteflow_pb2.JOB_STATUS_FAILED
|
||||
return response
|
||||
return _create_diarization_error_response("Invalid meeting_id")
|
||||
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
meeting = await uow.meetings.get(MeetingId(meeting_uuid))
|
||||
else:
|
||||
store = self._get_memory_store()
|
||||
meeting = store.get(request.meeting_id)
|
||||
if meeting is None:
|
||||
response = noteflow_pb2.RefineSpeakerDiarizationResponse()
|
||||
response.segments_updated = 0
|
||||
response.speaker_ids[:] = []
|
||||
response.error_message = "Meeting not found"
|
||||
response.job_id = ""
|
||||
response.status = noteflow_pb2.JOB_STATUS_FAILED
|
||||
return response
|
||||
meeting_state = meeting.state
|
||||
if meeting_state in (
|
||||
MeetingState.UNSPECIFIED,
|
||||
MeetingState.CREATED,
|
||||
MeetingState.RECORDING,
|
||||
MeetingState.STOPPING,
|
||||
):
|
||||
response = noteflow_pb2.RefineSpeakerDiarizationResponse()
|
||||
response.segments_updated = 0
|
||||
response.speaker_ids[:] = []
|
||||
response.error_message = (
|
||||
f"Meeting must be stopped before refinement (state: {meeting_state.name.lower()})"
|
||||
async with self._create_repository_provider() as repo:
|
||||
meeting = await repo.meetings.get(parse_meeting_id(request.meeting_id))
|
||||
if meeting is None:
|
||||
return _create_diarization_error_response("Meeting not found")
|
||||
|
||||
meeting_state = meeting.state
|
||||
if meeting_state in (
|
||||
MeetingState.UNSPECIFIED,
|
||||
MeetingState.CREATED,
|
||||
MeetingState.RECORDING,
|
||||
MeetingState.STOPPING,
|
||||
):
|
||||
return _create_diarization_error_response(
|
||||
f"Meeting must be stopped before refinement (state: {meeting_state.name.lower()})"
|
||||
)
|
||||
|
||||
num_speakers = request.num_speakers if request.num_speakers > 0 else None
|
||||
|
||||
job_id = str(uuid4())
|
||||
job = DiarizationJob(
|
||||
job_id=job_id,
|
||||
meeting_id=request.meeting_id,
|
||||
status=noteflow_pb2.JOB_STATUS_QUEUED,
|
||||
)
|
||||
response.job_id = ""
|
||||
response.status = noteflow_pb2.JOB_STATUS_FAILED
|
||||
return response
|
||||
|
||||
num_speakers = request.num_speakers if request.num_speakers > 0 else None
|
||||
|
||||
job_id = str(uuid4())
|
||||
job = DiarizationJob(
|
||||
job_id=job_id,
|
||||
meeting_id=request.meeting_id,
|
||||
status=noteflow_pb2.JOB_STATUS_QUEUED,
|
||||
)
|
||||
|
||||
# Persist job to database
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
await uow.diarization_jobs.create(job)
|
||||
await uow.commit()
|
||||
else:
|
||||
self._diarization_jobs[job_id] = job
|
||||
# Persist job if DB supports it, otherwise use in-memory dict
|
||||
if repo.supports_diarization_jobs:
|
||||
await repo.diarization_jobs.create(job)
|
||||
await repo.commit()
|
||||
else:
|
||||
self._diarization_jobs[job_id] = job
|
||||
|
||||
# Create background task and store reference for potential cancellation
|
||||
task = asyncio.create_task(self._run_diarization_job(job_id, num_speakers))
|
||||
@@ -295,32 +315,31 @@ class DiarizationMixin:
|
||||
) -> None:
|
||||
"""Run background diarization job.
|
||||
|
||||
Updates job status in database as the job progresses.
|
||||
Updates job status in repository as the job progresses.
|
||||
"""
|
||||
# Get meeting_id from database
|
||||
# Get meeting_id and update status to RUNNING
|
||||
meeting_id: str | None = None
|
||||
job: DiarizationJob | None = None
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
job = await uow.diarization_jobs.get(job_id)
|
||||
async with self._create_repository_provider() as repo:
|
||||
if repo.supports_diarization_jobs:
|
||||
job = await repo.diarization_jobs.get(job_id)
|
||||
if job is None:
|
||||
logger.warning("Diarization job %s not found in database", job_id)
|
||||
return
|
||||
meeting_id = job.meeting_id
|
||||
# Update status to RUNNING
|
||||
await uow.diarization_jobs.update_status(
|
||||
await repo.diarization_jobs.update_status(
|
||||
job_id,
|
||||
noteflow_pb2.JOB_STATUS_RUNNING,
|
||||
)
|
||||
await uow.commit()
|
||||
else:
|
||||
job = self._diarization_jobs.get(job_id)
|
||||
if job is None:
|
||||
logger.warning("Diarization job %s not found in memory", job_id)
|
||||
return
|
||||
meeting_id = job.meeting_id
|
||||
job.status = noteflow_pb2.JOB_STATUS_RUNNING
|
||||
job.updated_at = datetime.now()
|
||||
await repo.commit()
|
||||
else:
|
||||
job = self._diarization_jobs.get(job_id)
|
||||
if job is None:
|
||||
logger.warning("Diarization job %s not found in memory", job_id)
|
||||
return
|
||||
meeting_id = job.meeting_id
|
||||
job.status = noteflow_pb2.JOB_STATUS_RUNNING
|
||||
job.updated_at = datetime.now()
|
||||
|
||||
try:
|
||||
updated_count = await self.refine_speaker_diarization(
|
||||
@@ -330,17 +349,16 @@ class DiarizationMixin:
|
||||
speaker_ids = await self._collect_speaker_ids(meeting_id)
|
||||
|
||||
# Update status to COMPLETED
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
await uow.diarization_jobs.update_status(
|
||||
async with self._create_repository_provider() as repo:
|
||||
if repo.supports_diarization_jobs:
|
||||
await repo.diarization_jobs.update_status(
|
||||
job_id,
|
||||
noteflow_pb2.JOB_STATUS_COMPLETED,
|
||||
segments_updated=updated_count,
|
||||
speaker_ids=speaker_ids,
|
||||
)
|
||||
await uow.commit()
|
||||
else:
|
||||
if job is not None:
|
||||
await repo.commit()
|
||||
elif job is not None:
|
||||
job.status = noteflow_pb2.JOB_STATUS_COMPLETED
|
||||
job.segments_updated = updated_count
|
||||
job.speaker_ids = speaker_ids
|
||||
@@ -349,16 +367,15 @@ class DiarizationMixin:
|
||||
except Exception as exc:
|
||||
logger.exception("Diarization failed for meeting %s", meeting_id)
|
||||
# Update status to FAILED
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
await uow.diarization_jobs.update_status(
|
||||
async with self._create_repository_provider() as repo:
|
||||
if repo.supports_diarization_jobs:
|
||||
await repo.diarization_jobs.update_status(
|
||||
job_id,
|
||||
noteflow_pb2.JOB_STATUS_FAILED,
|
||||
error_message=str(exc),
|
||||
)
|
||||
await uow.commit()
|
||||
else:
|
||||
if job is not None:
|
||||
await repo.commit()
|
||||
elif job is not None:
|
||||
job.status = noteflow_pb2.JOB_STATUS_FAILED
|
||||
job.error_message = str(exc)
|
||||
job.updated_at = datetime.now()
|
||||
@@ -452,53 +469,39 @@ class DiarizationMixin:
|
||||
"""Apply diarization turns to segments and return updated count."""
|
||||
updated_count = 0
|
||||
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
segments = await uow.segments.get_by_meeting(MeetingId(UUID(meeting_id)))
|
||||
for segment in segments:
|
||||
if segment.db_id is None:
|
||||
continue
|
||||
speaker_id, confidence = assign_speaker(
|
||||
segment.start_time,
|
||||
segment.end_time,
|
||||
turns,
|
||||
)
|
||||
if speaker_id is None:
|
||||
continue
|
||||
await uow.segments.update_speaker(
|
||||
segment.db_id,
|
||||
speaker_id,
|
||||
confidence,
|
||||
)
|
||||
updated_count += 1
|
||||
await uow.commit()
|
||||
else:
|
||||
store = self._get_memory_store()
|
||||
if meeting := store.get(meeting_id):
|
||||
for segment in meeting.segments:
|
||||
speaker_id, confidence = assign_speaker(
|
||||
segment.start_time,
|
||||
segment.end_time,
|
||||
turns,
|
||||
)
|
||||
if speaker_id is None:
|
||||
continue
|
||||
segment.speaker_id = speaker_id
|
||||
segment.speaker_confidence = confidence
|
||||
async with self._create_repository_provider() as repo:
|
||||
try:
|
||||
parsed_meeting_id = parse_meeting_id(meeting_id)
|
||||
except ValueError:
|
||||
logger.warning("Invalid meeting_id %s while applying diarization turns", meeting_id)
|
||||
return 0
|
||||
|
||||
segments = await repo.segments.get_by_meeting(parsed_meeting_id)
|
||||
for segment in segments:
|
||||
if _apply_speaker_to_segment(segment, turns):
|
||||
# For DB segments with db_id, use update_speaker
|
||||
if segment.db_id is not None:
|
||||
await repo.segments.update_speaker(
|
||||
segment.db_id,
|
||||
segment.speaker_id,
|
||||
segment.speaker_confidence,
|
||||
)
|
||||
updated_count += 1
|
||||
await repo.commit()
|
||||
|
||||
return updated_count
|
||||
|
||||
async def _collect_speaker_ids(self: ServicerHost, meeting_id: str) -> list[str]:
|
||||
"""Collect distinct speaker IDs for a meeting."""
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
segments = await uow.segments.get_by_meeting(MeetingId(UUID(meeting_id)))
|
||||
return sorted({s.speaker_id for s in segments if s.speaker_id})
|
||||
store = self._get_memory_store()
|
||||
if meeting := store.get(meeting_id):
|
||||
return sorted({s.speaker_id for s in meeting.segments if s.speaker_id})
|
||||
return []
|
||||
async with self._create_repository_provider() as repo:
|
||||
try:
|
||||
parsed_meeting_id = parse_meeting_id(meeting_id)
|
||||
except ValueError:
|
||||
logger.warning("Invalid meeting_id %s while collecting speaker ids", meeting_id)
|
||||
return []
|
||||
|
||||
segments = await repo.segments.get_by_meeting(parsed_meeting_id)
|
||||
return sorted({s.speaker_id for s in segments if s.speaker_id})
|
||||
|
||||
async def RenameSpeaker(
|
||||
self: ServicerHost,
|
||||
@@ -511,42 +514,35 @@ class DiarizationMixin:
|
||||
to use new_speaker_name instead.
|
||||
"""
|
||||
if not request.old_speaker_id or not request.new_speaker_name:
|
||||
await context.abort(
|
||||
grpc.StatusCode.INVALID_ARGUMENT,
|
||||
"old_speaker_id and new_speaker_name are required",
|
||||
await abort_invalid_argument(
|
||||
context, "old_speaker_id and new_speaker_name are required"
|
||||
)
|
||||
|
||||
try:
|
||||
meeting_uuid = UUID(request.meeting_id)
|
||||
meeting_id = parse_meeting_id(request.meeting_id)
|
||||
except ValueError:
|
||||
await context.abort(
|
||||
grpc.StatusCode.INVALID_ARGUMENT,
|
||||
"Invalid meeting_id",
|
||||
)
|
||||
await abort_invalid_argument(context, "Invalid meeting_id")
|
||||
|
||||
updated_count = 0
|
||||
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
segments = await uow.segments.get_by_meeting(MeetingId(meeting_uuid))
|
||||
async with self._create_repository_provider() as repo:
|
||||
segments = await repo.segments.get_by_meeting(meeting_id)
|
||||
|
||||
for segment in segments:
|
||||
if segment.speaker_id == request.old_speaker_id and segment.db_id:
|
||||
await uow.segments.update_speaker(
|
||||
for segment in segments:
|
||||
if segment.speaker_id == request.old_speaker_id:
|
||||
# For DB segments with db_id, use update_speaker
|
||||
if segment.db_id is not None:
|
||||
await repo.segments.update_speaker(
|
||||
segment.db_id,
|
||||
request.new_speaker_name,
|
||||
segment.speaker_confidence,
|
||||
)
|
||||
updated_count += 1
|
||||
|
||||
await uow.commit()
|
||||
else:
|
||||
store = self._get_memory_store()
|
||||
if meeting := store.get(request.meeting_id):
|
||||
for segment in meeting.segments:
|
||||
if segment.speaker_id == request.old_speaker_id:
|
||||
else:
|
||||
# Memory segments: update directly
|
||||
segment.speaker_id = request.new_speaker_name
|
||||
updated_count += 1
|
||||
updated_count += 1
|
||||
|
||||
await repo.commit()
|
||||
|
||||
return noteflow_pb2.RenameSpeakerResponse(
|
||||
segments_updated=updated_count,
|
||||
@@ -560,35 +556,23 @@ class DiarizationMixin:
|
||||
) -> noteflow_pb2.DiarizationJobStatus:
|
||||
"""Return current status for a diarization job.
|
||||
|
||||
Queries job state from database for persistence across restarts.
|
||||
Queries job state from repository for persistence across restarts.
|
||||
"""
|
||||
await self._prune_diarization_jobs()
|
||||
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
job = await uow.diarization_jobs.get(request.job_id)
|
||||
if job is None:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
"Diarization job not found",
|
||||
)
|
||||
return noteflow_pb2.DiarizationJobStatus(
|
||||
job_id=job.job_id,
|
||||
status=job.status,
|
||||
segments_updated=job.segments_updated,
|
||||
speaker_ids=job.speaker_ids,
|
||||
error_message=job.error_message,
|
||||
)
|
||||
job = self._diarization_jobs.get(request.job_id)
|
||||
if job is None:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
"Diarization job not found",
|
||||
async with self._create_repository_provider() as repo:
|
||||
if repo.supports_diarization_jobs:
|
||||
job = await repo.diarization_jobs.get(request.job_id)
|
||||
else:
|
||||
job = self._diarization_jobs.get(request.job_id)
|
||||
|
||||
if job is None:
|
||||
await abort_not_found(context, "Diarization job", request.job_id)
|
||||
|
||||
return noteflow_pb2.DiarizationJobStatus(
|
||||
job_id=job.job_id,
|
||||
status=job.status,
|
||||
segments_updated=job.segments_updated,
|
||||
speaker_ids=job.speaker_ids,
|
||||
error_message=job.error_message,
|
||||
)
|
||||
return noteflow_pb2.DiarizationJobStatus(
|
||||
job_id=job.job_id,
|
||||
status=job.status,
|
||||
segments_updated=job.segments_updated,
|
||||
speaker_ids=job.speaker_ids,
|
||||
error_message=job.error_message,
|
||||
)
|
||||
|
||||
116
src/noteflow/grpc/_mixins/errors.py
Normal file
116
src/noteflow/grpc/_mixins/errors.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Error response helpers for gRPC service mixins.
|
||||
|
||||
Consolidates the 34+ duplicated `await context.abort()` patterns
|
||||
into reusable helper functions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, NoReturn
|
||||
|
||||
import grpc
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import grpc.aio
|
||||
|
||||
|
||||
async def abort_not_found(
|
||||
context: grpc.aio.ServicerContext,
|
||||
entity_type: str,
|
||||
entity_id: str,
|
||||
) -> NoReturn:
|
||||
"""Abort with NOT_FOUND status.
|
||||
|
||||
Consolidates the repeated "entity not found" error pattern.
|
||||
|
||||
Args:
|
||||
context: gRPC servicer context.
|
||||
entity_type: Type of entity (e.g., "Meeting", "Annotation").
|
||||
entity_id: ID of the missing entity.
|
||||
|
||||
Raises:
|
||||
grpc.RpcError: Always raises NOT_FOUND.
|
||||
"""
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
f"{entity_type} {entity_id} not found",
|
||||
)
|
||||
# This line is unreachable but helps type checkers
|
||||
raise AssertionError("Unreachable")
|
||||
|
||||
|
||||
async def abort_database_required(
|
||||
context: grpc.aio.ServicerContext,
|
||||
feature: str,
|
||||
) -> NoReturn:
|
||||
"""Abort with UNIMPLEMENTED for DB-only features.
|
||||
|
||||
Consolidates the repeated "requires database persistence" pattern.
|
||||
|
||||
Args:
|
||||
context: gRPC servicer context.
|
||||
feature: Feature that requires database (e.g., "Annotations").
|
||||
|
||||
Raises:
|
||||
grpc.RpcError: Always raises UNIMPLEMENTED.
|
||||
"""
|
||||
await context.abort(
|
||||
grpc.StatusCode.UNIMPLEMENTED,
|
||||
f"{feature} require database persistence",
|
||||
)
|
||||
raise AssertionError("Unreachable")
|
||||
|
||||
|
||||
async def abort_invalid_argument(
|
||||
context: grpc.aio.ServicerContext,
|
||||
message: str,
|
||||
) -> NoReturn:
|
||||
"""Abort with INVALID_ARGUMENT status.
|
||||
|
||||
Args:
|
||||
context: gRPC servicer context.
|
||||
message: Error message describing the invalid argument.
|
||||
|
||||
Raises:
|
||||
grpc.RpcError: Always raises INVALID_ARGUMENT.
|
||||
"""
|
||||
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, message)
|
||||
raise AssertionError("Unreachable")
|
||||
|
||||
|
||||
async def abort_failed_precondition(
|
||||
context: grpc.aio.ServicerContext,
|
||||
message: str,
|
||||
) -> NoReturn:
|
||||
"""Abort with FAILED_PRECONDITION status.
|
||||
|
||||
Use when operation cannot proceed due to system state.
|
||||
|
||||
Args:
|
||||
context: gRPC servicer context.
|
||||
message: Error message describing the precondition failure.
|
||||
|
||||
Raises:
|
||||
grpc.RpcError: Always raises FAILED_PRECONDITION.
|
||||
"""
|
||||
await context.abort(grpc.StatusCode.FAILED_PRECONDITION, message)
|
||||
raise AssertionError("Unreachable")
|
||||
|
||||
|
||||
async def abort_internal(
|
||||
context: grpc.aio.ServicerContext,
|
||||
message: str,
|
||||
) -> NoReturn:
|
||||
"""Abort with INTERNAL status.
|
||||
|
||||
Use for unexpected server errors.
|
||||
|
||||
Args:
|
||||
context: gRPC servicer context.
|
||||
message: Error message describing the internal error.
|
||||
|
||||
Raises:
|
||||
grpc.RpcError: Always raises INTERNAL.
|
||||
"""
|
||||
await context.abort(grpc.StatusCode.INTERNAL, message)
|
||||
raise AssertionError("Unreachable")
|
||||
@@ -3,15 +3,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
import grpc.aio
|
||||
|
||||
from noteflow.application.services.export_service import ExportFormat, ExportService
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
|
||||
from ..proto import noteflow_pb2
|
||||
from .converters import proto_to_export_format
|
||||
from .converters import parse_meeting_id, proto_to_export_format
|
||||
from .errors import abort_invalid_argument, abort_not_found
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .protocols import ServicerHost
|
||||
@@ -21,7 +20,7 @@ class ExportMixin:
|
||||
"""Mixin providing export functionality.
|
||||
|
||||
Requires host to implement ServicerHost protocol.
|
||||
Export requires database persistence.
|
||||
Works with both database and memory backends via RepositoryProvider.
|
||||
"""
|
||||
|
||||
async def ExportTranscript(
|
||||
@@ -30,19 +29,19 @@ class ExportMixin:
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.ExportTranscriptResponse:
|
||||
"""Export meeting transcript to specified format."""
|
||||
if not self._use_database():
|
||||
await context.abort(
|
||||
grpc.StatusCode.UNIMPLEMENTED,
|
||||
"Export requires database persistence",
|
||||
)
|
||||
|
||||
# Map proto format to ExportFormat
|
||||
fmt = proto_to_export_format(request.format)
|
||||
|
||||
export_service = ExportService(self._create_uow())
|
||||
# Use unified repository provider - works with both DB and memory
|
||||
try:
|
||||
meeting_id = parse_meeting_id(request.meeting_id)
|
||||
except ValueError:
|
||||
await abort_invalid_argument(context, "Invalid meeting_id")
|
||||
|
||||
export_service = ExportService(self._create_repository_provider())
|
||||
try:
|
||||
content = await export_service.export_transcript(
|
||||
MeetingId(UUID(request.meeting_id)),
|
||||
meeting_id,
|
||||
fmt,
|
||||
)
|
||||
exporter_info = export_service.get_supported_formats()
|
||||
@@ -61,8 +60,5 @@ class ExportMixin:
|
||||
format_name=fmt_name,
|
||||
file_extension=fmt_ext,
|
||||
)
|
||||
except ValueError as e:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
str(e),
|
||||
)
|
||||
except ValueError:
|
||||
await abort_not_found(context, "Meeting", request.meeting_id)
|
||||
|
||||
@@ -2,25 +2,30 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
import grpc.aio
|
||||
|
||||
from noteflow.domain.entities import Meeting
|
||||
from noteflow.domain.value_objects import MeetingId, MeetingState
|
||||
from noteflow.domain.value_objects import MeetingState
|
||||
|
||||
from ..proto import noteflow_pb2
|
||||
from .converters import meeting_to_proto
|
||||
from .converters import meeting_to_proto, parse_meeting_id
|
||||
from .errors import abort_invalid_argument, abort_not_found
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .protocols import ServicerHost
|
||||
|
||||
# Timeout for waiting for stream to exit gracefully
|
||||
STOP_WAIT_TIMEOUT_SECONDS: float = 2.0
|
||||
|
||||
|
||||
class MeetingMixin:
|
||||
"""Mixin providing meeting CRUD functionality.
|
||||
|
||||
Requires host to implement ServicerHost protocol.
|
||||
Works with both database and memory backends via RepositoryProvider.
|
||||
"""
|
||||
|
||||
async def CreateMeeting(
|
||||
@@ -31,63 +36,61 @@ class MeetingMixin:
|
||||
"""Create a new meeting."""
|
||||
metadata = dict(request.metadata) if request.metadata else {}
|
||||
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
meeting = Meeting.create(title=request.title, metadata=metadata)
|
||||
saved = await uow.meetings.create(meeting)
|
||||
await uow.commit()
|
||||
return meeting_to_proto(saved)
|
||||
else:
|
||||
store = self._get_memory_store()
|
||||
meeting = store.create(title=request.title, metadata=metadata)
|
||||
return meeting_to_proto(meeting)
|
||||
async with self._create_repository_provider() as repo:
|
||||
meeting = Meeting.create(title=request.title, metadata=metadata)
|
||||
saved = await repo.meetings.create(meeting)
|
||||
await repo.commit()
|
||||
return meeting_to_proto(saved)
|
||||
|
||||
async def StopMeeting(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.StopMeetingRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.Meeting:
|
||||
"""Stop a meeting using graceful STOPPING -> STOPPED transition."""
|
||||
"""Stop a meeting using graceful STOPPING -> STOPPED transition.
|
||||
|
||||
If the meeting is actively streaming, signals the stream to stop
|
||||
and waits briefly for it to exit before closing resources.
|
||||
"""
|
||||
meeting_id = request.meeting_id
|
||||
|
||||
# Close audio writer if open
|
||||
# Signal stop to active stream and wait for graceful exit
|
||||
if meeting_id in self._active_streams:
|
||||
self._stop_requested.add(meeting_id)
|
||||
# Wait briefly for stream to detect stop request and exit
|
||||
wait_iterations = int(STOP_WAIT_TIMEOUT_SECONDS * 10) # 100ms intervals
|
||||
for _ in range(wait_iterations):
|
||||
if meeting_id not in self._active_streams:
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
# Clean up stop request even if stream didn't exit
|
||||
self._stop_requested.discard(meeting_id)
|
||||
|
||||
# Close audio writer if open (stream cleanup may have done this)
|
||||
if meeting_id in self._audio_writers:
|
||||
self._close_audio_writer(meeting_id)
|
||||
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
meeting = await uow.meetings.get(MeetingId(UUID(meeting_id)))
|
||||
if meeting is None:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
f"Meeting {meeting_id} not found",
|
||||
)
|
||||
try:
|
||||
# Graceful shutdown: RECORDING -> STOPPING -> STOPPED
|
||||
meeting.begin_stopping()
|
||||
meeting.stop_recording()
|
||||
except ValueError as e:
|
||||
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
||||
await uow.meetings.update(meeting)
|
||||
# Clean up streaming diarization turns (no longer needed)
|
||||
await uow.diarization_jobs.clear_streaming_turns(meeting_id)
|
||||
await uow.commit()
|
||||
return meeting_to_proto(meeting)
|
||||
store = self._get_memory_store()
|
||||
meeting = store.get(meeting_id)
|
||||
if meeting is None:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
f"Meeting {meeting_id} not found",
|
||||
)
|
||||
try:
|
||||
# Graceful shutdown: RECORDING -> STOPPING -> STOPPED
|
||||
meeting.begin_stopping()
|
||||
meeting.stop_recording()
|
||||
except ValueError as e:
|
||||
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
||||
store.update(meeting)
|
||||
return meeting_to_proto(meeting)
|
||||
async with self._create_repository_provider() as repo:
|
||||
try:
|
||||
parsed_meeting_id = parse_meeting_id(meeting_id)
|
||||
except ValueError:
|
||||
await abort_invalid_argument(context, "Invalid meeting_id")
|
||||
|
||||
meeting = await repo.meetings.get(parsed_meeting_id)
|
||||
if meeting is None:
|
||||
await abort_not_found(context, "Meeting", meeting_id)
|
||||
try:
|
||||
# Graceful shutdown: RECORDING -> STOPPING -> STOPPED
|
||||
meeting.begin_stopping()
|
||||
meeting.stop_recording()
|
||||
except ValueError as e:
|
||||
await abort_invalid_argument(context, str(e))
|
||||
await repo.meetings.update(meeting)
|
||||
# Clean up streaming diarization turns if DB supports it
|
||||
if repo.supports_diarization_jobs:
|
||||
await repo.diarization_jobs.clear_streaming_turns(meeting_id)
|
||||
await repo.commit()
|
||||
return meeting_to_proto(meeting)
|
||||
|
||||
async def ListMeetings(
|
||||
self: ServicerHost,
|
||||
@@ -98,24 +101,10 @@ class MeetingMixin:
|
||||
limit = request.limit or 100
|
||||
offset = request.offset or 0
|
||||
sort_desc = request.sort_order != noteflow_pb2.SORT_ORDER_CREATED_ASC
|
||||
states = [MeetingState(s) for s in request.states] if request.states else None
|
||||
|
||||
if self._use_database():
|
||||
states = [MeetingState(s) for s in request.states] if request.states else None
|
||||
async with self._create_uow() as uow:
|
||||
meetings, total = await uow.meetings.list_all(
|
||||
states=states,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort_desc=sort_desc,
|
||||
)
|
||||
return noteflow_pb2.ListMeetingsResponse(
|
||||
meetings=[meeting_to_proto(m, include_segments=False) for m in meetings],
|
||||
total_count=total,
|
||||
)
|
||||
else:
|
||||
store = self._get_memory_store()
|
||||
states = [MeetingState(s) for s in request.states] if request.states else None
|
||||
meetings, total = store.list_all(
|
||||
async with self._create_repository_provider() as repo:
|
||||
meetings, total = await repo.meetings.list_all(
|
||||
states=states,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
@@ -132,39 +121,28 @@ class MeetingMixin:
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.Meeting:
|
||||
"""Get meeting details."""
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
meeting = await uow.meetings.get(MeetingId(UUID(request.meeting_id)))
|
||||
if meeting is None:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
f"Meeting {request.meeting_id} not found",
|
||||
)
|
||||
# Load segments if requested
|
||||
if request.include_segments:
|
||||
segments = await uow.segments.get_by_meeting(meeting.id)
|
||||
meeting.segments = list(segments)
|
||||
# Load summary if requested
|
||||
if request.include_summary:
|
||||
summary = await uow.summaries.get_by_meeting(meeting.id)
|
||||
meeting.summary = summary
|
||||
return meeting_to_proto(
|
||||
meeting,
|
||||
include_segments=request.include_segments,
|
||||
include_summary=request.include_summary,
|
||||
)
|
||||
store = self._get_memory_store()
|
||||
meeting = store.get(request.meeting_id)
|
||||
if meeting is None:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
f"Meeting {request.meeting_id} not found",
|
||||
async with self._create_repository_provider() as repo:
|
||||
try:
|
||||
meeting_id = parse_meeting_id(request.meeting_id)
|
||||
except ValueError:
|
||||
await abort_invalid_argument(context, "Invalid meeting_id")
|
||||
|
||||
meeting = await repo.meetings.get(meeting_id)
|
||||
if meeting is None:
|
||||
await abort_not_found(context, "Meeting", request.meeting_id)
|
||||
# Load segments if requested
|
||||
if request.include_segments:
|
||||
segments = await repo.segments.get_by_meeting(meeting.id)
|
||||
meeting.segments = list(segments)
|
||||
# Load summary if requested
|
||||
if request.include_summary:
|
||||
summary = await repo.summaries.get_by_meeting(meeting.id)
|
||||
meeting.summary = summary
|
||||
return meeting_to_proto(
|
||||
meeting,
|
||||
include_segments=request.include_segments,
|
||||
include_summary=request.include_summary,
|
||||
)
|
||||
return meeting_to_proto(
|
||||
meeting,
|
||||
include_segments=request.include_segments,
|
||||
include_summary=request.include_summary,
|
||||
)
|
||||
|
||||
async def DeleteMeeting(
|
||||
self: ServicerHost,
|
||||
@@ -172,21 +150,14 @@ class MeetingMixin:
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.DeleteMeetingResponse:
|
||||
"""Delete a meeting."""
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
success = await uow.meetings.delete(MeetingId(UUID(request.meeting_id)))
|
||||
if success:
|
||||
await uow.commit()
|
||||
return noteflow_pb2.DeleteMeetingResponse(success=True)
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
f"Meeting {request.meeting_id} not found",
|
||||
)
|
||||
store = self._get_memory_store()
|
||||
success = store.delete(request.meeting_id)
|
||||
if not success:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
f"Meeting {request.meeting_id} not found",
|
||||
)
|
||||
return noteflow_pb2.DeleteMeetingResponse(success=True)
|
||||
async with self._create_repository_provider() as repo:
|
||||
try:
|
||||
meeting_id = parse_meeting_id(request.meeting_id)
|
||||
except ValueError:
|
||||
await abort_invalid_argument(context, "Invalid meeting_id")
|
||||
|
||||
success = await repo.meetings.delete(meeting_id)
|
||||
if success:
|
||||
await repo.commit()
|
||||
return noteflow_pb2.DeleteMeetingResponse(success=True)
|
||||
await abort_not_found(context, "Meeting", request.meeting_id)
|
||||
|
||||
@@ -13,9 +13,14 @@ if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from noteflow.domain.entities import Meeting
|
||||
from noteflow.domain.ports.unit_of_work import UnitOfWork
|
||||
from noteflow.infrastructure.asr import FasterWhisperEngine, Segmenter, StreamingVad
|
||||
from noteflow.infrastructure.audio.writer import MeetingAudioWriter
|
||||
from noteflow.infrastructure.diarization import DiarizationEngine, SpeakerTurn
|
||||
from noteflow.infrastructure.diarization import (
|
||||
DiarizationEngine,
|
||||
DiarizationSession,
|
||||
SpeakerTurn,
|
||||
)
|
||||
from noteflow.infrastructure.persistence.repositories import DiarizationJob
|
||||
from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWork
|
||||
from noteflow.infrastructure.security.crypto import AesGcmCryptoBox
|
||||
@@ -53,6 +58,7 @@ class ServicerHost(Protocol):
|
||||
_segment_counters: dict[str, int]
|
||||
_stream_formats: dict[str, tuple[int, int]]
|
||||
_active_streams: set[str]
|
||||
_stop_requested: set[str] # Meeting IDs with pending stop requests
|
||||
|
||||
# Partial transcription state per meeting
|
||||
_partial_buffers: dict[str, list[NDArray[np.float32]]]
|
||||
@@ -63,6 +69,7 @@ class ServicerHost(Protocol):
|
||||
_diarization_turns: dict[str, list[SpeakerTurn]]
|
||||
_diarization_stream_time: dict[str, float]
|
||||
_diarization_streaming_failed: set[str]
|
||||
_diarization_sessions: dict[str, DiarizationSession]
|
||||
|
||||
# Background diarization task references (for cancellation)
|
||||
_diarization_jobs: dict[str, DiarizationJob]
|
||||
@@ -84,7 +91,19 @@ class ServicerHost(Protocol):
|
||||
...
|
||||
|
||||
def _create_uow(self) -> SqlAlchemyUnitOfWork:
|
||||
"""Create a new Unit of Work."""
|
||||
"""Create a new Unit of Work (database-backed)."""
|
||||
...
|
||||
|
||||
def _create_repository_provider(self) -> UnitOfWork:
|
||||
"""Create a repository provider (database or memory backed).
|
||||
|
||||
Returns a UnitOfWork implementation appropriate for the current
|
||||
configuration. Use this for operations that can work with either
|
||||
backend, eliminating the need for if/else branching.
|
||||
|
||||
Returns:
|
||||
SqlAlchemyUnitOfWork if database configured, MemoryUnitOfWork otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
def _next_segment_id(self, meeting_id: str, fallback: int = 0) -> int:
|
||||
|
||||
@@ -8,19 +8,25 @@ import time
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
import grpc.aio
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
from noteflow.infrastructure.diarization import SpeakerTurn
|
||||
|
||||
from ..proto import noteflow_pb2
|
||||
from .converters import create_segment_from_asr, create_vad_update, segment_to_proto_update
|
||||
from .converters import (
|
||||
create_segment_from_asr,
|
||||
create_vad_update,
|
||||
parse_meeting_id,
|
||||
segment_to_proto_update,
|
||||
)
|
||||
from .errors import abort_failed_precondition, abort_invalid_argument
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.domain.entities import Segment
|
||||
|
||||
from .protocols import ServicerHost
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -55,12 +61,10 @@ class StreamingMixin:
|
||||
|
||||
Receive audio chunks from client, process through ASR,
|
||||
persist segments, and yield transcript updates.
|
||||
Works with both database and memory backends via RepositoryProvider.
|
||||
"""
|
||||
if self._asr_engine is None or not self._asr_engine.is_loaded:
|
||||
await context.abort(
|
||||
grpc.StatusCode.FAILED_PRECONDITION,
|
||||
"ASR engine not loaded",
|
||||
)
|
||||
await abort_failed_precondition(context, "ASR engine not loaded")
|
||||
|
||||
current_meeting_id: str | None = None
|
||||
|
||||
@@ -68,10 +72,7 @@ class StreamingMixin:
|
||||
async for chunk in request_iterator:
|
||||
meeting_id = chunk.meeting_id
|
||||
if not meeting_id:
|
||||
await context.abort(
|
||||
grpc.StatusCode.INVALID_ARGUMENT,
|
||||
"meeting_id required",
|
||||
)
|
||||
await abort_invalid_argument(context, "meeting_id required")
|
||||
|
||||
# Initialize stream on first chunk
|
||||
if current_meeting_id is None:
|
||||
@@ -80,11 +81,18 @@ class StreamingMixin:
|
||||
return # Error already sent via context.abort
|
||||
current_meeting_id = meeting_id
|
||||
elif meeting_id != current_meeting_id:
|
||||
await context.abort(
|
||||
grpc.StatusCode.INVALID_ARGUMENT,
|
||||
"Stream may only contain a single meeting_id",
|
||||
await abort_invalid_argument(
|
||||
context, "Stream may only contain a single meeting_id"
|
||||
)
|
||||
|
||||
# Check for stop request (graceful shutdown from StopMeeting)
|
||||
if current_meeting_id in self._stop_requested:
|
||||
logger.info(
|
||||
"Stop requested for meeting %s, exiting stream gracefully",
|
||||
current_meeting_id,
|
||||
)
|
||||
break
|
||||
|
||||
# Process audio chunk
|
||||
async for update in self._process_stream_chunk(current_meeting_id, chunk, context):
|
||||
yield update
|
||||
@@ -112,6 +120,8 @@ class StreamingMixin:
|
||||
) -> _StreamSessionInit | None:
|
||||
"""Initialize streaming for a meeting.
|
||||
|
||||
Uses unified repository provider for both DB and memory backends.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting ID string.
|
||||
context: gRPC context for error handling.
|
||||
@@ -120,17 +130,11 @@ class StreamingMixin:
|
||||
Initialization result, or None if error was sent.
|
||||
"""
|
||||
if meeting_id in self._active_streams:
|
||||
await context.abort(
|
||||
grpc.StatusCode.FAILED_PRECONDITION,
|
||||
f"Meeting {meeting_id} already streaming",
|
||||
)
|
||||
await abort_failed_precondition(context, f"Meeting {meeting_id} already streaming")
|
||||
|
||||
self._active_streams.add(meeting_id)
|
||||
|
||||
if self._use_database():
|
||||
init_result = await self._init_stream_session_db(meeting_id)
|
||||
else:
|
||||
init_result = self._init_stream_session_memory(meeting_id)
|
||||
init_result = await self._init_stream_session(meeting_id)
|
||||
|
||||
if not init_result.success:
|
||||
self._active_streams.discard(meeting_id)
|
||||
@@ -138,11 +142,11 @@ class StreamingMixin:
|
||||
|
||||
return init_result
|
||||
|
||||
async def _init_stream_session_db(
|
||||
async def _init_stream_session(
|
||||
self: ServicerHost,
|
||||
meeting_id: str,
|
||||
) -> _StreamSessionInit:
|
||||
"""Initialize stream session using database persistence.
|
||||
"""Initialize stream session using unified repository provider.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting ID string.
|
||||
@@ -150,8 +154,17 @@ class StreamingMixin:
|
||||
Returns:
|
||||
Stream session initialization result.
|
||||
"""
|
||||
async with self._create_uow() as uow:
|
||||
meeting = await uow.meetings.get(MeetingId(UUID(meeting_id)))
|
||||
async with self._create_repository_provider() as repo:
|
||||
try:
|
||||
parsed_meeting_id = parse_meeting_id(meeting_id)
|
||||
except ValueError:
|
||||
return _StreamSessionInit(
|
||||
next_segment_id=0,
|
||||
error_code=grpc.StatusCode.INVALID_ARGUMENT,
|
||||
error_message="Invalid meeting_id",
|
||||
)
|
||||
|
||||
meeting = await repo.meetings.get(parsed_meeting_id)
|
||||
if meeting is None:
|
||||
return _StreamSessionInit(
|
||||
next_segment_id=0,
|
||||
@@ -170,82 +183,43 @@ class StreamingMixin:
|
||||
)
|
||||
|
||||
if dek_updated or recording_updated:
|
||||
await uow.meetings.update(meeting)
|
||||
await uow.commit()
|
||||
await repo.meetings.update(meeting)
|
||||
await repo.commit()
|
||||
|
||||
next_segment_id = await uow.segments.get_next_segment_id(meeting.id)
|
||||
next_segment_id = await repo.segments.get_next_segment_id(meeting.id)
|
||||
self._open_meeting_audio_writer(
|
||||
meeting_id, dek, wrapped_dek, asset_path=meeting.asset_path
|
||||
)
|
||||
self._init_streaming_state(meeting_id, next_segment_id)
|
||||
|
||||
# Load any persisted streaming turns (crash recovery)
|
||||
persisted_turns = await uow.diarization_jobs.get_streaming_turns(meeting_id)
|
||||
if persisted_turns:
|
||||
domain_turns = [
|
||||
SpeakerTurn(
|
||||
speaker=t.speaker,
|
||||
start=t.start_time,
|
||||
end=t.end_time,
|
||||
confidence=t.confidence,
|
||||
# Load any persisted streaming turns (crash recovery) - DB only
|
||||
if repo.supports_diarization_jobs:
|
||||
persisted_turns = await repo.diarization_jobs.get_streaming_turns(meeting_id)
|
||||
if persisted_turns:
|
||||
domain_turns = [
|
||||
SpeakerTurn(
|
||||
speaker=t.speaker,
|
||||
start=t.start_time,
|
||||
end=t.end_time,
|
||||
confidence=t.confidence,
|
||||
)
|
||||
for t in persisted_turns
|
||||
]
|
||||
self._diarization_turns[meeting_id] = domain_turns
|
||||
# Advance stream time to avoid overlapping recovered turns
|
||||
last_end = max(t.end_time for t in persisted_turns)
|
||||
self._diarization_stream_time[meeting_id] = max(
|
||||
self._diarization_stream_time.get(meeting_id, 0.0),
|
||||
last_end,
|
||||
)
|
||||
logger.info(
|
||||
"Loaded %d streaming diarization turns for meeting %s",
|
||||
len(domain_turns),
|
||||
meeting_id,
|
||||
)
|
||||
for t in persisted_turns
|
||||
]
|
||||
self._diarization_turns[meeting_id] = domain_turns
|
||||
# Advance stream time to avoid overlapping recovered turns
|
||||
last_end = max(t.end_time for t in persisted_turns)
|
||||
self._diarization_stream_time[meeting_id] = max(
|
||||
self._diarization_stream_time.get(meeting_id, 0.0),
|
||||
last_end,
|
||||
)
|
||||
logger.info(
|
||||
"Loaded %d streaming diarization turns for meeting %s",
|
||||
len(domain_turns),
|
||||
meeting_id,
|
||||
)
|
||||
|
||||
return _StreamSessionInit(next_segment_id=next_segment_id)
|
||||
|
||||
def _init_stream_session_memory(
|
||||
self: ServicerHost,
|
||||
meeting_id: str,
|
||||
) -> _StreamSessionInit:
|
||||
"""Initialize stream session using in-memory store.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting ID string.
|
||||
|
||||
Returns:
|
||||
Stream session initialization result.
|
||||
"""
|
||||
store = self._get_memory_store()
|
||||
meeting = store.get(meeting_id)
|
||||
if meeting is None:
|
||||
return _StreamSessionInit(
|
||||
next_segment_id=0,
|
||||
error_code=grpc.StatusCode.NOT_FOUND,
|
||||
error_message=f"Meeting {meeting_id} not found",
|
||||
)
|
||||
|
||||
dek, wrapped_dek, dek_updated = self._ensure_meeting_dek(meeting)
|
||||
recording_updated, error_msg = self._start_meeting_if_needed(meeting)
|
||||
|
||||
if error_msg:
|
||||
return _StreamSessionInit(
|
||||
next_segment_id=0,
|
||||
error_code=grpc.StatusCode.INVALID_ARGUMENT,
|
||||
error_message=error_msg,
|
||||
)
|
||||
|
||||
if dek_updated or recording_updated:
|
||||
store.update(meeting)
|
||||
|
||||
next_segment_id = meeting.next_segment_id
|
||||
self._open_meeting_audio_writer(meeting_id, dek, wrapped_dek, asset_path=meeting.asset_path)
|
||||
self._init_streaming_state(meeting_id, next_segment_id)
|
||||
|
||||
return _StreamSessionInit(next_segment_id=next_segment_id)
|
||||
|
||||
async def _process_stream_chunk(
|
||||
self: ServicerHost,
|
||||
meeting_id: str,
|
||||
@@ -269,7 +243,7 @@ class StreamingMixin:
|
||||
chunk.channels,
|
||||
)
|
||||
except ValueError as e:
|
||||
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
||||
await abort_invalid_argument(context, str(e))
|
||||
|
||||
audio = self._decode_audio_chunk(chunk)
|
||||
if audio is None:
|
||||
@@ -278,7 +252,7 @@ class StreamingMixin:
|
||||
try:
|
||||
audio = self._convert_audio_format(audio, sample_rate, channels)
|
||||
except ValueError as e:
|
||||
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
||||
await abort_invalid_argument(context, str(e))
|
||||
|
||||
# Write to encrypted audio file
|
||||
self._write_audio_chunk_safe(meeting_id, audio)
|
||||
@@ -550,6 +524,9 @@ class StreamingMixin:
|
||||
) -> AsyncIterator[noteflow_pb2.TranscriptUpdate]:
|
||||
"""Process a complete audio segment through ASR.
|
||||
|
||||
Uses unified repository provider for both DB and memory backends.
|
||||
Batches all segments from ASR results and commits once per audio chunk.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting identifier.
|
||||
audio: Complete audio segment.
|
||||
@@ -561,37 +538,21 @@ class StreamingMixin:
|
||||
if len(audio) == 0 or self._asr_engine is None:
|
||||
return
|
||||
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
meeting = await uow.meetings.get(MeetingId(UUID(meeting_id)))
|
||||
if meeting is None:
|
||||
return
|
||||
async with self._create_repository_provider() as repo:
|
||||
try:
|
||||
parsed_meeting_id = parse_meeting_id(meeting_id)
|
||||
except ValueError:
|
||||
logger.warning("Invalid meeting_id %s in streaming segment", meeting_id)
|
||||
return
|
||||
|
||||
results = await self._asr_engine.transcribe_async(audio)
|
||||
for result in results:
|
||||
segment_id = self._next_segment_id(
|
||||
meeting_id,
|
||||
fallback=meeting.next_segment_id,
|
||||
)
|
||||
segment = create_segment_from_asr(
|
||||
meeting.id,
|
||||
segment_id,
|
||||
result,
|
||||
segment_start_time,
|
||||
)
|
||||
# Call diarization mixin method if available
|
||||
if hasattr(self, "_maybe_assign_speaker"):
|
||||
self._maybe_assign_speaker(meeting_id, segment)
|
||||
meeting.add_segment(segment)
|
||||
await uow.segments.add(meeting.id, segment)
|
||||
await uow.commit()
|
||||
yield segment_to_proto_update(meeting_id, segment)
|
||||
else:
|
||||
store = self._get_memory_store()
|
||||
meeting = store.get(meeting_id)
|
||||
meeting = await repo.meetings.get(parsed_meeting_id)
|
||||
if meeting is None:
|
||||
return
|
||||
|
||||
results = await self._asr_engine.transcribe_async(audio)
|
||||
|
||||
# Build all segments first
|
||||
segments_to_add: list[tuple[Segment, noteflow_pb2.TranscriptUpdate]] = []
|
||||
for result in results:
|
||||
segment_id = self._next_segment_id(
|
||||
meeting_id,
|
||||
@@ -606,5 +567,13 @@ class StreamingMixin:
|
||||
# Call diarization mixin method if available
|
||||
if hasattr(self, "_maybe_assign_speaker"):
|
||||
self._maybe_assign_speaker(meeting_id, segment)
|
||||
store.add_segment(meeting_id, segment)
|
||||
yield segment_to_proto_update(meeting_id, segment)
|
||||
await repo.segments.add(meeting.id, segment)
|
||||
segments_to_add.append((segment, segment_to_proto_update(meeting_id, segment)))
|
||||
|
||||
# Single commit for all segments in this audio chunk
|
||||
if segments_to_add:
|
||||
await repo.commit()
|
||||
|
||||
# Yield updates after commit
|
||||
for _, update in segments_to_add:
|
||||
yield update
|
||||
|
||||
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
import grpc.aio
|
||||
|
||||
@@ -13,7 +12,8 @@ from noteflow.domain.summarization import ProviderUnavailableError
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
|
||||
from ..proto import noteflow_pb2
|
||||
from .converters import summary_to_proto
|
||||
from .converters import parse_meeting_id, summary_to_proto
|
||||
from .errors import abort_invalid_argument, abort_not_found
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.application.services.summarization_service import SummarizationService
|
||||
@@ -27,6 +27,7 @@ class SummarizationMixin:
|
||||
"""Mixin providing summarization functionality.
|
||||
|
||||
Requires host to implement ServicerHost protocol.
|
||||
Works with both database and memory backends via RepositoryProvider.
|
||||
"""
|
||||
|
||||
_summarization_service: SummarizationService | None
|
||||
@@ -36,70 +37,38 @@ class SummarizationMixin:
|
||||
request: noteflow_pb2.GenerateSummaryRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.Summary:
|
||||
"""Generate meeting summary using SummarizationService with fallback."""
|
||||
if self._use_database():
|
||||
return await self._generate_summary_db(request, context)
|
||||
"""Generate meeting summary using SummarizationService with fallback.
|
||||
|
||||
return await self._generate_summary_memory(request, context)
|
||||
|
||||
async def _generate_summary_db(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.GenerateSummaryRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.Summary:
|
||||
"""Generate summary for a meeting stored in the database.
|
||||
|
||||
The potentially slow summarization step is executed outside the UoW to
|
||||
avoid holding database connections while waiting on LLMs.
|
||||
The potentially slow summarization step is executed outside the repository
|
||||
context to avoid holding connections while waiting on LLMs.
|
||||
"""
|
||||
meeting_id = MeetingId(UUID(request.meeting_id))
|
||||
try:
|
||||
meeting_id = parse_meeting_id(request.meeting_id)
|
||||
except ValueError:
|
||||
await abort_invalid_argument(context, "Invalid meeting_id")
|
||||
|
||||
# 1) Load meeting, existing summary, and segments inside a short UoW
|
||||
async with self._create_uow() as uow:
|
||||
meeting = await uow.meetings.get(meeting_id)
|
||||
# 1) Load meeting, existing summary, and segments in a short transaction
|
||||
async with self._create_repository_provider() as repo:
|
||||
meeting = await repo.meetings.get(meeting_id)
|
||||
if meeting is None:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
f"Meeting {request.meeting_id} not found",
|
||||
)
|
||||
await abort_not_found(context, "Meeting", request.meeting_id)
|
||||
|
||||
existing = await uow.summaries.get_by_meeting(meeting.id)
|
||||
existing = await repo.summaries.get_by_meeting(meeting.id)
|
||||
if existing and not request.force_regenerate:
|
||||
return summary_to_proto(existing)
|
||||
|
||||
segments = list(await uow.segments.get_by_meeting(meeting.id))
|
||||
segments = list(await repo.segments.get_by_meeting(meeting.id))
|
||||
|
||||
# 2) Run summarization outside DB transaction
|
||||
# 2) Run summarization outside repository context (slow LLM call)
|
||||
summary = await self._summarize_or_placeholder(meeting_id, segments)
|
||||
|
||||
# 3) Persist in a fresh UoW
|
||||
async with self._create_uow() as uow:
|
||||
saved = await uow.summaries.save(summary)
|
||||
await uow.commit()
|
||||
# 3) Persist in a fresh transaction
|
||||
async with self._create_repository_provider() as repo:
|
||||
saved = await repo.summaries.save(summary)
|
||||
await repo.commit()
|
||||
|
||||
return summary_to_proto(saved)
|
||||
|
||||
async def _generate_summary_memory(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.GenerateSummaryRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.Summary:
|
||||
"""Generate summary for meetings held in the in-memory store."""
|
||||
store = self._get_memory_store()
|
||||
meeting = store.get(request.meeting_id)
|
||||
if meeting is None:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
f"Meeting {request.meeting_id} not found",
|
||||
)
|
||||
|
||||
if meeting.summary and not request.force_regenerate:
|
||||
return summary_to_proto(meeting.summary)
|
||||
|
||||
summary = await self._summarize_or_placeholder(meeting.id, meeting.segments)
|
||||
store.set_summary(request.meeting_id, summary)
|
||||
return summary_to_proto(summary)
|
||||
|
||||
async def _summarize_or_placeholder(
|
||||
self: ServicerHost,
|
||||
meeting_id: MeetingId,
|
||||
|
||||
@@ -420,7 +420,11 @@ class NoteFlowClient:
|
||||
|
||||
if self._stream_thread:
|
||||
self._stream_thread.join(timeout=2.0)
|
||||
self._stream_thread = None
|
||||
# Only clear reference if thread exited cleanly
|
||||
if self._stream_thread.is_alive():
|
||||
logger.warning("Stream thread did not exit within timeout, keeping reference")
|
||||
else:
|
||||
self._stream_thread = None
|
||||
|
||||
self._current_meeting_id = None
|
||||
logger.info("Stopped streaming")
|
||||
|
||||
@@ -46,6 +46,19 @@ class MeetingStore:
|
||||
|
||||
return meeting
|
||||
|
||||
def insert(self, meeting: Meeting) -> Meeting:
|
||||
"""Insert an existing meeting into the store.
|
||||
|
||||
Args:
|
||||
meeting: Meeting entity to store.
|
||||
|
||||
Returns:
|
||||
Stored meeting.
|
||||
"""
|
||||
with self._lock:
|
||||
self._meetings[str(meeting.id)] = meeting
|
||||
return meeting
|
||||
|
||||
def get(self, meeting_id: str) -> Meeting | None:
|
||||
"""Get a meeting by ID.
|
||||
|
||||
@@ -58,6 +71,18 @@ class MeetingStore:
|
||||
with self._lock:
|
||||
return self._meetings.get(meeting_id)
|
||||
|
||||
def count_by_state(self, state: MeetingState) -> int:
|
||||
"""Count meetings in a specific state.
|
||||
|
||||
Args:
|
||||
state: Meeting state to count.
|
||||
|
||||
Returns:
|
||||
Number of meetings in the specified state.
|
||||
"""
|
||||
with self._lock:
|
||||
return sum(m.state == state for m in self._meetings.values())
|
||||
|
||||
def list_all(
|
||||
self,
|
||||
states: Sequence[MeetingState] | None = None,
|
||||
@@ -94,6 +119,22 @@ class MeetingStore:
|
||||
|
||||
return meetings, total
|
||||
|
||||
def find_older_than(self, cutoff: datetime) -> list[Meeting]:
|
||||
"""Find completed meetings older than cutoff date.
|
||||
|
||||
Args:
|
||||
cutoff: Cutoff datetime; meetings ended before this are returned.
|
||||
|
||||
Returns:
|
||||
List of meetings with ended_at before cutoff.
|
||||
"""
|
||||
with self._lock:
|
||||
return [
|
||||
m
|
||||
for m in self._meetings.values()
|
||||
if m.ended_at is not None and m.ended_at < cutoff
|
||||
]
|
||||
|
||||
def update(self, meeting: Meeting) -> Meeting:
|
||||
"""Update a meeting in the store.
|
||||
|
||||
@@ -125,6 +166,19 @@ class MeetingStore:
|
||||
meeting.add_segment(segment)
|
||||
return meeting
|
||||
|
||||
def get_segments(self, meeting_id: str) -> list[Segment]:
|
||||
"""Get a copy of segments for a meeting.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting ID.
|
||||
|
||||
Returns:
|
||||
List of segments (copy), or empty list if meeting not found.
|
||||
"""
|
||||
with self._lock:
|
||||
meeting = self._meetings.get(meeting_id)
|
||||
return [] if meeting is None else list(meeting.segments)
|
||||
|
||||
def set_summary(self, meeting_id: str, summary: Summary) -> Meeting | None:
|
||||
"""Set meeting summary.
|
||||
|
||||
@@ -143,6 +197,35 @@ class MeetingStore:
|
||||
meeting.summary = summary
|
||||
return meeting
|
||||
|
||||
def get_summary(self, meeting_id: str) -> Summary | None:
|
||||
"""Get meeting summary.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting ID.
|
||||
|
||||
Returns:
|
||||
Summary or None if missing.
|
||||
"""
|
||||
with self._lock:
|
||||
meeting = self._meetings.get(meeting_id)
|
||||
return meeting.summary if meeting else None
|
||||
|
||||
def clear_summary(self, meeting_id: str) -> bool:
|
||||
"""Clear meeting summary.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting ID.
|
||||
|
||||
Returns:
|
||||
True if cleared, False if meeting not found or no summary set.
|
||||
"""
|
||||
with self._lock:
|
||||
meeting = self._meetings.get(meeting_id)
|
||||
if meeting is None or meeting.summary is None:
|
||||
return False
|
||||
meeting.summary = None
|
||||
return True
|
||||
|
||||
def update_state(self, meeting_id: str, state: MeetingState) -> bool:
|
||||
"""Atomically update meeting state.
|
||||
|
||||
@@ -194,6 +277,21 @@ class MeetingStore:
|
||||
meeting.end_time = end_time
|
||||
return True
|
||||
|
||||
def get_next_segment_id(self, meeting_id: str) -> int:
|
||||
"""Get next segment ID for a meeting.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting ID.
|
||||
|
||||
Returns:
|
||||
Next segment ID (0 if meeting or segments missing).
|
||||
"""
|
||||
with self._lock:
|
||||
meeting = self._meetings.get(meeting_id)
|
||||
if meeting is None or not meeting.segments:
|
||||
return 0
|
||||
return max(s.segment_id for s in meeting.segments) + 1
|
||||
|
||||
def delete(self, meeting_id: str) -> bool:
|
||||
"""Delete a meeting.
|
||||
|
||||
|
||||
@@ -14,9 +14,12 @@ import numpy as np
|
||||
|
||||
from noteflow.config.constants import DEFAULT_SAMPLE_RATE as _DEFAULT_SAMPLE_RATE
|
||||
from noteflow.domain.entities import Meeting
|
||||
from noteflow.domain.ports.unit_of_work import UnitOfWork
|
||||
from noteflow.domain.value_objects import MeetingState
|
||||
from noteflow.infrastructure.asr import Segmenter, SegmenterConfig, StreamingVad
|
||||
from noteflow.infrastructure.audio.writer import MeetingAudioWriter
|
||||
from noteflow.infrastructure.diarization import DiarizationSession
|
||||
from noteflow.infrastructure.persistence.memory import MemoryUnitOfWork
|
||||
from noteflow.infrastructure.persistence.repositories import DiarizationJob
|
||||
from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWork
|
||||
from noteflow.infrastructure.security.crypto import AesGcmCryptoBox
|
||||
@@ -107,6 +110,7 @@ class NoteFlowServicer(
|
||||
self._segment_counters: dict[str, int] = {}
|
||||
self._stream_formats: dict[str, tuple[int, int]] = {}
|
||||
self._active_streams: set[str] = set()
|
||||
self._stop_requested: set[str] = set()
|
||||
|
||||
# Partial transcription state per meeting
|
||||
self._partial_buffers: dict[str, list[NDArray[np.float32]]] = {}
|
||||
@@ -117,6 +121,7 @@ class NoteFlowServicer(
|
||||
self._diarization_turns: dict[str, list[SpeakerTurn]] = {}
|
||||
self._diarization_stream_time: dict[str, float] = {}
|
||||
self._diarization_streaming_failed: set[str] = set()
|
||||
self._diarization_sessions: dict[str, DiarizationSession] = {}
|
||||
|
||||
# Track audio write failures to avoid log spam
|
||||
self._audio_write_failed: set[str] = set()
|
||||
@@ -155,11 +160,25 @@ class NoteFlowServicer(
|
||||
return self._memory_store
|
||||
|
||||
def _create_uow(self) -> SqlAlchemyUnitOfWork:
|
||||
"""Create a new Unit of Work."""
|
||||
"""Create a new Unit of Work (database-backed)."""
|
||||
if self._session_factory is None:
|
||||
raise RuntimeError("Database not configured")
|
||||
return SqlAlchemyUnitOfWork(self._session_factory)
|
||||
|
||||
def _create_repository_provider(self) -> UnitOfWork:
|
||||
"""Create a repository provider (database or memory backed).
|
||||
|
||||
Returns a UnitOfWork implementation appropriate for the current
|
||||
configuration. Use this for operations that can work with either
|
||||
backend, eliminating the need for if/else branching.
|
||||
|
||||
Returns:
|
||||
SqlAlchemyUnitOfWork if database configured, MemoryUnitOfWork otherwise.
|
||||
"""
|
||||
if self._session_factory is not None:
|
||||
return SqlAlchemyUnitOfWork(self._session_factory)
|
||||
return MemoryUnitOfWork(self._get_memory_store())
|
||||
|
||||
def _init_streaming_state(self, meeting_id: str, next_segment_id: int) -> None:
|
||||
"""Initialize VAD, Segmenter, speaking state, and partial buffers for a meeting."""
|
||||
self._vad_instances[meeting_id] = StreamingVad()
|
||||
@@ -174,8 +193,8 @@ class NoteFlowServicer(
|
||||
self._diarization_turns[meeting_id] = []
|
||||
self._diarization_stream_time[meeting_id] = 0.0
|
||||
self._diarization_streaming_failed.discard(meeting_id)
|
||||
if self._diarization_engine is not None:
|
||||
self._diarization_engine.reset_streaming()
|
||||
# NOTE: Per-meeting diarization sessions are created lazily in
|
||||
# _process_streaming_diarization() to avoid blocking on model load
|
||||
|
||||
def _cleanup_streaming_state(self, meeting_id: str) -> None:
|
||||
"""Clean up VAD, Segmenter, speaking state, and partial buffers for a meeting."""
|
||||
@@ -191,6 +210,10 @@ class NoteFlowServicer(
|
||||
self._diarization_stream_time.pop(meeting_id, None)
|
||||
self._diarization_streaming_failed.discard(meeting_id)
|
||||
|
||||
# Clean up per-meeting diarization session
|
||||
if session := self._diarization_sessions.pop(meeting_id, None):
|
||||
session.close()
|
||||
|
||||
def _ensure_meeting_dek(self, meeting: Meeting) -> tuple[bytes, bytes, bool]:
|
||||
"""Ensure meeting has a DEK, generating one if needed.
|
||||
|
||||
@@ -344,6 +367,12 @@ class NoteFlowServicer(
|
||||
|
||||
self._diarization_tasks.clear()
|
||||
|
||||
# Close all diarization sessions
|
||||
for meeting_id, session in list(self._diarization_sessions.items()):
|
||||
logger.debug("Closing diarization session for meeting %s", meeting_id)
|
||||
session.close()
|
||||
self._diarization_sessions.clear()
|
||||
|
||||
# Close all audio writers
|
||||
for meeting_id in list(self._audio_writers.keys()):
|
||||
logger.debug("Closing audio writer for meeting %s", meeting_id)
|
||||
|
||||
@@ -9,9 +9,11 @@ from noteflow.infrastructure.diarization.assigner import (
|
||||
)
|
||||
from noteflow.infrastructure.diarization.dto import SpeakerTurn
|
||||
from noteflow.infrastructure.diarization.engine import DiarizationEngine
|
||||
from noteflow.infrastructure.diarization.session import DiarizationSession
|
||||
|
||||
__all__ = [
|
||||
"DiarizationEngine",
|
||||
"DiarizationSession",
|
||||
"SpeakerTurn",
|
||||
"assign_speaker",
|
||||
"assign_speakers_batch",
|
||||
|
||||
@@ -9,10 +9,12 @@ Requires optional dependencies: pip install noteflow[diarization]
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.config.constants import DEFAULT_SAMPLE_RATE
|
||||
from noteflow.infrastructure.diarization.dto import SpeakerTurn
|
||||
from noteflow.infrastructure.diarization.session import DiarizationSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
@@ -60,6 +62,10 @@ class DiarizationEngine:
|
||||
self._streaming_pipeline = None
|
||||
self._offline_pipeline = None
|
||||
|
||||
# Shared models for per-session pipelines (loaded once, reused)
|
||||
self._segmentation_model = None
|
||||
self._embedding_model = None
|
||||
|
||||
def _resolve_device(self) -> str:
|
||||
"""Resolve the actual device to use based on availability.
|
||||
|
||||
@@ -133,6 +139,80 @@ class DiarizationEngine:
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load streaming diarization model: {e}") from e
|
||||
|
||||
def _ensure_streaming_models_loaded(self) -> None:
|
||||
"""Ensure shared streaming models are loaded.
|
||||
|
||||
Loads the segmentation and embedding models that are shared across
|
||||
all streaming sessions. These models are stateless and can be safely
|
||||
reused by multiple concurrent sessions.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model loading fails.
|
||||
ValueError: If HuggingFace token is not provided.
|
||||
"""
|
||||
if self._segmentation_model is not None and self._embedding_model is not None:
|
||||
return
|
||||
|
||||
if not self._hf_token:
|
||||
raise ValueError("HuggingFace token required for pyannote models")
|
||||
|
||||
device = self._resolve_device()
|
||||
logger.info("Loading shared streaming diarization models on %s...", device)
|
||||
|
||||
try:
|
||||
from diart.models import EmbeddingModel, SegmentationModel
|
||||
|
||||
self._segmentation_model = SegmentationModel.from_pretrained(
|
||||
"pyannote/segmentation-3.0",
|
||||
use_hf_token=self._hf_token,
|
||||
)
|
||||
self._embedding_model = EmbeddingModel.from_pretrained(
|
||||
"pyannote/wespeaker-voxceleb-resnet34-LM",
|
||||
use_hf_token=self._hf_token,
|
||||
)
|
||||
logger.info("Shared streaming models loaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load streaming models: {e}") from e
|
||||
|
||||
def create_streaming_session(self, meeting_id: str) -> DiarizationSession:
|
||||
"""Create a new per-meeting streaming diarization session.
|
||||
|
||||
Each session maintains its own pipeline state, enabling concurrent
|
||||
diarization of multiple meetings without interference. The underlying
|
||||
models are shared across sessions for memory efficiency.
|
||||
|
||||
Args:
|
||||
meeting_id: Unique identifier for the meeting.
|
||||
|
||||
Returns:
|
||||
New DiarizationSession for the meeting.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model loading fails.
|
||||
ValueError: If HuggingFace token is not provided.
|
||||
"""
|
||||
self._ensure_streaming_models_loaded()
|
||||
|
||||
from diart import SpeakerDiarization, SpeakerDiarizationConfig
|
||||
|
||||
config = SpeakerDiarizationConfig(
|
||||
segmentation=self._segmentation_model,
|
||||
embedding=self._embedding_model,
|
||||
step=self._streaming_latency,
|
||||
latency=self._streaming_latency,
|
||||
device=self._resolve_device(),
|
||||
)
|
||||
|
||||
pipeline = SpeakerDiarization(config)
|
||||
logger.info("Created streaming session for meeting %s", meeting_id)
|
||||
|
||||
return DiarizationSession(
|
||||
meeting_id=meeting_id,
|
||||
_pipeline=pipeline,
|
||||
_sample_rate=DEFAULT_SAMPLE_RATE,
|
||||
)
|
||||
|
||||
def load_offline_model(self) -> None:
|
||||
"""Load the offline diarization model (pyannote.audio).
|
||||
|
||||
@@ -287,7 +367,19 @@ class DiarizationEngine:
|
||||
return turns
|
||||
|
||||
def reset_streaming(self) -> None:
|
||||
"""Reset streaming pipeline state for a new recording."""
|
||||
"""Reset streaming pipeline state for a new recording.
|
||||
|
||||
.. deprecated::
|
||||
Use create_streaming_session() for per-meeting sessions instead.
|
||||
This method resets global state and will cause race conditions
|
||||
with concurrent meetings.
|
||||
"""
|
||||
warnings.warn(
|
||||
"reset_streaming() is deprecated. Use create_streaming_session() for "
|
||||
"per-meeting sessions to avoid race conditions with concurrent meetings.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if self._streaming_pipeline is not None:
|
||||
self._streaming_pipeline.reset()
|
||||
logger.debug("Streaming pipeline state reset")
|
||||
@@ -296,6 +388,8 @@ class DiarizationEngine:
|
||||
"""Unload all models to free memory."""
|
||||
self._streaming_pipeline = None
|
||||
self._offline_pipeline = None
|
||||
self._segmentation_model = None
|
||||
self._embedding_model = None
|
||||
self._device = None
|
||||
logger.info("Diarization models unloaded")
|
||||
|
||||
|
||||
170
src/noteflow/infrastructure/diarization/session.py
Normal file
170
src/noteflow/infrastructure/diarization/session.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""Per-meeting diarization session for streaming speaker identification.
|
||||
|
||||
Each session maintains its own pipeline state, enabling concurrent meetings
|
||||
without cross-talk. Shared models are loaded once and reused across sessions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.config.constants import DEFAULT_SAMPLE_RATE
|
||||
from noteflow.infrastructure.diarization.dto import SpeakerTurn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from diart import SpeakerDiarization
|
||||
from numpy.typing import NDArray
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiarizationSession:
|
||||
"""Per-meeting streaming diarization session.
|
||||
|
||||
Maintains independent pipeline state for a single meeting, enabling
|
||||
concurrent diarization of multiple meetings without interference.
|
||||
|
||||
The session owns its own SpeakerDiarization pipeline instance but
|
||||
shares the underlying segmentation and embedding models with other
|
||||
sessions for memory efficiency.
|
||||
"""
|
||||
|
||||
meeting_id: str
|
||||
_pipeline: SpeakerDiarization
|
||||
_sample_rate: int = DEFAULT_SAMPLE_RATE
|
||||
_stream_time: float = field(default=0.0, init=False)
|
||||
_turns: list[SpeakerTurn] = field(default_factory=list, init=False)
|
||||
_closed: bool = field(default=False, init=False)
|
||||
|
||||
def process_chunk(
|
||||
self,
|
||||
audio: NDArray[np.float32],
|
||||
sample_rate: int | None = None,
|
||||
) -> Sequence[SpeakerTurn]:
|
||||
"""Process an audio chunk and return new speaker turns.
|
||||
|
||||
Args:
|
||||
audio: Audio samples as float32 array (mono).
|
||||
sample_rate: Audio sample rate (defaults to session's configured rate).
|
||||
|
||||
Returns:
|
||||
Sequence of speaker turns detected in this chunk,
|
||||
with times adjusted to absolute stream position.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If session is closed.
|
||||
"""
|
||||
if self._closed:
|
||||
raise RuntimeError(f"Session {self.meeting_id} is closed")
|
||||
|
||||
if audio.size == 0:
|
||||
return []
|
||||
|
||||
rate = sample_rate or self._sample_rate
|
||||
duration = len(audio) / rate
|
||||
|
||||
# Import here to avoid import errors when diart not installed
|
||||
from pyannote.core import SlidingWindow, SlidingWindowFeature
|
||||
|
||||
# Reshape audio for diart: (samples,) -> (1, samples)
|
||||
if audio.ndim == 1:
|
||||
audio = audio.reshape(1, -1)
|
||||
|
||||
# Create SlidingWindowFeature for diart
|
||||
window = SlidingWindow(start=0.0, duration=duration, step=duration)
|
||||
waveform = SlidingWindowFeature(audio, window)
|
||||
|
||||
# Process through pipeline
|
||||
results = self._pipeline([waveform])
|
||||
|
||||
# Convert results to turns with absolute time offsets
|
||||
new_turns: list[SpeakerTurn] = []
|
||||
for annotation, _ in results:
|
||||
for track in annotation.itertracks(yield_label=True):
|
||||
if len(track) == 3:
|
||||
segment, _, speaker = track
|
||||
turn = SpeakerTurn(
|
||||
speaker=str(speaker),
|
||||
start=segment.start + self._stream_time,
|
||||
end=segment.end + self._stream_time,
|
||||
)
|
||||
new_turns.append(turn)
|
||||
self._turns.append(turn)
|
||||
|
||||
self._stream_time += duration
|
||||
return new_turns
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset session state for restarting diarization.
|
||||
|
||||
Clears accumulated turns and resets stream time to zero.
|
||||
The underlying pipeline is also reset.
|
||||
"""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._pipeline.reset()
|
||||
self._stream_time = 0.0
|
||||
self._turns.clear()
|
||||
logger.debug("Session %s reset", self.meeting_id)
|
||||
|
||||
def restore(
|
||||
self,
|
||||
turns: Sequence[SpeakerTurn],
|
||||
*,
|
||||
stream_time: float | None = None,
|
||||
) -> None:
|
||||
"""Restore session state from prior streaming turns.
|
||||
|
||||
Use during crash recovery to continue diarization timelines.
|
||||
|
||||
Args:
|
||||
turns: Previously collected speaker turns.
|
||||
stream_time: Optional stream time to restore. If not provided,
|
||||
uses the max end time from turns.
|
||||
"""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._turns = list(turns)
|
||||
if stream_time is None:
|
||||
stream_time = max((t.end for t in turns), default=0.0)
|
||||
self._stream_time = max(self._stream_time, stream_time)
|
||||
logger.debug(
|
||||
"Session %s restored stream_time=%.3f with %d turns",
|
||||
self.meeting_id,
|
||||
self._stream_time,
|
||||
len(self._turns),
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the session and release resources.
|
||||
|
||||
After closing, the session cannot be used for further processing.
|
||||
"""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
self._turns.clear()
|
||||
logger.debug("Session %s closed", self.meeting_id)
|
||||
|
||||
@property
|
||||
def stream_time(self) -> float:
|
||||
"""Current stream time position in seconds."""
|
||||
return self._stream_time
|
||||
|
||||
@property
|
||||
def turns(self) -> list[SpeakerTurn]:
|
||||
"""All accumulated speaker turns for this session."""
|
||||
return list(self._turns)
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if session is closed."""
|
||||
return self._closed
|
||||
10
src/noteflow/infrastructure/persistence/memory/__init__.py
Normal file
10
src/noteflow/infrastructure/persistence/memory/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""In-memory persistence implementations.
|
||||
|
||||
Provides repository-pattern wrappers around the MeetingStore for use
|
||||
when no database is configured. Implements the UnitOfWork protocol
|
||||
for uniform access across database and memory backends.
|
||||
"""
|
||||
|
||||
from .unit_of_work import MemoryUnitOfWork
|
||||
|
||||
__all__ = ["MemoryUnitOfWork"]
|
||||
259
src/noteflow/infrastructure/persistence/memory/repositories.py
Normal file
259
src/noteflow/infrastructure/persistence/memory/repositories.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""Memory repository implementations wrapping MeetingStore.
|
||||
|
||||
Provides repository interfaces over the in-memory MeetingStore for
|
||||
uniform access pattern with database implementations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.domain.entities import Meeting, Segment, Summary
|
||||
from noteflow.domain.value_objects import MeetingState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.domain.entities import Annotation
|
||||
from noteflow.domain.value_objects import AnnotationId, MeetingId
|
||||
from noteflow.grpc.meeting_store import MeetingStore
|
||||
from noteflow.infrastructure.persistence.repositories import (
|
||||
DiarizationJob,
|
||||
StreamingTurn,
|
||||
)
|
||||
|
||||
|
||||
class MemoryMeetingRepository:
|
||||
"""Meeting repository backed by MeetingStore."""
|
||||
|
||||
def __init__(self, store: MeetingStore) -> None:
|
||||
"""Initialize with meeting store.
|
||||
|
||||
Args:
|
||||
store: In-memory meeting store.
|
||||
"""
|
||||
self._store = store
|
||||
|
||||
async def create(self, meeting: Meeting) -> Meeting:
|
||||
"""Persist a new meeting."""
|
||||
return self._store.insert(meeting)
|
||||
|
||||
async def get(self, meeting_id: MeetingId) -> Meeting | None:
|
||||
"""Retrieve a meeting by ID."""
|
||||
return self._store.get(str(meeting_id))
|
||||
|
||||
async def update(self, meeting: Meeting) -> Meeting:
|
||||
"""Update an existing meeting."""
|
||||
return self._store.update(meeting)
|
||||
|
||||
async def delete(self, meeting_id: MeetingId) -> bool:
|
||||
"""Delete a meeting and all associated data."""
|
||||
return self._store.delete(str(meeting_id))
|
||||
|
||||
async def list_all(
|
||||
self,
|
||||
states: list[MeetingState] | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
sort_desc: bool = True,
|
||||
) -> tuple[Sequence[Meeting], int]:
|
||||
"""List meetings with optional filtering."""
|
||||
return self._store.list_all(states, limit, offset, sort_desc)
|
||||
|
||||
async def count_by_state(self, state: MeetingState) -> int:
|
||||
"""Count meetings in a specific state."""
|
||||
return self._store.count_by_state(state)
|
||||
|
||||
async def find_older_than(self, cutoff: datetime) -> Sequence[Meeting]:
|
||||
"""Find completed meetings older than cutoff date."""
|
||||
return self._store.find_older_than(cutoff)
|
||||
|
||||
|
||||
class MemorySegmentRepository:
|
||||
"""Segment repository backed by MeetingStore."""
|
||||
|
||||
def __init__(self, store: MeetingStore) -> None:
|
||||
"""Initialize with meeting store."""
|
||||
self._store = store
|
||||
|
||||
async def add(self, meeting_id: MeetingId, segment: Segment) -> Segment:
|
||||
"""Add a segment to a meeting."""
|
||||
self._store.add_segment(str(meeting_id), segment)
|
||||
return segment
|
||||
|
||||
async def add_batch(
|
||||
self,
|
||||
meeting_id: MeetingId,
|
||||
segments: Sequence[Segment],
|
||||
) -> Sequence[Segment]:
|
||||
"""Add multiple segments to a meeting in batch."""
|
||||
for segment in segments:
|
||||
self._store.add_segment(str(meeting_id), segment)
|
||||
return segments
|
||||
|
||||
async def get_by_meeting(
|
||||
self,
|
||||
meeting_id: MeetingId,
|
||||
include_words: bool = True,
|
||||
) -> Sequence[Segment]:
|
||||
"""Get all segments for a meeting."""
|
||||
return self._store.get_segments(str(meeting_id))
|
||||
|
||||
async def search_semantic(
|
||||
self,
|
||||
query_embedding: list[float],
|
||||
limit: int = 10,
|
||||
meeting_id: MeetingId | None = None,
|
||||
) -> Sequence[tuple[Segment, float]]:
|
||||
"""Semantic search not supported in memory mode."""
|
||||
return []
|
||||
|
||||
async def update_embedding(
|
||||
self,
|
||||
segment_db_id: int,
|
||||
embedding: list[float],
|
||||
) -> None:
|
||||
"""Embeddings not supported in memory mode."""
|
||||
|
||||
async def update_speaker(
|
||||
self,
|
||||
segment_db_id: int,
|
||||
speaker_id: str,
|
||||
speaker_confidence: float,
|
||||
) -> None:
|
||||
"""Update speaker for segment - not applicable in memory mode.
|
||||
|
||||
In memory mode, segments are updated directly on the entity.
|
||||
This method exists for interface compatibility.
|
||||
"""
|
||||
|
||||
async def get_next_segment_id(self, meeting_id: MeetingId) -> int:
|
||||
"""Get next segment ID for a meeting."""
|
||||
return self._store.get_next_segment_id(str(meeting_id))
|
||||
|
||||
|
||||
class MemorySummaryRepository:
|
||||
"""Summary repository backed by MeetingStore."""
|
||||
|
||||
def __init__(self, store: MeetingStore) -> None:
|
||||
"""Initialize with meeting store."""
|
||||
self._store = store
|
||||
|
||||
async def save(self, summary: Summary) -> Summary:
|
||||
"""Save or update a meeting summary."""
|
||||
self._store.set_summary(str(summary.meeting_id), summary)
|
||||
return summary
|
||||
|
||||
async def get_by_meeting(self, meeting_id: MeetingId) -> Summary | None:
|
||||
"""Get summary for a meeting."""
|
||||
return self._store.get_summary(str(meeting_id))
|
||||
|
||||
async def delete_by_meeting(self, meeting_id: MeetingId) -> bool:
|
||||
"""Delete summary for a meeting."""
|
||||
return self._store.clear_summary(str(meeting_id))
|
||||
|
||||
|
||||
class UnsupportedAnnotationRepository:
|
||||
"""Annotation repository that raises for unsupported operations.
|
||||
|
||||
Used in memory mode where annotations require database persistence.
|
||||
"""
|
||||
|
||||
async def add(self, annotation: Annotation) -> Annotation:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("Annotations require database persistence")
|
||||
|
||||
async def get(self, annotation_id: AnnotationId) -> Annotation | None:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("Annotations require database persistence")
|
||||
|
||||
async def get_by_meeting(self, meeting_id: MeetingId) -> Sequence[Annotation]:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("Annotations require database persistence")
|
||||
|
||||
async def get_by_time_range(
|
||||
self,
|
||||
meeting_id: MeetingId,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
) -> Sequence[Annotation]:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("Annotations require database persistence")
|
||||
|
||||
async def update(self, annotation: Annotation) -> Annotation:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("Annotations require database persistence")
|
||||
|
||||
async def delete(self, annotation_id: AnnotationId) -> bool:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("Annotations require database persistence")
|
||||
|
||||
|
||||
class UnsupportedDiarizationJobRepository:
|
||||
"""Diarization job repository that raises for unsupported operations.
|
||||
|
||||
Used in memory mode where jobs require database persistence.
|
||||
"""
|
||||
|
||||
async def create(self, job: DiarizationJob) -> DiarizationJob:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("Diarization jobs require database persistence")
|
||||
|
||||
async def get(self, job_id: str) -> DiarizationJob | None:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("Diarization jobs require database persistence")
|
||||
|
||||
async def update_status(
|
||||
self,
|
||||
job_id: str,
|
||||
status: int,
|
||||
*,
|
||||
segments_updated: int | None = None,
|
||||
speaker_ids: list[str] | None = None,
|
||||
error_message: str | None = None,
|
||||
) -> bool:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("Diarization jobs require database persistence")
|
||||
|
||||
async def prune_completed(self, ttl_seconds: float) -> int:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("Diarization jobs require database persistence")
|
||||
|
||||
async def add_streaming_turns(
|
||||
self,
|
||||
meeting_id: str,
|
||||
turns: Sequence[StreamingTurn],
|
||||
) -> int:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("Diarization jobs require database persistence")
|
||||
|
||||
async def get_streaming_turns(self, meeting_id: str) -> list[StreamingTurn]:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("Diarization jobs require database persistence")
|
||||
|
||||
async def clear_streaming_turns(self, meeting_id: str) -> int:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("Diarization jobs require database persistence")
|
||||
|
||||
async def mark_running_as_failed(self, error_message: str = "Server restarted") -> int:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("Diarization jobs require database persistence")
|
||||
|
||||
|
||||
class UnsupportedPreferencesRepository:
|
||||
"""Preferences repository that raises for unsupported operations.
|
||||
|
||||
Used in memory mode where preferences require database persistence.
|
||||
"""
|
||||
|
||||
async def get(self, key: str) -> object | None:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("Preferences require database persistence")
|
||||
|
||||
async def set(self, key: str, value: object) -> None:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("Preferences require database persistence")
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("Preferences require database persistence")
|
||||
133
src/noteflow/infrastructure/persistence/memory/unit_of_work.py
Normal file
133
src/noteflow/infrastructure/persistence/memory/unit_of_work.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""In-memory Unit of Work implementation.
|
||||
|
||||
Provides a UnitOfWork implementation backed by the MeetingStore for use
|
||||
when no database is configured. Implements the same interface as
|
||||
SqlAlchemyUnitOfWork for uniform access.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Self
|
||||
|
||||
from .repositories import (
|
||||
MemoryMeetingRepository,
|
||||
MemorySegmentRepository,
|
||||
MemorySummaryRepository,
|
||||
UnsupportedAnnotationRepository,
|
||||
UnsupportedDiarizationJobRepository,
|
||||
UnsupportedPreferencesRepository,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.grpc.meeting_store import MeetingStore
|
||||
|
||||
|
||||
class MemoryUnitOfWork:
|
||||
"""In-memory Unit of Work backed by MeetingStore.
|
||||
|
||||
Implements the same interface as SqlAlchemyUnitOfWork for uniform
|
||||
access across database and memory backends.
|
||||
|
||||
Commit and rollback are no-ops since changes are applied directly
|
||||
to the in-memory store.
|
||||
|
||||
Example:
|
||||
async with MemoryUnitOfWork(store) as uow:
|
||||
meeting = await uow.meetings.get(meeting_id)
|
||||
await uow.segments.add(meeting_id, segment)
|
||||
await uow.commit() # No-op, changes already applied
|
||||
"""
|
||||
|
||||
def __init__(self, store: MeetingStore) -> None:
|
||||
"""Initialize unit of work with meeting store.
|
||||
|
||||
Args:
|
||||
store: In-memory meeting storage.
|
||||
"""
|
||||
self._store = store
|
||||
self._meetings = MemoryMeetingRepository(store)
|
||||
self._segments = MemorySegmentRepository(store)
|
||||
self._summaries = MemorySummaryRepository(store)
|
||||
self._annotations = UnsupportedAnnotationRepository()
|
||||
self._diarization_jobs = UnsupportedDiarizationJobRepository()
|
||||
self._preferences = UnsupportedPreferencesRepository()
|
||||
|
||||
# Core repositories
|
||||
@property
|
||||
def meetings(self) -> MemoryMeetingRepository:
|
||||
"""Get meetings repository."""
|
||||
return self._meetings
|
||||
|
||||
@property
|
||||
def segments(self) -> MemorySegmentRepository:
|
||||
"""Get segments repository."""
|
||||
return self._segments
|
||||
|
||||
@property
|
||||
def summaries(self) -> MemorySummaryRepository:
|
||||
"""Get summaries repository."""
|
||||
return self._summaries
|
||||
|
||||
# Optional repositories (unsupported in memory mode)
|
||||
@property
|
||||
def annotations(self) -> UnsupportedAnnotationRepository:
|
||||
"""Get annotations repository (unsupported)."""
|
||||
return self._annotations
|
||||
|
||||
@property
|
||||
def diarization_jobs(self) -> UnsupportedDiarizationJobRepository:
|
||||
"""Get diarization jobs repository (unsupported)."""
|
||||
return self._diarization_jobs
|
||||
|
||||
@property
|
||||
def preferences(self) -> UnsupportedPreferencesRepository:
|
||||
"""Get preferences repository (unsupported)."""
|
||||
return self._preferences
|
||||
|
||||
# Feature flags - limited in memory mode
|
||||
@property
|
||||
def supports_annotations(self) -> bool:
|
||||
"""Annotations not supported in memory mode."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_diarization_jobs(self) -> bool:
|
||||
"""Diarization job persistence not supported in memory mode."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_preferences(self) -> bool:
|
||||
"""User preferences not supported in memory mode."""
|
||||
return False
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""Enter the unit of work context.
|
||||
|
||||
Returns:
|
||||
Self for use in async with statement.
|
||||
"""
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: object,
|
||||
) -> None:
|
||||
"""Exit the unit of work context.
|
||||
|
||||
No-op for memory implementation since changes are applied directly.
|
||||
"""
|
||||
|
||||
async def commit(self) -> None:
|
||||
"""Commit the current transaction.
|
||||
|
||||
No-op for memory implementation - changes are applied directly.
|
||||
"""
|
||||
|
||||
async def rollback(self) -> None:
|
||||
"""Rollback the current transaction.
|
||||
|
||||
Note: Memory implementation does not support rollback.
|
||||
Changes are applied directly and cannot be undone.
|
||||
"""
|
||||
@@ -19,6 +19,8 @@ from sqlalchemy import (
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
|
||||
from noteflow.domain.utils.time import utc_now
|
||||
|
||||
# Vector dimension for embeddings (OpenAI compatible)
|
||||
EMBEDDING_DIM = 1536
|
||||
|
||||
@@ -45,7 +47,7 @@ class MeetingModel(Base):
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=datetime.now,
|
||||
default=utc_now,
|
||||
)
|
||||
started_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
@@ -121,7 +123,7 @@ class SegmentModel(Base):
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=datetime.now,
|
||||
default=utc_now,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
@@ -178,7 +180,7 @@ class SummaryModel(Base):
|
||||
generated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=datetime.now,
|
||||
default=utc_now,
|
||||
)
|
||||
model_version: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
|
||||
@@ -296,7 +298,7 @@ class AnnotationModel(Base):
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=datetime.now,
|
||||
default=utc_now,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
@@ -322,8 +324,8 @@ class UserPreferencesModel(Base):
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=datetime.now,
|
||||
onupdate=datetime.now,
|
||||
default=utc_now,
|
||||
onupdate=utc_now,
|
||||
)
|
||||
|
||||
|
||||
@@ -355,13 +357,13 @@ class DiarizationJobModel(Base):
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=datetime.now,
|
||||
default=utc_now,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=datetime.now,
|
||||
onupdate=datetime.now,
|
||||
default=utc_now,
|
||||
onupdate=utc_now,
|
||||
)
|
||||
|
||||
|
||||
@@ -390,5 +392,5 @@ class StreamingDiarizationTurnModel(Base):
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=datetime.now,
|
||||
default=utc_now,
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -78,3 +78,63 @@ class BaseRepository:
|
||||
"""
|
||||
await self._session.delete(model)
|
||||
await self._session.flush()
|
||||
|
||||
async def _add_all_and_flush(self, models: list[TModel]) -> list[TModel]:
|
||||
"""Add multiple models to session and flush once.
|
||||
|
||||
Use this for batching inserts to reduce database round-trips.
|
||||
|
||||
Args:
|
||||
models: List of ORM model instances to persist.
|
||||
|
||||
Returns:
|
||||
The persisted models with generated fields populated.
|
||||
"""
|
||||
self._session.add_all(models)
|
||||
await self._session.flush()
|
||||
return models
|
||||
|
||||
async def _execute_count(self, stmt: Select[tuple[int]]) -> int:
|
||||
"""Execute count query and return result.
|
||||
|
||||
Args:
|
||||
stmt: SQLAlchemy select statement returning a count.
|
||||
|
||||
Returns:
|
||||
Integer count value.
|
||||
"""
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one()
|
||||
|
||||
async def _execute_exists(self, stmt: Select[Any]) -> bool:
|
||||
"""Check if any rows match the query.
|
||||
|
||||
More efficient than fetching all rows - stops at first match.
|
||||
|
||||
Args:
|
||||
stmt: SQLAlchemy select statement.
|
||||
|
||||
Returns:
|
||||
True if at least one row exists.
|
||||
"""
|
||||
result = await self._session.execute(stmt.limit(1))
|
||||
return result.scalar() is not None
|
||||
|
||||
async def _update_fields(
|
||||
self,
|
||||
model: TModel,
|
||||
**fields: Any,
|
||||
) -> TModel:
|
||||
"""Update model fields and flush.
|
||||
|
||||
Args:
|
||||
model: ORM model instance to update.
|
||||
**fields: Field name/value pairs to set.
|
||||
|
||||
Returns:
|
||||
The updated model.
|
||||
"""
|
||||
for key, value in fields.items():
|
||||
setattr(model, key, value)
|
||||
await self._session.flush()
|
||||
return model
|
||||
|
||||
@@ -15,15 +15,15 @@ from noteflow.infrastructure.persistence.repositories._base import BaseRepositor
|
||||
class SqlAlchemySegmentRepository(BaseRepository):
|
||||
"""SQLAlchemy implementation of SegmentRepository."""
|
||||
|
||||
async def add(self, meeting_id: MeetingId, segment: Segment) -> Segment:
|
||||
"""Add a segment to a meeting.
|
||||
def _create_segment_model(self, meeting_id: MeetingId, segment: Segment) -> SegmentModel:
|
||||
"""Create ORM model for a segment without persisting.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting identifier.
|
||||
segment: Segment to add.
|
||||
segment: Domain segment.
|
||||
|
||||
Returns:
|
||||
Added segment with db_id populated.
|
||||
SegmentModel ready for persistence.
|
||||
"""
|
||||
model = SegmentModel(
|
||||
meeting_id=UUID(str(meeting_id)),
|
||||
@@ -46,6 +46,20 @@ class SqlAlchemySegmentRepository(BaseRepository):
|
||||
word_model = WordTimingModel(**word_kwargs)
|
||||
model.words.append(word_model)
|
||||
|
||||
return model
|
||||
|
||||
async def add(self, meeting_id: MeetingId, segment: Segment) -> Segment:
|
||||
"""Add a segment to a meeting.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting identifier.
|
||||
segment: Segment to add.
|
||||
|
||||
Returns:
|
||||
Added segment with db_id populated.
|
||||
"""
|
||||
model = self._create_segment_model(meeting_id, segment)
|
||||
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
|
||||
@@ -61,6 +75,8 @@ class SqlAlchemySegmentRepository(BaseRepository):
|
||||
) -> Sequence[Segment]:
|
||||
"""Add multiple segments to a meeting in batch.
|
||||
|
||||
Uses add_all() with a single flush for better performance.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting identifier.
|
||||
segments: Segments to add.
|
||||
@@ -68,13 +84,25 @@ class SqlAlchemySegmentRepository(BaseRepository):
|
||||
Returns:
|
||||
Added segments with db_ids populated.
|
||||
"""
|
||||
result_segments: list[Segment] = []
|
||||
if not segments:
|
||||
return []
|
||||
|
||||
# Build all models upfront
|
||||
models: list[SegmentModel] = []
|
||||
for segment in segments:
|
||||
added = await self.add(meeting_id, segment)
|
||||
result_segments.append(added)
|
||||
model = self._create_segment_model(meeting_id, segment)
|
||||
models.append(model)
|
||||
|
||||
return result_segments
|
||||
# Add all models and flush once
|
||||
self._session.add_all(models)
|
||||
await self._session.flush()
|
||||
|
||||
# Update segments with db_ids
|
||||
for segment, model in zip(segments, models, strict=True):
|
||||
segment.db_id = model.id
|
||||
segment.meeting_id = meeting_id
|
||||
|
||||
return list(segments)
|
||||
|
||||
async def get_by_meeting(
|
||||
self,
|
||||
|
||||
@@ -126,6 +126,22 @@ class SqlAlchemyUnitOfWork:
|
||||
raise RuntimeError("UnitOfWork not in context")
|
||||
return self._summaries_repo
|
||||
|
||||
# Feature flags - all True for database-backed implementation
|
||||
@property
|
||||
def supports_annotations(self) -> bool:
|
||||
"""Annotations are fully supported with database."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_diarization_jobs(self) -> bool:
|
||||
"""Diarization job persistence is fully supported with database."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_preferences(self) -> bool:
|
||||
"""User preferences persistence is fully supported with database."""
|
||||
return True
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""Enter the unit of work context.
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ class KeyringKeyStore:
|
||||
stored = keyring.get_password(self._service_name, self._key_name)
|
||||
if stored is not None:
|
||||
logger.debug("Retrieved existing master key from keyring")
|
||||
return base64.b64decode(stored)
|
||||
return _decode_and_validate_key(stored, "Keyring storage")
|
||||
|
||||
# Generate new key
|
||||
new_key, encoded = _generate_key()
|
||||
|
||||
@@ -322,7 +322,7 @@ class TestTranscriptSearch:
|
||||
"""Tests for transcript search functionality."""
|
||||
|
||||
def test_search_filters_segments(self) -> None:
|
||||
"""Search should filter visible segments."""
|
||||
"""Search should filter visible segments via visibility toggle."""
|
||||
state = MockAppState()
|
||||
component = TranscriptComponent(state)
|
||||
component.build()
|
||||
@@ -338,8 +338,9 @@ class TestTranscriptSearch:
|
||||
component._search_query = "world"
|
||||
component._rerender_all_segments()
|
||||
|
||||
# Should only show segments containing "world"
|
||||
visible_count = sum(row is not None for row in component._segment_rows)
|
||||
# All rows exist, but only matching ones are visible
|
||||
assert len(component._segment_rows) == 3
|
||||
visible_count = sum(row.visible for row in component._segment_rows)
|
||||
assert visible_count == 2
|
||||
|
||||
def test_search_is_case_insensitive(self) -> None:
|
||||
@@ -356,7 +357,9 @@ class TestTranscriptSearch:
|
||||
component._search_query = "world"
|
||||
component._rerender_all_segments()
|
||||
|
||||
visible_count = sum(row is not None for row in component._segment_rows)
|
||||
# All rows exist, but only matching ones are visible
|
||||
assert len(component._segment_rows) == 2
|
||||
visible_count = sum(row.visible for row in component._segment_rows)
|
||||
assert visible_count == 1
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import pytest
|
||||
from noteflow.domain.entities.meeting import Meeting
|
||||
from noteflow.domain.entities.segment import Segment
|
||||
from noteflow.domain.entities.summary import Summary
|
||||
from noteflow.domain.utils.time import utc_now
|
||||
from noteflow.domain.value_objects import MeetingState
|
||||
|
||||
|
||||
@@ -197,7 +198,7 @@ class TestMeetingProperties:
|
||||
def test_duration_seconds_in_progress(self) -> None:
|
||||
"""Test duration is > 0 when started but not ended."""
|
||||
meeting = Meeting.create()
|
||||
meeting.started_at = datetime.now() - timedelta(seconds=5)
|
||||
meeting.started_at = utc_now() - timedelta(seconds=5)
|
||||
assert meeting.duration_seconds >= 5.0
|
||||
|
||||
def test_is_active_created(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user