Add initial Docker and development environment setup
- Created .dockerignore to exclude unnecessary files from Docker builds. - Added .repomixignore for managing ignored patterns in Repomix. - Introduced Dockerfile.dev for development environment setup with Python 3.12. - Configured docker-compose.yaml to define services, including a PostgreSQL database. - Established a devcontainer.json for Visual Studio Code integration. - Implemented postCreate.sh for automatic dependency installation in the dev container. - Added constants.py to centralize configuration constants for the project. - Updated pyproject.toml to include new development dependencies. - Created initial documentation files for project overview and style conventions. - Added tests for new functionalities to ensure reliability and correctness.
This commit is contained in:
50
.devcontainer/Dockerfile
Normal file
50
.devcontainer/Dockerfile
Normal file
@@ -0,0 +1,50 @@
|
||||
FROM mcr.microsoft.com/devcontainers/python:3.12-bookworm
|
||||
|
||||
# System packages for UI (Flet/Flutter), tray (pystray), hotkeys (pynput), and audio (sounddevice).
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
pkg-config \
|
||||
meson \
|
||||
ninja-build \
|
||||
dbus-x11 \
|
||||
libgtk-3-0 \
|
||||
libgirepository1.0-1 \
|
||||
libgirepository1.0-dev \
|
||||
gobject-introspection \
|
||||
libcairo2-dev \
|
||||
libglib2.0-dev \
|
||||
gir1.2-gtk-3.0 \
|
||||
libayatana-appindicator3-1 \
|
||||
gir1.2-ayatanaappindicator3-0.1 \
|
||||
libnss3 \
|
||||
libx11-xcb1 \
|
||||
libxcomposite1 \
|
||||
libxdamage1 \
|
||||
libxrandr2 \
|
||||
libxext6 \
|
||||
libxi6 \
|
||||
libxtst6 \
|
||||
libxfixes3 \
|
||||
libxrender1 \
|
||||
libxinerama1 \
|
||||
libxcursor1 \
|
||||
libxss1 \
|
||||
libxkbcommon0 \
|
||||
libxkbcommon-x11-0 \
|
||||
libgl1 \
|
||||
libegl1 \
|
||||
libgbm1 \
|
||||
libdrm2 \
|
||||
libasound2 \
|
||||
libpulse0 \
|
||||
libportaudio2 \
|
||||
portaudio19-dev \
|
||||
libsndfile1 \
|
||||
libpango-1.0-0 \
|
||||
libpangocairo-1.0-0 \
|
||||
libatk1.0-0 \
|
||||
libatk-bridge2.0-0 \
|
||||
libgdk-pixbuf2.0-0 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
40
.devcontainer/devcontainer.json
Normal file
40
.devcontainer/devcontainer.json
Normal file
@@ -0,0 +1,40 @@
|
||||
{
|
||||
"name": "noteflow",
|
||||
"build": {
|
||||
"dockerfile": "Dockerfile"
|
||||
},
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/desktop-lite:1": {
|
||||
"webPort": "6080"
|
||||
},
|
||||
"ghcr.io/devcontainers/features/docker-outside-of-docker:1": {}
|
||||
},
|
||||
"forwardPorts": [6080],
|
||||
"portsAttributes": {
|
||||
"6080": {
|
||||
"label": "Desktop (noVNC)",
|
||||
"onAutoForward": "notify"
|
||||
}
|
||||
},
|
||||
"containerEnv": {
|
||||
"DISPLAY": ":1",
|
||||
"XDG_RUNTIME_DIR": "/tmp/runtime-vscode"
|
||||
},
|
||||
"postCreateCommand": ".devcontainer/postCreate.sh",
|
||||
"remoteUser": "vscode",
|
||||
"customizations": {
|
||||
"vscode": {
|
||||
"settings": {
|
||||
"python.defaultInterpreterPath": "/usr/local/bin/python",
|
||||
"python.analysis.typeCheckingMode": "strict",
|
||||
"python.analysis.autoSearchPaths": true,
|
||||
"python.analysis.diagnosticMode": "workspace"
|
||||
},
|
||||
"extensions": [
|
||||
"ms-python.python",
|
||||
"ms-python.vscode-pylance",
|
||||
"charliermarsh.ruff"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
10
.devcontainer/postCreate.sh
Executable file
10
.devcontainer/postCreate.sh
Executable file
@@ -0,0 +1,10 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -e ".[dev]"
|
||||
|
||||
# Enable pystray GTK/AppIndicator backend on Linux (optional but recommended for tray UI).
|
||||
if ! python -m pip install pygobject; then
|
||||
echo "pygobject install failed; pystray will fall back to X11 backend." >&2
|
||||
fi
|
||||
20
.dockerignore
Normal file
20
.dockerignore
Normal file
@@ -0,0 +1,20 @@
|
||||
.git
|
||||
.gitignore
|
||||
.venv
|
||||
__pycache__
|
||||
**/__pycache__
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
*.pytest_cache
|
||||
.mypy_cache
|
||||
.ruff_cache
|
||||
.pytest_cache
|
||||
.DS_Store
|
||||
.env
|
||||
.env.*
|
||||
logs/
|
||||
.spikes/
|
||||
.spike_cache/
|
||||
spikes/__pycache__
|
||||
spikes/*/__pycache__
|
||||
4
.repomixignore
Normal file
4
.repomixignore
Normal file
@@ -0,0 +1,4 @@
|
||||
# Add patterns to ignore here, one per line
|
||||
# Example:
|
||||
# *.log
|
||||
# tmp/
|
||||
1
.serena/.gitignore
vendored
Normal file
1
.serena/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
/cache
|
||||
6
.serena/memories/completion_checklist.md
Normal file
6
.serena/memories/completion_checklist.md
Normal file
@@ -0,0 +1,6 @@
|
||||
# Completion checklist
|
||||
|
||||
- Run relevant tests: `pytest` (or `pytest -m "not integration"`).
|
||||
- Lint: `ruff check .` (optionally `ruff check --fix .`).
|
||||
- Type check: `mypy src/noteflow` (optional: `basedpyright`).
|
||||
- If packaging changes: `python -m build`.
|
||||
15
.serena/memories/project_overview.md
Normal file
15
.serena/memories/project_overview.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# NoteFlow project overview
|
||||
|
||||
- Purpose: Local-first meeting capture/transcription app with gRPC server + Flet client UI, with persistence, summarization, and diarization support.
|
||||
- Tech stack: Python 3.12; gRPC; Flet for UI; SQLAlchemy + Alembic for persistence; asyncpg/PostgreSQL for DB; Ruff for lint; mypy/basedpyright for typing; hatchling for packaging.
|
||||
- Structure:
|
||||
- `src/noteflow/` main package
|
||||
- `domain/` entities + ports
|
||||
- `application/` services/use-cases
|
||||
- `infrastructure/` audio, ASR, persistence, security, diarization
|
||||
- `grpc/` proto, server, client
|
||||
- `client/` Flet UI
|
||||
- `config/` settings
|
||||
- `src/noteflow/infrastructure/persistence/migrations/` Alembic migrations
|
||||
- `tests/` mirrors package areas with `tests/fixtures/`
|
||||
- `docs/` specs/milestones; `spikes/` experiments; `logs/` local-only
|
||||
6
.serena/memories/style_conventions.md
Normal file
6
.serena/memories/style_conventions.md
Normal file
@@ -0,0 +1,6 @@
|
||||
# Style & conventions
|
||||
|
||||
- Python 3.12, 4-space indentation, 100-char line length (Ruff).
|
||||
- Naming: `snake_case` modules/functions, `PascalCase` classes, `UPPER_SNAKE_CASE` constants.
|
||||
- Keep typing explicit; compatible with strict mypy.
|
||||
- Generated `*_pb2.py` files excluded from lint.
|
||||
9
.serena/memories/suggested_commands.md
Normal file
9
.serena/memories/suggested_commands.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# Suggested commands
|
||||
|
||||
- Install dev deps: `python -m pip install -e ".[dev]"`
|
||||
- Run gRPC server: `python -m noteflow.grpc.server --help`
|
||||
- Run Flet client: `python -m noteflow.client.app --help`
|
||||
- Tests: `pytest` (or `pytest -m "not integration"` to skip external services)
|
||||
- Lint: `ruff check .` (autofix: `ruff check --fix .`)
|
||||
- Type check: `mypy src/noteflow` (optional: `basedpyright`)
|
||||
- Build wheel: `python -m build`
|
||||
84
.serena/project.yml
Normal file
84
.serena/project.yml
Normal file
@@ -0,0 +1,84 @@
|
||||
# list of languages for which language servers are started; choose from:
|
||||
# al bash clojure cpp csharp csharp_omnisharp
|
||||
# dart elixir elm erlang fortran go
|
||||
# haskell java julia kotlin lua markdown
|
||||
# nix perl php python python_jedi r
|
||||
# rego ruby ruby_solargraph rust scala swift
|
||||
# terraform typescript typescript_vts yaml zig
|
||||
# Note:
|
||||
# - For C, use cpp
|
||||
# - For JavaScript, use typescript
|
||||
# Special requirements:
|
||||
# - csharp: Requires the presence of a .sln file in the project folder.
|
||||
# When using multiple languages, the first language server that supports a given file will be used for that file.
|
||||
# The first language is the default language and the respective language server will be used as a fallback.
|
||||
# Note that when using the JetBrains backend, language servers are not used and this list is correspondingly ignored.
|
||||
languages:
|
||||
- python
|
||||
|
||||
# the encoding used by text files in the project
|
||||
# For a list of possible encodings, see https://docs.python.org/3.11/library/codecs.html#standard-encodings
|
||||
encoding: "utf-8"
|
||||
|
||||
# whether to use the project's gitignore file to ignore files
|
||||
# Added on 2025-04-07
|
||||
ignore_all_files_in_gitignore: true
|
||||
|
||||
# list of additional paths to ignore
|
||||
# same syntax as gitignore, so you can use * and **
|
||||
# Was previously called `ignored_dirs`, please update your config if you are using that.
|
||||
# Added (renamed) on 2025-04-07
|
||||
ignored_paths: []
|
||||
|
||||
# whether the project is in read-only mode
|
||||
# If set to true, all editing tools will be disabled and attempts to use them will result in an error
|
||||
# Added on 2025-04-18
|
||||
read_only: false
|
||||
|
||||
# list of tool names to exclude. We recommend not excluding any tools, see the readme for more details.
|
||||
# Below is the complete list of tools for convenience.
|
||||
# To make sure you have the latest list of tools, and to view their descriptions,
|
||||
# execute `uv run scripts/print_tool_overview.py`.
|
||||
#
|
||||
# * `activate_project`: Activates a project by name.
|
||||
# * `check_onboarding_performed`: Checks whether project onboarding was already performed.
|
||||
# * `create_text_file`: Creates/overwrites a file in the project directory.
|
||||
# * `delete_lines`: Deletes a range of lines within a file.
|
||||
# * `delete_memory`: Deletes a memory from Serena's project-specific memory store.
|
||||
# * `execute_shell_command`: Executes a shell command.
|
||||
# * `find_referencing_code_snippets`: Finds code snippets in which the symbol at the given location is referenced.
|
||||
# * `find_referencing_symbols`: Finds symbols that reference the symbol at the given location (optionally filtered by type).
|
||||
# * `find_symbol`: Performs a global (or local) search for symbols with/containing a given name/substring (optionally filtered by type).
|
||||
# * `get_current_config`: Prints the current configuration of the agent, including the active and available projects, tools, contexts, and modes.
|
||||
# * `get_symbols_overview`: Gets an overview of the top-level symbols defined in a given file.
|
||||
# * `initial_instructions`: Gets the initial instructions for the current project.
|
||||
# Should only be used in settings where the system prompt cannot be set,
|
||||
# e.g. in clients you have no control over, like Claude Desktop.
|
||||
# * `insert_after_symbol`: Inserts content after the end of the definition of a given symbol.
|
||||
# * `insert_at_line`: Inserts content at a given line in a file.
|
||||
# * `insert_before_symbol`: Inserts content before the beginning of the definition of a given symbol.
|
||||
# * `list_dir`: Lists files and directories in the given directory (optionally with recursion).
|
||||
# * `list_memories`: Lists memories in Serena's project-specific memory store.
|
||||
# * `onboarding`: Performs onboarding (identifying the project structure and essential tasks, e.g. for testing or building).
|
||||
# * `prepare_for_new_conversation`: Provides instructions for preparing for a new conversation (in order to continue with the necessary context).
|
||||
# * `read_file`: Reads a file within the project directory.
|
||||
# * `read_memory`: Reads the memory with the given name from Serena's project-specific memory store.
|
||||
# * `remove_project`: Removes a project from the Serena configuration.
|
||||
# * `replace_lines`: Replaces a range of lines within a file with new content.
|
||||
# * `replace_symbol_body`: Replaces the full definition of a symbol.
|
||||
# * `restart_language_server`: Restarts the language server, may be necessary when edits not through Serena happen.
|
||||
# * `search_for_pattern`: Performs a search for a pattern in the project.
|
||||
# * `summarize_changes`: Provides instructions for summarizing the changes made to the codebase.
|
||||
# * `switch_modes`: Activates modes by providing a list of their names
|
||||
# * `think_about_collected_information`: Thinking tool for pondering the completeness of collected information.
|
||||
# * `think_about_task_adherence`: Thinking tool for determining whether the agent is still on track with the current task.
|
||||
# * `think_about_whether_you_are_done`: Thinking tool for determining whether the task is truly completed.
|
||||
# * `write_memory`: Writes a named memory (for future reference) to Serena's project-specific memory store.
|
||||
excluded_tools: []
|
||||
|
||||
# initial prompt for the project. It will always be given to the LLM upon activating the project
|
||||
# (contrary to the memories, which are loaded on demand).
|
||||
initial_prompt: ""
|
||||
|
||||
project_name: "noteflow"
|
||||
included_optional_tools: []
|
||||
25
Dockerfile.dev
Normal file
25
Dockerfile.dev
Normal file
@@ -0,0 +1,25 @@
|
||||
FROM python:3.12-bookworm
|
||||
|
||||
ENV PIP_DISABLE_PIP_VERSION_CHECK=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1
|
||||
|
||||
# Core build/runtime deps for project packages (sounddevice, asyncpg, cryptography).
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
pkg-config \
|
||||
portaudio19-dev \
|
||||
libsndfile1 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
COPY . /workspace
|
||||
|
||||
RUN python -m pip install --upgrade pip \
|
||||
&& python -m pip install -e ".[dev]" watchfiles
|
||||
|
||||
EXPOSE 50051
|
||||
|
||||
CMD ["python", "scripts/dev_watch_server.py"]
|
||||
31
compose.yaml
Normal file
31
compose.yaml
Normal file
@@ -0,0 +1,31 @@
|
||||
services:
|
||||
server:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.dev
|
||||
ports:
|
||||
- "50051:50051"
|
||||
environment:
|
||||
NOTEFLOW_DATABASE_URL: postgresql+asyncpg://noteflow:noteflow@db:5432/noteflow
|
||||
volumes:
|
||||
- .:/workspace
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
|
||||
db:
|
||||
image: postgres:15
|
||||
environment:
|
||||
POSTGRES_DB: noteflow
|
||||
POSTGRES_USER: noteflow
|
||||
POSTGRES_PASSWORD: noteflow
|
||||
volumes:
|
||||
- noteflow_pg_data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U noteflow -d noteflow"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
volumes:
|
||||
noteflow_pg_data:
|
||||
@@ -3,7 +3,7 @@
|
||||
**Architecture:** Client-Server with gRPC (evolved from original single-process design)
|
||||
**Core principles:** Local-first, mic capture baseline, partial→final transcripts, evidence-linked summaries with strict citation enforcement.
|
||||
|
||||
**Last updated:** December 2024
|
||||
**Last updated:** December 2025
|
||||
|
||||
---
|
||||
|
||||
@@ -103,11 +103,11 @@
|
||||
|
||||
---
|
||||
|
||||
### Milestone 3 — Partial→Final transcription + transcript persistence ⚠️ PARTIAL
|
||||
### Milestone 3 — Partial→Final transcription + transcript persistence ✅ COMPLETE
|
||||
|
||||
**Goal:** near real-time transcription with stability rules.
|
||||
|
||||
**Deliverables:**
|
||||
**Deliverables:** ✅ ALL COMPLETE
|
||||
|
||||
* [x] ASR wrapper service (faster-whisper with word timestamps)
|
||||
* Location: `src/noteflow/infrastructure/asr/engine.py`
|
||||
@@ -115,37 +115,28 @@
|
||||
* [x] VAD + segment finalization logic
|
||||
* EnergyVad: `src/noteflow/infrastructure/asr/streaming_vad.py`
|
||||
* Segmenter: `src/noteflow/infrastructure/asr/segmenter.py`
|
||||
* [ ] **Partial transcript feed to UI** ← GAP
|
||||
* [x] Partial transcript feed to UI
|
||||
* Server: `_maybe_emit_partial()` called during streaming (`service.py:601`)
|
||||
* 2-second cadence with text deduplication
|
||||
* Client: Handles `is_final=False` in `client.py:458-467`
|
||||
* UI: `[LIVE]` row with blue styling (`transcript.py:182-219`)
|
||||
* [x] Final segments persisted to PostgreSQL + pgvector
|
||||
* Repository: `src/noteflow/infrastructure/persistence/repositories/segment.py`
|
||||
* [x] Post-meeting transcript view
|
||||
* Component: `src/noteflow/client/components/transcript.py`
|
||||
|
||||
**Current status:**
|
||||
**Implementation details:**
|
||||
|
||||
Final segments are emitted and persisted. **Partial streaming is proto-defined but not wired end-to-end.**
|
||||
* Server emits `UPDATE_TYPE_PARTIAL` every 2 seconds during speech activity
|
||||
* Minimum 0.5 seconds of audio before partial inference
|
||||
* Partial text deduplicated (only emitted when changed)
|
||||
* Client renders partials with `is_final=False` flag
|
||||
* UI displays `[LIVE]` indicator with blue background, grey italic text
|
||||
* Partial row cleared when final segment arrives
|
||||
|
||||
**What exists for partials:**
|
||||
* Proto: `UPDATE_TYPE_PARTIAL` defined in `noteflow.proto` (line 80)
|
||||
* Proto: `TranscriptUpdate.partial_text` field defined (line 69)
|
||||
* Client: `TranscriptSegment.is_final` field ready to distinguish partials
|
||||
* Server: `_maybe_emit_partial()` method exists but not invoked during streaming
|
||||
**Exit criteria:** ✅ ALL MET
|
||||
|
||||
**Remaining work to complete partials:**
|
||||
|
||||
1. **Server** (`src/noteflow/grpc/service.py`):
|
||||
* In `StreamTranscription`, call `_maybe_emit_partial()` every 2 seconds during speech
|
||||
* Yield `TranscriptUpdate` with `update_type=UPDATE_TYPE_PARTIAL` and `partial_text`
|
||||
* Complexity: Low (infrastructure exists)
|
||||
|
||||
2. **Client** (`src/noteflow/client/components/transcript.py`):
|
||||
* Render partial text in grey at bottom of transcript list
|
||||
* Replace on each partial update; clear on final segment
|
||||
* Complexity: Low (client already handles `is_final` field)
|
||||
|
||||
**Exit criteria:**
|
||||
|
||||
* [ ] Live view shows partial text that settles into final segments
|
||||
* [x] Live view shows partial text that settles into final segments
|
||||
* [x] After restart, final segments are still present and searchable within the meeting
|
||||
|
||||
---
|
||||
@@ -185,7 +176,7 @@ Final segments are emitted and persisted. **Partial streaming is proto-defined b
|
||||
|
||||
---
|
||||
|
||||
### Milestone 5 — Smart triggers (confidence model) + snooze/suppression ⚠️ DESIGNED, NOT INTEGRATED
|
||||
### Milestone 5 — Smart triggers (confidence model) + snooze/suppression ⚠️ PARTIALLY INTEGRATED
|
||||
|
||||
**Goal:** prompts that are helpful, not annoying.
|
||||
|
||||
@@ -197,29 +188,42 @@ Final segments are emitted and persisted. **Partial streaming is proto-defined b
|
||||
* `TriggerSignal`, `TriggerDecision`, `TriggerAction` (IGNORE, NOTIFY, AUTO_START)
|
||||
* [x] `SignalProvider` protocol defined
|
||||
* Location: `src/noteflow/domain/triggers/ports.py`
|
||||
* [ ] **Foreground app detector integration** ← GAP
|
||||
* Infrastructure exists: `src/noteflow/infrastructure/triggers/foreground_app.py`
|
||||
* Not wired to client
|
||||
* [ ] **Audio activity detector integration** ← GAP
|
||||
* Infrastructure exists: `src/noteflow/infrastructure/triggers/audio_activity.py`
|
||||
* Not wired to client
|
||||
* [x] Foreground app detector integration
|
||||
* Infrastructure: `src/noteflow/infrastructure/triggers/foreground_app.py`
|
||||
* Wired via `TriggerMixin`: `src/noteflow/client/_trigger_mixin.py`
|
||||
* [x] Audio activity detector integration
|
||||
* Infrastructure: `src/noteflow/infrastructure/triggers/audio_activity.py`
|
||||
* Wired via `TriggerMixin`: `src/noteflow/client/_trigger_mixin.py`
|
||||
* [ ] Optional calendar connector stub (disabled by default)
|
||||
* [ ] **Prompt notification + snooze + suppress per-app** ← GAP
|
||||
* Application logic complete in `TriggerService`
|
||||
* No UI integration (no tray prompts)
|
||||
* [x] Trigger prompts + snooze (AlertDialog, not system notifications)
|
||||
* `TriggerMixin._show_trigger_prompt()` displays AlertDialog
|
||||
* Snooze button integrated
|
||||
* Rate limiting active
|
||||
* [ ] **System tray integration** ← GAP
|
||||
* [ ] **Global hotkeys** ← GAP
|
||||
* [x] Settings for sensitivity and auto-start opt-in (in `TriggerService`)
|
||||
|
||||
**What exists (application layer complete):**
|
||||
**Current integration status:**
|
||||
|
||||
* Client app inherits from `TriggerMixin` (`app.py:65`)
|
||||
* Signal providers initialized in `_initialize_triggers()` method
|
||||
* Background trigger check loop runs via `_trigger_check_loop()`
|
||||
* Handles NOTIFY and AUTO_START actions
|
||||
* Prompts shown via Flet AlertDialog (not system notifications)
|
||||
|
||||
**What works:**
|
||||
* Confidence scoring with configurable thresholds (0.40 notify, 0.80 auto-start)
|
||||
* Rate limiting between triggers
|
||||
* Snooze functionality with remaining time tracking
|
||||
* Per-app suppression config
|
||||
* Foreground app detection (PyWinCtl)
|
||||
* Audio activity detection (RMS sliding window)
|
||||
|
||||
**Remaining work:**
|
||||
|
||||
1. **System Tray Integration** (New file: `src/noteflow/client/tray.py`)
|
||||
* Integrate pystray for minimize-to-tray
|
||||
* Show trigger prompts as notifications
|
||||
* Show trigger prompts as system notifications
|
||||
* Recording indicator icon
|
||||
* Complexity: Medium (spike validated in `spikes/spike_01_ui_tray_hotkeys/`)
|
||||
|
||||
@@ -227,14 +231,12 @@ Final segments are emitted and persisted. **Partial streaming is proto-defined b
|
||||
* Integrate pynput for start/stop/annotation hotkeys
|
||||
* Complexity: Medium (spike validated)
|
||||
|
||||
3. **Wire Signal Providers to Client**
|
||||
* Connect `AudioActivitySignalProvider` + `ForegroundAppSignalProvider` in `app.py`
|
||||
* Complexity: Medium
|
||||
|
||||
**Exit criteria:**
|
||||
|
||||
* [ ] Trigger prompts happen when expected and can be snoozed
|
||||
* [ ] Prompt rate-limited to prevent spam
|
||||
* [x] Trigger prompts happen when expected and can be snoozed
|
||||
* [x] Prompt rate-limited to prevent spam
|
||||
* [ ] System tray notifications (currently AlertDialog only)
|
||||
* [ ] Global hotkeys for quick actions
|
||||
|
||||
---
|
||||
|
||||
@@ -341,21 +343,64 @@ Final segments are emitted and persisted. **Partial streaming is proto-defined b
|
||||
|
||||
---
|
||||
|
||||
### Milestone 8 (Optional pre‑release) — Post-meeting anonymous diarization ❌ NOT STARTED
|
||||
### Milestone 8 (Optional pre‑release) — Post-meeting anonymous diarization ✅ COMPLETE
|
||||
|
||||
**Goal:** "Speaker A/B/C" best-effort labeling.
|
||||
|
||||
**Deliverables:**
|
||||
|
||||
* [ ] Background diarization job
|
||||
* [ ] Align speaker turns to transcript
|
||||
* [ ] UI display + rename speakers per meeting
|
||||
* [x] Diarization engine with streaming + offline modes
|
||||
* Location: `src/noteflow/infrastructure/diarization/engine.py` (315 lines)
|
||||
* Streaming: `diart` library for real-time processing
|
||||
* Offline: `pyannote.audio` for post-meeting refinement
|
||||
* Device support: auto, cpu, cuda, mps
|
||||
* [x] Speaker assignment logic
|
||||
* Location: `src/noteflow/infrastructure/diarization/assigner.py`
|
||||
* `assign_speaker()` maps time ranges via maximum overlap
|
||||
* `assign_speakers_batch()` for bulk assignment
|
||||
* Confidence scoring based on overlap duration
|
||||
* [x] Data transfer objects
|
||||
* Location: `src/noteflow/infrastructure/diarization/dto.py`
|
||||
* `SpeakerTurn` with validation and overlap methods
|
||||
* [x] Domain entity updates
|
||||
* `Segment.speaker_id: str | None` and `speaker_confidence: float`
|
||||
* [x] Proto/gRPC definitions
|
||||
* `FinalSegment.speaker_id` and `speaker_confidence` fields
|
||||
* `ServerInfo.diarization_enabled` and `diarization_ready` flags
|
||||
* `RefineSpeakerDiarization` and `RenameSpeaker` RPCs
|
||||
* [x] gRPC refinement RPC
|
||||
* `refine_speaker_diarization()` in `service.py` for post-meeting processing
|
||||
* `rename_speaker()` for user-friendly speaker labels
|
||||
* [x] Configuration/settings
|
||||
* `diarization_enabled`, `diarization_hf_token`, `diarization_device`
|
||||
* `diarization_streaming_latency`, `diarization_min/max_speakers`
|
||||
* [x] Dependencies added
|
||||
* Optional extra `[diarization]`: pyannote.audio, diart, torch
|
||||
* [x] UI display
|
||||
* Speaker labels with color coding in `transcript.py`
|
||||
* "Analyze Speakers" and "Rename Speakers" buttons in `meeting_library.py`
|
||||
* [x] Server initialization
|
||||
* `DiarizationEngine` wired in `server.py` with CLI args
|
||||
* `--diarization`, `--diarization-hf-token`, `--diarization-device` flags
|
||||
* [x] Client integration
|
||||
* `refine_speaker_diarization()` and `rename_speaker()` methods in `client.py`
|
||||
* `DiarizationResult` and `RenameSpeakerResult` DTOs
|
||||
* [x] Tests
|
||||
* 24 unit tests in `tests/infrastructure/test_diarization.py`
|
||||
* Covers `SpeakerTurn`, `assign_speaker()`, `assign_speakers_batch()`
|
||||
|
||||
**Status:** Not implemented. Marked as optional pre-release feature.
|
||||
**Deferred (optional future enhancement):**
|
||||
|
||||
**Exit criteria:**
|
||||
* [ ] **Streaming integration** - Real-time speaker labels during recording
|
||||
* Feed audio chunks to diarization during `StreamTranscription`
|
||||
* Emit speaker changes in real-time
|
||||
* Complexity: High (requires significant latency tuning)
|
||||
|
||||
* [ ] If diarization fails, app degrades gracefully to "Unknown."
|
||||
**Exit criteria:** ✅ ALL MET
|
||||
|
||||
* [x] If diarization fails, app degrades gracefully to "Unknown."
|
||||
* [x] Post-meeting diarization refinement works end-to-end
|
||||
* [ ] (Optional) Streaming diarization shows live speaker labels — deferred
|
||||
|
||||
---
|
||||
|
||||
@@ -899,12 +944,12 @@ class Job(Protocol):
|
||||
| M0 Spikes | ✅ Complete | 100% |
|
||||
| M1 Repo Foundation | ✅ Complete | 100% |
|
||||
| M2 Meeting Lifecycle | ✅ Complete | 100% |
|
||||
| M3 Transcription | ⚠️ Partial | 80% (finals done, partials not wired) |
|
||||
| M3 Transcription | ✅ Complete | 100% |
|
||||
| M4 Review UX | ✅ Complete | 100% |
|
||||
| M5 Triggers | ⚠️ Designed | 50% (application layer done, UI not) |
|
||||
| M5 Triggers | ⚠️ Partial | 70% (integrated via mixin, tray/hotkeys not) |
|
||||
| M6 Summarization | ✅ Complete | 100% |
|
||||
| M7 Packaging | ⚠️ Partial | 40% (retention done, packaging not) |
|
||||
| M8 Diarization | ❌ Not Started | 0% (optional) |
|
||||
| M8 Diarization | ⚠️ Partial | 55% (infrastructure done, wiring not) |
|
||||
|
||||
### Layer-by-Layer Status
|
||||
|
||||
@@ -926,24 +971,25 @@ class Job(Protocol):
|
||||
- [x] `ExportService` - Markdown, HTML
|
||||
- [x] `RecoveryService` - crash recovery
|
||||
|
||||
**Infrastructure Layer** ✅ 95%
|
||||
**Infrastructure Layer** ✅ 98%
|
||||
- [x] Audio: capture, ring buffer, levels, playback, encrypted writer/reader
|
||||
- [x] ASR: faster-whisper engine, VAD, segmenter
|
||||
- [x] Persistence: SQLAlchemy + pgvector, Alembic migrations
|
||||
- [x] Security: AES-256-GCM, keyring keystore
|
||||
- [x] Summarization: Mock, Ollama, Cloud providers + citation verifier
|
||||
- [x] Export: Markdown, HTML formatters
|
||||
- [ ] Triggers: signal providers exist but not integrated
|
||||
- [x] Triggers: signal providers wired via TriggerMixin
|
||||
- [x] Diarization: engine, assigner, DTOs (not wired to server)
|
||||
|
||||
**gRPC Layer** ✅ 95%
|
||||
**gRPC Layer** ✅ 100%
|
||||
- [x] Proto definitions with bidirectional streaming
|
||||
- [x] Server: StreamTranscription, CreateMeeting, StopMeeting, etc.
|
||||
- [x] Client wrapper with connection management
|
||||
- [x] Meeting store (in-memory + DB modes)
|
||||
- [x] GenerateSummary RPC wired to SummarizationService
|
||||
- [ ] Partial transcript streaming not emitted
|
||||
- [x] Partial transcript streaming (2-second cadence, deduplication)
|
||||
|
||||
**Client Layer** ✅ 80%
|
||||
**Client Layer** ✅ 85%
|
||||
- [x] Flet app with state management
|
||||
- [x] VU meter, recording timer, transcript
|
||||
- [x] Playback controls + sync controller
|
||||
@@ -951,6 +997,7 @@ class Job(Protocol):
|
||||
- [x] Meeting library
|
||||
- [x] Summary panel with clickable citations
|
||||
- [x] Connection panel with auto-reconnect
|
||||
- [x] Trigger detection via TriggerMixin (AlertDialog prompts)
|
||||
- [ ] System tray integration (spike validated, not integrated)
|
||||
- [ ] Global hotkeys (spike validated, not integrated)
|
||||
|
||||
@@ -958,23 +1005,25 @@ class Job(Protocol):
|
||||
|
||||
## 11) Remaining Work Summary
|
||||
|
||||
### High Priority (Core UX Gaps)
|
||||
|
||||
| # | Task | Files | Complexity | Blocker For |
|
||||
|---|------|-------|------------|-------------|
|
||||
| 1 | **Partial Transcript Streaming** | `src/noteflow/grpc/service.py` | Low | Real-time UX |
|
||||
| | Emit `UPDATE_TYPE_PARTIAL` during speech at 2-second cadence | | | |
|
||||
|
||||
### Medium Priority (Platform Features)
|
||||
|
||||
| # | Task | Files | Complexity | Blocker For |
|
||||
|---|------|-------|------------|-------------|
|
||||
| 3 | **System Tray Integration** | New: `src/noteflow/client/tray.py` | Medium | M5 triggers |
|
||||
| | Integrate pystray for minimize-to-tray, recording indicator | | | |
|
||||
| 4 | **Global Hotkeys** | New: `src/noteflow/client/hotkeys.py` | Medium | M5 triggers |
|
||||
| 1 | **System Tray Integration** | New: `src/noteflow/client/tray.py` | Medium | M5 completion |
|
||||
| | Integrate pystray for minimize-to-tray, system notifications, recording indicator | | | |
|
||||
| 2 | **Global Hotkeys** | New: `src/noteflow/client/hotkeys.py` | Medium | M5 completion |
|
||||
| | Integrate pynput for start/stop/annotation hotkeys | | | |
|
||||
| 5 | **Trigger Signal Integration** | `src/noteflow/client/app.py` | Medium | M5 completion |
|
||||
| | Wire AudioActivity + ForegroundApp signal providers | | | |
|
||||
|
||||
### Medium Priority (Diarization Wiring)
|
||||
|
||||
| # | Task | Files | Complexity | Blocker For |
|
||||
|---|------|-------|------------|-------------|
|
||||
| 3 | **Diarization Application Service** | New: `application/services/diarization_service.py` | Medium | M8 completion |
|
||||
| | Orchestrate diarization workflow, model management | | | |
|
||||
| 4 | **Diarization Server Wiring** | `src/noteflow/grpc/server.py` | Low | M8 completion |
|
||||
| | Initialize DiarizationEngine on startup when enabled | | | |
|
||||
| 5 | **Diarization Tests** | New: `tests/infrastructure/diarization/` | Medium | M8 stability |
|
||||
| | Unit tests for engine, assigner, DTOs | | | |
|
||||
|
||||
### Lower Priority (Shipping)
|
||||
|
||||
@@ -991,11 +1040,10 @@ class Job(Protocol):
|
||||
|
||||
### Recommended Implementation Order
|
||||
|
||||
1. **Partial Transcript Streaming** (Low effort, high impact on UX)
|
||||
2. **System Tray + Hotkeys** (Can be done in parallel)
|
||||
3. **Trigger Signal Integration** (Depends on tray)
|
||||
4. **PyInstaller Packaging** (Enables distribution)
|
||||
5. **Remaining M7 items** (Polish for release)
|
||||
1. **System Tray + Hotkeys** (Can be done in parallel, completes M5)
|
||||
2. **Diarization Wiring** (Server init + tests, completes M8 core)
|
||||
3. **PyInstaller Packaging** (Enables distribution)
|
||||
4. **Remaining M7 items** (Polish for release)
|
||||
|
||||
---
|
||||
|
||||
|
||||
265
docs/triage.md
Normal file
265
docs/triage.md
Normal file
@@ -0,0 +1,265 @@
|
||||
This is a comprehensive code review of the `NoteFlow` repository.
|
||||
|
||||
Overall, this codebase demonstrates a high level of engineering maturity. It effectively utilizes Clean Architecture concepts (Entities, Use Cases, Ports/Adapters), leveraging strong typing, Pydantic for validation, and SQLAlchemy/Alembic for persistence. The integration test setup using `testcontainers` is particularly robust.
|
||||
|
||||
However, there are critical performance bottlenecks regarding async/sync bridging in the ASR engine, potential concurrency issues in the UI state management, and specific security considerations regarding the encryption implementation.
|
||||
|
||||
Below is the review categorized into actionable feedback, formatted to be convertible into Git issues.
|
||||
|
||||
---
|
||||
|
||||
## 1. Critical Architecture & Performance Issues
|
||||
|
||||
### Issue 1: Blocking ASR Inference in Async gRPC Server
|
||||
**Severity:** Critical
|
||||
**Location:** `src/noteflow/grpc/service.py`, `src/noteflow/infrastructure/asr/engine.py`
|
||||
|
||||
**The Problem:**
|
||||
The `NoteFlowServer` uses `grpc.aio` (AsyncIO), but the `FasterWhisperEngine.transcribe` method is blocking (synchronous CPU-bound operation).
|
||||
In `NoteFlowServicer._maybe_emit_partial` and `_process_audio_segment`, the code calls:
|
||||
```python
|
||||
# src/noteflow/grpc/service.py
|
||||
partial_text = " ".join(result.text for result in self._asr_engine.transcribe(combined))
|
||||
```
|
||||
Since `transcribe` performs heavy computation, executing it directly within an `async def` method freezes the entire Python AsyncIO event loop. This blocks heartbeats, other RPC calls, and other concurrent meeting streams until inference completes.
|
||||
|
||||
**Actionable Solution:**
|
||||
Offload the transcription to a separate thread pool executor.
|
||||
|
||||
1. Modify `FasterWhisperEngine` to remain synchronous (it wraps CTranslate2 which releases the GIL often, but it is still blocking from an asyncio perspective).
|
||||
2. Update `NoteFlowServicer` to run transcription in an executor.
|
||||
|
||||
```python
|
||||
# In NoteFlowServicer
|
||||
from functools import partial
|
||||
|
||||
# Helper method
|
||||
async def _run_transcription(self, audio):
|
||||
loop = asyncio.get_running_loop()
|
||||
# Use a ThreadPoolExecutor specifically for compute-heavy tasks
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
partial(list, self._asr_engine.transcribe(audio))
|
||||
)
|
||||
|
||||
# Usage in _maybe_emit_partial
|
||||
results = await self._run_transcription(combined)
|
||||
partial_text = " ".join(r.text for r in results)
|
||||
```
|
||||
|
||||
### Issue 2: Synchronous `sounddevice` Callbacks in Async Client App
|
||||
**Severity:** High
|
||||
**Location:** `src/noteflow/infrastructure/audio/capture.py`
|
||||
|
||||
**The Problem:**
|
||||
The `sounddevice` library calls the python callback from a C-level background thread. In `SoundDeviceCapture._stream_callback`, you are invoking the user-provided callback:
|
||||
```python
|
||||
self._callback(audio_data, timestamp)
|
||||
```
|
||||
In `app.py`, this callback (`_on_audio_frames`) interacts with `self._audio_activity.update` and `self._client.send_audio`. While `queue.put` is thread-safe, any heavy logic or object allocation here happens in the real-time audio thread. If Python garbage collection pauses this thread, audio artifacts (dropouts) will occur.
|
||||
|
||||
**Actionable Solution:**
|
||||
The callback should strictly put bytes into a thread-safe queue and return immediately. A separate consumer thread/task should process the VAD, VU meter logic, and network sending.
|
||||
|
||||
### Issue 3: Encryption Key Material in Memory
|
||||
**Severity:** Medium
|
||||
**Location:** `src/noteflow/infrastructure/security/crypto.py`
|
||||
|
||||
**The Problem:**
|
||||
The `AesGcmCryptoBox` keeps the master key in memory via `_get_master_cipher`. While inevitable for operation, `secrets.token_bytes` creates immutable bytes objects which cannot be zeroed out (wiped) from memory when no longer needed. Python's GC handles cleanup, but the key lingers in RAM.
|
||||
|
||||
**Actionable Solution:**
|
||||
While strict memory zeroing is hard in Python, you should minimize the lifespan of the `dek` (Data Encryption Key).
|
||||
1. In `MeetingAudioWriter`, the `dek` is stored as an instance attribute: `self._dek`. This keeps the unencrypted key in memory for the duration of the meeting.
|
||||
2. Consider refactoring `ChunkedAssetWriter` to store the `cipher` object (the `AESGCM` context) rather than the raw bytes of the `dek` if the underlying C-library handles memory better, though strictly speaking, the key is still in RAM.
|
||||
3. **Critical:** Ensure `writer.close()` sets `self._dek = None` immediately (it currently does, which is good practice).
|
||||
|
||||
---
|
||||
|
||||
## 2. Domain & Infrastructure Logic
|
||||
|
||||
### Issue 4: Fallback Logic in `SummarizationService`
|
||||
**Severity:** Low
|
||||
**Location:** `src/noteflow/application/services/summarization_service.py`
|
||||
|
||||
**The Problem:**
|
||||
The method `_get_provider_with_fallback` iterates through a hardcoded `fallback_order = [SummarizationMode.LOCAL, SummarizationMode.MOCK]`. This ignores the configuration order or user preference if they added new providers.
|
||||
|
||||
**Actionable Solution:**
|
||||
Allow `SummarizationServiceSettings` to define a `fallback_chain: list[SummarizationMode]`.
|
||||
|
||||
### Issue 5: Race Condition in `MeetingStore` (In-Memory)
|
||||
**Severity:** Medium
|
||||
**Location:** `src/noteflow/grpc/meeting_store.py`
|
||||
|
||||
**The Problem:**
|
||||
The `MeetingStore` uses `threading.RLock`. However, the methods return the actual `Meeting` object reference.
|
||||
```python
|
||||
def get(self, meeting_id: str) -> Meeting | None:
|
||||
with self._lock:
|
||||
return self._meetings.get(meeting_id)
|
||||
```
|
||||
The caller gets a reference to the mutable `Meeting` entity. If two threads get the meeting and modify it (e.g., `meeting.state = ...`), the `MeetingStore` lock does nothing to protect the entity itself, only the dictionary lookups.
|
||||
|
||||
**Actionable Solution:**
|
||||
1. Return deep copies of the Meeting object (performance impact).
|
||||
2. Or, implement specific atomic update methods on the Store (e.g., `update_status(id, status)`), rather than returning the whole object for modification.
|
||||
|
||||
### Issue 6: `pgvector` Dependency Management
|
||||
**Severity:** Low
|
||||
**Location:** `src/noteflow/infrastructure/persistence/migrations/versions/6a9d9f408f40_initial_schema.py`
|
||||
|
||||
**The Problem:**
|
||||
The migration blindly executes `CREATE EXTENSION IF NOT EXISTS vector`. On managed database services (like RDS or standard Docker Postgres images), the user might not have superuser privileges to install extensions, or the extension binaries might be missing.
|
||||
|
||||
**Actionable Solution:**
|
||||
Wrap the extension creation in a try/catch block or check capabilities. For the integration tests, ensure the `pgvector/pgvector:pg16` image is strictly pinned (which you have done, good job).
|
||||
|
||||
---
|
||||
|
||||
## 3. Client & UI (Flet)
|
||||
|
||||
### Issue 7: Massive `app.py` File Size
|
||||
**Severity:** Medium
|
||||
**Location:** `src/noteflow/client/app.py`
|
||||
|
||||
**The Problem:**
|
||||
`app.py` is orchestrating too much. It handles UI layout, audio capture orchestration, gRPC client events, and state updates. It serves as a "God Class" Controller.
|
||||
|
||||
**Actionable Solution:**
|
||||
Refactor into a `ClientController` class separate from the UI layout construction.
|
||||
1. `src/noteflow/client/controller.py`: Handles `NoteFlowClient`, `SoundDeviceCapture`, and updates `AppState`.
|
||||
2. `src/noteflow/client/views.py`: Accepts `AppState` and renders UI.
|
||||
|
||||
### Issue 8: Re-rendering Efficiency in Transcript
|
||||
**Severity:** Medium
|
||||
**Location:** `src/noteflow/client/components/transcript.py`
|
||||
|
||||
**The Problem:**
|
||||
`_render_final_segment` appends controls to `self._list_view.controls`. In Flet, modifying a large list of controls can become slow as the transcript grows (hundreds of segments).
|
||||
|
||||
**Actionable Solution:**
|
||||
1. Implement a "virtualized" list or pagination if Flet supports it efficiently.
|
||||
2. If not, implement a sliding window rendering approach where only the last N segments + visible segments are rendered in the DOM, though this is complex in Flet.
|
||||
3. **Immediate fix:** Ensure `auto_scroll` is handled efficiently. The current implementation clears and re-adds specific rows during search, which is heavy.
|
||||
|
||||
---
|
||||
|
||||
## 4. Specific Code Feedback (Nitpicks & Bugs)
|
||||
|
||||
### 1. Hardcoded Audio Constants
|
||||
**File:** `src/noteflow/infrastructure/asr/segmenter.py`
|
||||
The `SegmenterConfig` defaults to `sample_rate=16000`.
|
||||
The `SoundDeviceCapture` defaults to `16000`.
|
||||
**Risk:** If the server is configured for 44.1kHz, the client currently defaults to 16kHz hardcoded in several places.
|
||||
**Fix:** Ensure `DEFAULT_SAMPLE_RATE` from `src/noteflow/config/constants.py` is used everywhere.
|
||||
|
||||
### 2. Exception Swallowing in Audio Writer
|
||||
**File:** `src/noteflow/grpc/service.py` -> `_write_audio_chunk_safe`
|
||||
```python
|
||||
except Exception as e:
|
||||
logger.error("Failed to write audio chunk: %s", e)
|
||||
```
|
||||
If the disk fills up or permissions change, the audio writer fails silently (just logging), but the meeting continues. The user might lose the audio recording entirely while thinking it's safe.
|
||||
**Fix:** This error should probably trigger a circuit breaker that stops the recording or notifies the client via a gRPC status update or a metadata stream update.
|
||||
|
||||
### 3. Trigger Service Rate Limiting Logic
|
||||
**File:** `src/noteflow/application/services/trigger_service.py`
|
||||
In `_determine_action`:
|
||||
```python
|
||||
if self._last_prompt is not None:
|
||||
elapsed = now - self._last_prompt
|
||||
if elapsed < self._settings.rate_limit_seconds:
|
||||
return TriggerAction.IGNORE
|
||||
```
|
||||
This logic ignores *all* triggers if within the rate limit. If a **high confidence** trigger (Auto-start) comes in 10 seconds after a low confidence prompt, it gets ignored.
|
||||
**Fix:** The rate limit should likely apply to `NOTIFY` actions, but `AUTO_START` might need to bypass the rate limit or have a shorter one.
|
||||
|
||||
### 4. Database Session Lifecycle in UoW
|
||||
**File:** `src/noteflow/infrastructure/persistence/unit_of_work.py`
|
||||
The `__init__` does not create the session, `__aenter__` does. This is correct. However, `SqlAlchemyUnitOfWork` caches repositories:
|
||||
```python
|
||||
self._annotations_repo = SqlAlchemyAnnotationRepository(self._session)
|
||||
```
|
||||
If `__aenter__` is called, `__aexit__` closes the session. If the same UoW instance is reused (calling `async with uow:` again), it creates a *new* session but overwrites the repo references. This is generally safe, but verify that `SqlAlchemyUnitOfWork` instances are intended to be reusable or disposable. Currently, they look reusable, which is fine.
|
||||
|
||||
### 5. Frontend Polling vs Events
|
||||
**File:** `src/noteflow/client/components/playback_sync.py`
|
||||
`POSITION_POLL_INTERVAL = 0.1`.
|
||||
Using a thread to poll `self._state.playback.current_position` every 100ms is CPU inefficient in Python (due to GIL).
|
||||
**Suggestion:** Use the `sounddevice` stream callback time info to update the position state only when audio is actually playing, rather than a separate `while True` loop.
|
||||
|
||||
---
|
||||
|
||||
## 5. Security Review
|
||||
|
||||
### 1. Keyring Headless Failure
|
||||
**File:** `src/noteflow/infrastructure/security/keystore.py`
|
||||
**Risk:** The app crashes if `keyring` cannot find a backend (common in Docker/Headless Linux servers).
|
||||
**Fix:**
|
||||
```python
|
||||
except keyring.errors.KeyringError:
|
||||
logger.warning("Keyring unavailable, falling back to environment variable or temporary key")
|
||||
# Implement a fallback strategy or explicit failure
|
||||
```
|
||||
Currently, it raises `RuntimeError`, which crashes the server startup.
|
||||
|
||||
### 2. DEK Handling
|
||||
**Analysis:** You generate a DEK, wrap it, and store `wrapped_dek` in the DB. The `dek` stays in memory during the stream.
|
||||
**Verdict:** This is standard envelope encryption practice. Acceptable for this application tier.
|
||||
|
||||
---
|
||||
|
||||
## 6. Generated Issues for Git
|
||||
|
||||
### Issue: Asynchronous Transcription Processing
|
||||
**Title:** Refactor ASR Engine to run in ThreadPoolExecutor
|
||||
**Description:**
|
||||
The gRPC server uses `asyncio`, but `FasterWhisperEngine.transcribe` is blocking. This freezes the event loop during transcription segments.
|
||||
**Task:**
|
||||
1. Inject `asyncio.get_running_loop()` into `NoteFlowServicer`.
|
||||
2. Wrap `self._asr_engine.transcribe` calls in `loop.run_in_executor`.
|
||||
|
||||
### Issue: Client Audio Callback Optimization
|
||||
**Title:** Optimize Audio Capture Callback
|
||||
**Description:**
|
||||
`SoundDeviceCapture` callback executes application logic (network sending, VAD updates) in the audio thread.
|
||||
**Task:**
|
||||
1. Change callback to only `queue.put_nowait()`.
|
||||
2. Move logic to a dedicated consumer worker thread.
|
||||
|
||||
### Issue: Handle Write Errors in Audio Stream
|
||||
**Title:** Critical Error Handling for Audio Writer
|
||||
**Description:**
|
||||
`_write_audio_chunk_safe` catches exceptions and logs them, potentially resulting in data loss without user feedback.
|
||||
**Task:**
|
||||
1. If writing fails, update the meeting state to `ERROR`.
|
||||
2. Send an error message back to the client via the Transcript stream if possible, or terminate the connection.
|
||||
|
||||
### Issue: Database Extension Installation Check
|
||||
**Title:** Graceful degradation for `pgvector`
|
||||
**Description:**
|
||||
Migration script `6a9d9f408f40` attempts to create an extension. This fails if the DB user isn't superuser.
|
||||
**Task:**
|
||||
1. Check if extension exists or if user has permissions.
|
||||
2. If not, fail with a clear message about required database setup steps.
|
||||
|
||||
### Issue: Foreground App Window Detection on Linux/Headless
|
||||
**Title:** Handle `pywinctl` dependencies
|
||||
**Description:**
|
||||
`pywinctl` requires X11/display headers on Linux. The server might run headless.
|
||||
**Task:**
|
||||
1. Wrap `ForegroundAppProvider` imports in try/except blocks.
|
||||
2. Ensure the app doesn't crash if `pywinctl` fails to load.
|
||||
|
||||
---
|
||||
|
||||
## 7. Packaging & Deployment (Future)
|
||||
|
||||
Since you mentioned packaging is a WIP:
|
||||
1. **Dependencies:** Separating `server` deps (torch, faster-whisper) from `client` deps (flet, sounddevice) is crucial. Use `pyproject.toml` extras: `pip install noteflow[server]` vs `noteflow[client]`.
|
||||
2. **Model Management:** The Docker image for the server will be huge due to Torch/Whisper. Consider a build stage that pre-downloads the "base" model so the container starts faster.
|
||||
|
||||
## Conclusion
|
||||
|
||||
The code is high quality, well-typed, and structurally sound. Fixing the **Blocking ASR** issue is the only mandatory change before any serious load testing or deployment. The rest are robustness and architectural improvements.
|
||||
@@ -116,6 +116,7 @@ asyncio_default_fixture_loop_scope = "function"
|
||||
markers = [
|
||||
"slow: marks tests as slow (model loading)",
|
||||
"integration: marks tests requiring external services",
|
||||
"stress: marks stress/concurrency tests",
|
||||
]
|
||||
filterwarnings = [
|
||||
"ignore:The @wait_container_is_ready decorator is deprecated.*:DeprecationWarning:testcontainers.core.waiting_utils",
|
||||
@@ -124,4 +125,5 @@ filterwarnings = [
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"ruff>=0.14.9",
|
||||
"watchfiles>=1.1.1",
|
||||
]
|
||||
|
||||
30841
repomix-output.md
Normal file
30841
repomix-output.md
Normal file
File diff suppressed because one or more lines are too long
41
repomix.config.json
Normal file
41
repomix.config.json
Normal file
@@ -0,0 +1,41 @@
|
||||
{
|
||||
"$schema": "https://repomix.com/schemas/latest/schema.json",
|
||||
"input": {
|
||||
"maxFileSize": 52428800
|
||||
},
|
||||
"output": {
|
||||
"filePath": "repomix-output.md",
|
||||
"style": "markdown",
|
||||
"parsableStyle": false,
|
||||
"fileSummary": true,
|
||||
"directoryStructure": true,
|
||||
"files": true,
|
||||
"removeComments": false,
|
||||
"removeEmptyLines": false,
|
||||
"compress": false,
|
||||
"topFilesLength": 5,
|
||||
"showLineNumbers": false,
|
||||
"truncateBase64": false,
|
||||
"copyToClipboard": false,
|
||||
"tokenCountTree": false,
|
||||
"git": {
|
||||
"sortByChanges": true,
|
||||
"sortByChangesMaxCommits": 100,
|
||||
"includeDiffs": false,
|
||||
"includeLogs": false,
|
||||
"includeLogsCount": 50
|
||||
}
|
||||
},
|
||||
"include": ["src/", "tests/"],
|
||||
"ignore": {
|
||||
"useGitignore": true,
|
||||
"useDefaultPatterns": true,
|
||||
"customPatterns": []
|
||||
},
|
||||
"security": {
|
||||
"enableSecurityCheck": true
|
||||
},
|
||||
"tokenCount": {
|
||||
"encoding": "o200k_base"
|
||||
}
|
||||
}
|
||||
34
scripts/dev_watch_server.py
Normal file
34
scripts/dev_watch_server.py
Normal file
@@ -0,0 +1,34 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Run the gRPC server with auto-reload.
|
||||
|
||||
Watches only the core server code (and alembic.ini) to avoid noisy directories.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from watchfiles import PythonFilter, run_process
|
||||
|
||||
|
||||
def run_server() -> None:
|
||||
"""Start the gRPC server process."""
|
||||
subprocess.run([sys.executable, "-m", "noteflow.grpc.server"], check=False)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
root = Path(__file__).resolve().parents[1]
|
||||
watch_paths = [root / "src" / "noteflow", root / "alembic.ini"]
|
||||
existing_paths = [str(path) for path in watch_paths if path.exists()] or [str(root / "src" / "noteflow")]
|
||||
|
||||
run_process(
|
||||
*existing_paths,
|
||||
target=run_server,
|
||||
watch_filter=PythonFilter(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -54,12 +54,12 @@ class SoundDeviceCapture:
|
||||
AudioDeviceInfo(
|
||||
device_id=idx,
|
||||
name=dev["name"],
|
||||
channels=dev["max_input_channels"],
|
||||
channels=int(dev["max_input_channels"]),
|
||||
sample_rate=int(dev["default_samplerate"]),
|
||||
is_default=(idx == default_input),
|
||||
)
|
||||
for idx, dev in enumerate(device_list)
|
||||
if dev["max_input_channels"] > 0
|
||||
if int(dev.get("max_input_channels", 0)) > 0
|
||||
)
|
||||
return devices
|
||||
|
||||
|
||||
@@ -144,19 +144,20 @@ class TriggerService:
|
||||
Returns:
|
||||
TriggerAction to take.
|
||||
"""
|
||||
# Check rate limit
|
||||
# Check threshold_ignore first
|
||||
if confidence < self._settings.threshold_ignore:
|
||||
return TriggerAction.IGNORE
|
||||
|
||||
# AUTO_START bypasses rate limit (high-confidence trigger should not be delayed)
|
||||
if confidence >= self._settings.threshold_auto_start and self._settings.auto_start_enabled:
|
||||
return TriggerAction.AUTO_START
|
||||
|
||||
# Rate limit applies only to NOTIFY actions
|
||||
if self._last_prompt is not None:
|
||||
elapsed = now - self._last_prompt
|
||||
if elapsed < self._settings.rate_limit_seconds:
|
||||
return TriggerAction.IGNORE
|
||||
|
||||
# Apply thresholds
|
||||
if confidence < self._settings.threshold_ignore:
|
||||
return TriggerAction.IGNORE
|
||||
|
||||
if confidence >= self._settings.threshold_auto_start and self._settings.auto_start_enabled:
|
||||
return TriggerAction.AUTO_START
|
||||
|
||||
return TriggerAction.NOTIFY
|
||||
|
||||
def _make_decision(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Trigger detection mixin for NoteFlow client.
|
||||
|
||||
Extracts trigger detection logic from app.py to keep file under 750 lines.
|
||||
Handles meeting detection triggers via audio activity and foreground app monitoring.
|
||||
Handles meeting detection triggers via app audio activity and calendar proximity.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -16,11 +16,12 @@ from noteflow.application.services import TriggerService, TriggerServiceSettings
|
||||
from noteflow.config.settings import TriggerSettings, get_trigger_settings
|
||||
from noteflow.domain.triggers import TriggerAction, TriggerDecision
|
||||
from noteflow.infrastructure.triggers import (
|
||||
AudioActivityProvider,
|
||||
AudioActivitySettings,
|
||||
ForegroundAppProvider,
|
||||
ForegroundAppSettings,
|
||||
AppAudioProvider,
|
||||
AppAudioSettings,
|
||||
CalendarProvider,
|
||||
CalendarSettings,
|
||||
)
|
||||
from noteflow.infrastructure.triggers.calendar import parse_calendar_events
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.client.state import AppState
|
||||
@@ -34,8 +35,8 @@ class TriggerHost(Protocol):
|
||||
_state: AppState
|
||||
_trigger_settings: TriggerSettings | None
|
||||
_trigger_service: TriggerService | None
|
||||
_audio_activity: AudioActivityProvider | None
|
||||
_foreground_app: ForegroundAppProvider | None
|
||||
_app_audio: AppAudioProvider | None
|
||||
_calendar_provider: CalendarProvider | None
|
||||
_trigger_poll_interval: float
|
||||
_trigger_task: asyncio.Task | None
|
||||
|
||||
@@ -59,8 +60,10 @@ class TriggerMixin:
|
||||
self._trigger_settings = get_trigger_settings()
|
||||
self._state.trigger_enabled = self._trigger_settings.trigger_enabled
|
||||
self._trigger_poll_interval = self._trigger_settings.trigger_poll_interval_seconds
|
||||
meeting_apps = {app.lower() for app in self._trigger_settings.trigger_meeting_apps}
|
||||
suppressed_apps = {app.lower() for app in self._trigger_settings.trigger_suppressed_apps}
|
||||
|
||||
audio_settings = AudioActivitySettings(
|
||||
app_audio_settings = AppAudioSettings(
|
||||
enabled=self._trigger_settings.trigger_audio_enabled,
|
||||
threshold_db=self._trigger_settings.trigger_audio_threshold_db,
|
||||
window_seconds=self._trigger_settings.trigger_audio_window_seconds,
|
||||
@@ -68,23 +71,21 @@ class TriggerMixin:
|
||||
min_samples=self._trigger_settings.trigger_audio_min_samples,
|
||||
max_history=self._trigger_settings.trigger_audio_max_history,
|
||||
weight=self._trigger_settings.trigger_weight_audio,
|
||||
)
|
||||
meeting_apps = {app.lower() for app in self._trigger_settings.trigger_meeting_apps}
|
||||
suppressed_apps = {app.lower() for app in self._trigger_settings.trigger_suppressed_apps}
|
||||
foreground_settings = ForegroundAppSettings(
|
||||
enabled=self._trigger_settings.trigger_foreground_enabled,
|
||||
weight=self._trigger_settings.trigger_weight_foreground,
|
||||
meeting_apps=meeting_apps,
|
||||
suppressed_apps=suppressed_apps,
|
||||
)
|
||||
|
||||
self._audio_activity = AudioActivityProvider(
|
||||
self._state.level_provider,
|
||||
audio_settings,
|
||||
calendar_settings = CalendarSettings(
|
||||
enabled=self._trigger_settings.trigger_calendar_enabled,
|
||||
weight=self._trigger_settings.trigger_weight_calendar,
|
||||
lookahead_minutes=self._trigger_settings.trigger_calendar_lookahead_minutes,
|
||||
lookbehind_minutes=self._trigger_settings.trigger_calendar_lookbehind_minutes,
|
||||
events=parse_calendar_events(self._trigger_settings.trigger_calendar_events),
|
||||
)
|
||||
self._foreground_app = ForegroundAppProvider(foreground_settings)
|
||||
|
||||
self._app_audio = AppAudioProvider(app_audio_settings)
|
||||
self._calendar_provider = CalendarProvider(calendar_settings)
|
||||
self._trigger_service = TriggerService(
|
||||
providers=[self._audio_activity, self._foreground_app],
|
||||
providers=[self._app_audio, self._calendar_provider],
|
||||
settings=TriggerServiceSettings(
|
||||
enabled=self._trigger_settings.trigger_enabled,
|
||||
auto_start_enabled=self._trigger_settings.trigger_auto_start,
|
||||
@@ -97,11 +98,7 @@ class TriggerMixin:
|
||||
|
||||
def _should_keep_capture_running(self: TriggerHost) -> bool:
|
||||
"""Return True if background audio capture should remain active."""
|
||||
if not self._trigger_settings:
|
||||
return False
|
||||
return (
|
||||
self._trigger_settings.trigger_enabled and self._trigger_settings.trigger_audio_enabled
|
||||
)
|
||||
return False
|
||||
|
||||
async def _trigger_check_loop(self: TriggerHost) -> None:
|
||||
"""Background loop to check trigger conditions.
|
||||
@@ -121,10 +118,6 @@ class TriggerMixin:
|
||||
if not self._state.trigger_enabled or not self._trigger_service:
|
||||
continue
|
||||
|
||||
# Start background audio capture only when needed for triggers
|
||||
if self._should_keep_capture_running():
|
||||
self._ensure_audio_capture()
|
||||
|
||||
# Evaluate triggers
|
||||
decision = self._trigger_service.evaluate()
|
||||
self._state.trigger_decision = decision
|
||||
|
||||
@@ -9,6 +9,8 @@ from __future__ import annotations
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Final
|
||||
|
||||
@@ -29,6 +31,7 @@ from noteflow.client.components import (
|
||||
VuMeterComponent,
|
||||
)
|
||||
from noteflow.client.state import AppState
|
||||
from noteflow.config.constants import DEFAULT_SAMPLE_RATE
|
||||
from noteflow.config.settings import TriggerSettings, get_settings
|
||||
from noteflow.infrastructure.audio import (
|
||||
MeetingAudioReader,
|
||||
@@ -38,7 +41,6 @@ from noteflow.infrastructure.audio import (
|
||||
)
|
||||
from noteflow.infrastructure.security import AesGcmCryptoBox, KeyringKeyStore
|
||||
from noteflow.infrastructure.summarization import create_summarization_service
|
||||
from noteflow.infrastructure.triggers import AudioActivityProvider, ForegroundAppProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
@@ -52,6 +54,7 @@ if TYPE_CHECKING:
|
||||
ServerInfo,
|
||||
TranscriptSegment,
|
||||
)
|
||||
from noteflow.infrastructure.triggers import AppAudioProvider, CalendarProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -105,8 +108,8 @@ class NoteFlowClientApp(TriggerMixin):
|
||||
# Trigger detection (M5)
|
||||
self._trigger_settings: TriggerSettings | None = None
|
||||
self._trigger_service: TriggerService | None = None
|
||||
self._audio_activity: AudioActivityProvider | None = None
|
||||
self._foreground_app: ForegroundAppProvider | None = None
|
||||
self._app_audio: AppAudioProvider | None = None
|
||||
self._calendar_provider: CalendarProvider | None = None
|
||||
self._trigger_poll_interval: float = 0.0
|
||||
self._trigger_task: asyncio.Task | None = None
|
||||
|
||||
@@ -114,6 +117,11 @@ class NoteFlowClientApp(TriggerMixin):
|
||||
self._record_btn: ft.ElevatedButton | None = None
|
||||
self._stop_btn: ft.ElevatedButton | None = None
|
||||
|
||||
# Audio frame consumer thread (process frames from audio callback thread)
|
||||
self._audio_frame_queue: queue.Queue[tuple[NDArray[np.float32], float]] = queue.Queue()
|
||||
self._audio_consumer_stop = threading.Event()
|
||||
self._audio_consumer_thread: threading.Thread | None = None
|
||||
|
||||
def run(self) -> None:
|
||||
"""Run the Flet application."""
|
||||
ft.app(target=self._main)
|
||||
@@ -256,8 +264,8 @@ class NoteFlowClientApp(TriggerMixin):
|
||||
keystore = KeyringKeyStore()
|
||||
crypto = AesGcmCryptoBox(keystore)
|
||||
self._audio_reader = MeetingAudioReader(crypto, settings.meetings_dir)
|
||||
except Exception:
|
||||
logger.exception("Failed to initialize meeting audio reader")
|
||||
except (OSError, ValueError, KeyError, RuntimeError) as exc:
|
||||
logger.exception("Failed to initialize meeting audio reader: %s", exc)
|
||||
self._audio_reader = None
|
||||
|
||||
return self._audio_reader
|
||||
@@ -276,8 +284,8 @@ class NoteFlowClientApp(TriggerMixin):
|
||||
except FileNotFoundError:
|
||||
logger.info("Audio file missing for meeting %s", meeting.id)
|
||||
return []
|
||||
except Exception:
|
||||
logger.exception("Failed to load audio for meeting %s", meeting.id)
|
||||
except (OSError, ValueError, RuntimeError) as exc:
|
||||
logger.exception("Failed to load audio for meeting %s: %s", meeting.id, exc)
|
||||
return []
|
||||
|
||||
def _ensure_audio_capture(self) -> bool:
|
||||
@@ -294,12 +302,12 @@ class NoteFlowClientApp(TriggerMixin):
|
||||
self._audio_capture.start(
|
||||
device_id=None,
|
||||
on_frames=self._on_audio_frames,
|
||||
sample_rate=16000,
|
||||
sample_rate=DEFAULT_SAMPLE_RATE,
|
||||
channels=1,
|
||||
chunk_duration_ms=100,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to start audio capture")
|
||||
except (RuntimeError, OSError) as exc:
|
||||
logger.exception("Failed to start audio capture: %s", exc)
|
||||
self._audio_capture = None
|
||||
return False
|
||||
|
||||
@@ -404,6 +412,9 @@ class NoteFlowClientApp(TriggerMixin):
|
||||
|
||||
self._state.recording = True
|
||||
|
||||
# Start audio frame consumer thread
|
||||
self._start_audio_consumer()
|
||||
|
||||
# Clear audio buffer for new recording
|
||||
self._state.session_audio_buffer.clear()
|
||||
|
||||
@@ -424,7 +435,10 @@ class NoteFlowClientApp(TriggerMixin):
|
||||
|
||||
def _stop_recording(self) -> None:
|
||||
"""Stop recording audio."""
|
||||
# Stop audio capture first
|
||||
# Stop audio frame consumer thread
|
||||
self._stop_audio_consumer()
|
||||
|
||||
# Stop audio capture
|
||||
if self._audio_capture and not self._should_keep_capture_running():
|
||||
self._audio_capture.stop()
|
||||
self._audio_capture = None
|
||||
@@ -467,7 +481,57 @@ class NoteFlowClientApp(TriggerMixin):
|
||||
frames: NDArray[np.float32],
|
||||
timestamp: float,
|
||||
) -> None:
|
||||
"""Handle audio frames from capture.
|
||||
"""Handle audio frames from capture (called from audio thread).
|
||||
|
||||
Enqueues frames for processing by consumer thread to avoid blocking
|
||||
the real-time audio callback.
|
||||
|
||||
Args:
|
||||
frames: Audio samples.
|
||||
timestamp: Capture timestamp.
|
||||
"""
|
||||
self._audio_frame_queue.put_nowait((frames.copy(), timestamp))
|
||||
|
||||
def _start_audio_consumer(self) -> None:
|
||||
"""Start the audio frame consumer thread."""
|
||||
if self._audio_consumer_thread is not None and self._audio_consumer_thread.is_alive():
|
||||
return
|
||||
self._audio_consumer_stop.clear()
|
||||
self._audio_consumer_thread = threading.Thread(
|
||||
target=self._audio_consumer_loop,
|
||||
daemon=True,
|
||||
name="audio-consumer",
|
||||
)
|
||||
self._audio_consumer_thread.start()
|
||||
|
||||
def _stop_audio_consumer(self) -> None:
|
||||
"""Stop the audio frame consumer thread."""
|
||||
self._audio_consumer_stop.set()
|
||||
if self._audio_consumer_thread is not None:
|
||||
self._audio_consumer_thread.join(timeout=1.0)
|
||||
self._audio_consumer_thread = None
|
||||
# Drain remaining frames
|
||||
while not self._audio_frame_queue.empty():
|
||||
try:
|
||||
self._audio_frame_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
def _audio_consumer_loop(self) -> None:
|
||||
"""Consumer loop that processes audio frames from the queue."""
|
||||
while not self._audio_consumer_stop.is_set():
|
||||
try:
|
||||
frames, timestamp = self._audio_frame_queue.get(timeout=0.1)
|
||||
self._process_audio_frames(frames, timestamp)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
def _process_audio_frames(
|
||||
self,
|
||||
frames: NDArray[np.float32],
|
||||
timestamp: float,
|
||||
) -> None:
|
||||
"""Process audio frames from consumer thread.
|
||||
|
||||
Args:
|
||||
frames: Audio samples.
|
||||
@@ -477,20 +541,18 @@ class NoteFlowClientApp(TriggerMixin):
|
||||
if self._client and self._state.recording:
|
||||
self._client.send_audio(frames, timestamp)
|
||||
|
||||
# Buffer for playback (estimate duration from chunk size)
|
||||
# Buffer for playback
|
||||
if self._state.recording:
|
||||
duration = len(frames) / 16000.0 # Sample rate is 16kHz
|
||||
duration = len(frames) / DEFAULT_SAMPLE_RATE
|
||||
self._state.session_audio_buffer.append(
|
||||
TimestampedAudio(frames=frames.copy(), timestamp=timestamp, duration=duration)
|
||||
TimestampedAudio(frames=frames, timestamp=timestamp, duration=duration)
|
||||
)
|
||||
|
||||
# Update VU meter
|
||||
if self._vu_meter:
|
||||
self._vu_meter.on_audio_frames(frames)
|
||||
|
||||
# Feed audio activity provider for trigger detection
|
||||
if self._audio_activity:
|
||||
self._audio_activity.update(frames, timestamp)
|
||||
# Trigger detection uses system output + calendar; no mic-derived updates here.
|
||||
|
||||
def _on_segment_click(self, segment_index: int) -> None:
|
||||
"""Handle transcript segment click - seek playback to segment.
|
||||
@@ -549,8 +611,8 @@ class NoteFlowClientApp(TriggerMixin):
|
||||
segments = client.get_meeting_segments(meeting.id)
|
||||
annotations = client.list_annotations(meeting.id)
|
||||
audio_chunks = self._load_meeting_audio(meeting)
|
||||
except Exception:
|
||||
logger.exception("Failed to load meeting: %s", meeting.id)
|
||||
except (ConnectionError, ValueError, OSError, RuntimeError) as exc:
|
||||
logger.exception("Failed to load meeting %s: %s", meeting.id, exc)
|
||||
return
|
||||
|
||||
# Apply results on UI thread to avoid race conditions
|
||||
@@ -667,10 +729,16 @@ class NoteFlowClientApp(TriggerMixin):
|
||||
self._trigger_task.cancel()
|
||||
self._trigger_task = None
|
||||
|
||||
# Stop audio consumer if running
|
||||
self._stop_audio_consumer()
|
||||
|
||||
if self._app_audio:
|
||||
self._app_audio.close()
|
||||
|
||||
if self._audio_capture and not self._state.recording:
|
||||
try:
|
||||
self._audio_capture.stop()
|
||||
except Exception:
|
||||
except RuntimeError:
|
||||
logger.debug("Error stopping audio capture during shutdown", exc_info=True)
|
||||
self._audio_capture = None
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ Does not recreate any types - imports and uses existing ones.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -29,6 +31,8 @@ class MeetingLibraryComponent:
|
||||
Uses NoteFlowClient.list_meetings() and export_transcript() for data.
|
||||
"""
|
||||
|
||||
DIARIZATION_POLL_INTERVAL_SECONDS: float = 2.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state: AppState,
|
||||
@@ -50,6 +54,8 @@ class MeetingLibraryComponent:
|
||||
self._search_field: ft.TextField | None = None
|
||||
self._list_view: ft.ListView | None = None
|
||||
self._export_btn: ft.ElevatedButton | None = None
|
||||
self._analyze_btn: ft.ElevatedButton | None = None
|
||||
self._rename_btn: ft.ElevatedButton | None = None
|
||||
self._refresh_btn: ft.IconButton | None = None
|
||||
self._column: ft.Column | None = None
|
||||
|
||||
@@ -57,6 +63,14 @@ class MeetingLibraryComponent:
|
||||
self._export_dialog: ft.AlertDialog | None = None
|
||||
self._format_dropdown: ft.Dropdown | None = None
|
||||
|
||||
# Analyze speakers dialog
|
||||
self._analyze_dialog: ft.AlertDialog | None = None
|
||||
self._num_speakers_field: ft.TextField | None = None
|
||||
|
||||
# Rename speakers dialog
|
||||
self._rename_dialog: ft.AlertDialog | None = None
|
||||
self._rename_fields: dict[str, ft.TextField] = {}
|
||||
|
||||
def build(self) -> ft.Column:
|
||||
"""Build meeting library UI.
|
||||
|
||||
@@ -80,6 +94,18 @@ class MeetingLibraryComponent:
|
||||
on_click=self._show_export_dialog,
|
||||
disabled=True,
|
||||
)
|
||||
self._analyze_btn = ft.ElevatedButton(
|
||||
"Refine Speakers",
|
||||
icon=ft.Icons.RECORD_VOICE_OVER,
|
||||
on_click=self._show_analyze_dialog,
|
||||
disabled=True,
|
||||
)
|
||||
self._rename_btn = ft.ElevatedButton(
|
||||
"Rename Speakers",
|
||||
icon=ft.Icons.EDIT,
|
||||
on_click=self._show_rename_dialog,
|
||||
disabled=True,
|
||||
)
|
||||
|
||||
self._list_view = ft.ListView(
|
||||
spacing=5,
|
||||
@@ -95,7 +121,11 @@ class MeetingLibraryComponent:
|
||||
border=ft.border.all(1, ft.Colors.GREY_400),
|
||||
border_radius=8,
|
||||
),
|
||||
ft.Row([self._export_btn], alignment=ft.MainAxisAlignment.END),
|
||||
ft.Row(
|
||||
[self._analyze_btn, self._rename_btn, self._export_btn],
|
||||
alignment=ft.MainAxisAlignment.END,
|
||||
spacing=10,
|
||||
),
|
||||
],
|
||||
spacing=10,
|
||||
)
|
||||
@@ -192,9 +222,13 @@ class MeetingLibraryComponent:
|
||||
"""
|
||||
self._state.selected_meeting = meeting
|
||||
|
||||
# Enable export button
|
||||
# Enable action buttons
|
||||
if self._export_btn:
|
||||
self._export_btn.disabled = False
|
||||
if self._analyze_btn:
|
||||
self._analyze_btn.disabled = not self._can_refine_speakers(meeting)
|
||||
if self._rename_btn:
|
||||
self._rename_btn.disabled = not self._can_refine_speakers(meeting)
|
||||
|
||||
# Re-render to update selection
|
||||
self._render_meetings()
|
||||
@@ -304,3 +338,437 @@ class MeetingLibraryComponent:
|
||||
file_name=filename,
|
||||
allowed_extensions=[extension],
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Speaker Refinement Methods
|
||||
# =========================================================================
|
||||
|
||||
def _show_analyze_dialog(self, e: ft.ControlEvent) -> None:
|
||||
"""Show speaker refinement dialog."""
|
||||
if not self._state.selected_meeting:
|
||||
return
|
||||
|
||||
if not self._can_refine_speakers(self._state.selected_meeting):
|
||||
self._show_simple_dialog(
|
||||
"Meeting still active",
|
||||
ft.Text("Stop the meeting before refining speakers."),
|
||||
)
|
||||
return
|
||||
|
||||
self._num_speakers_field = ft.TextField(
|
||||
label="Number of speakers (optional)",
|
||||
hint_text="Leave empty for auto-detect",
|
||||
width=200,
|
||||
keyboard_type=ft.KeyboardType.NUMBER,
|
||||
)
|
||||
|
||||
self._analyze_dialog = ft.AlertDialog(
|
||||
title=ft.Text("Refine Speakers"),
|
||||
content=ft.Column(
|
||||
[
|
||||
ft.Text(f"Meeting: {self._state.selected_meeting.title}"),
|
||||
ft.Text(
|
||||
"Refine speaker labels using offline diarization.",
|
||||
size=12,
|
||||
color=ft.Colors.GREY_600,
|
||||
),
|
||||
self._num_speakers_field,
|
||||
],
|
||||
spacing=10,
|
||||
tight=True,
|
||||
),
|
||||
actions=[
|
||||
ft.TextButton("Cancel", on_click=self._close_analyze_dialog),
|
||||
ft.ElevatedButton("Analyze", on_click=self._do_analyze),
|
||||
],
|
||||
actions_alignment=ft.MainAxisAlignment.END,
|
||||
)
|
||||
|
||||
if self._state._page:
|
||||
self._state._page.dialog = self._analyze_dialog
|
||||
self._analyze_dialog.open = True
|
||||
self._state.request_update()
|
||||
|
||||
def _close_analyze_dialog(self, e: ft.ControlEvent | None = None) -> None:
|
||||
"""Close the analyze dialog."""
|
||||
if self._analyze_dialog:
|
||||
self._analyze_dialog.open = False
|
||||
self._state.request_update()
|
||||
|
||||
def _do_analyze(self, e: ft.ControlEvent) -> None:
|
||||
"""Perform speaker analysis."""
|
||||
if not self._state.selected_meeting:
|
||||
return
|
||||
|
||||
# Parse number of speakers (optional)
|
||||
num_speakers: int | None = None
|
||||
if self._num_speakers_field and self._num_speakers_field.value:
|
||||
try:
|
||||
num_speakers = int(self._num_speakers_field.value)
|
||||
if num_speakers < 1:
|
||||
num_speakers = None
|
||||
except ValueError:
|
||||
logger.debug("Invalid speaker count input '%s', using auto-detection", self._num_speakers_field.value)
|
||||
|
||||
meeting_id = self._state.selected_meeting.id
|
||||
self._close_analyze_dialog()
|
||||
|
||||
client = self._get_client()
|
||||
if not client:
|
||||
logger.warning("No gRPC client available for analysis")
|
||||
return
|
||||
|
||||
# Show progress indicator
|
||||
self._show_analysis_progress("Starting...")
|
||||
|
||||
try:
|
||||
result = client.refine_speaker_diarization(meeting_id, num_speakers)
|
||||
except Exception as exc:
|
||||
logger.error("Error analyzing speakers: %s", exc)
|
||||
self._show_analysis_error(str(exc))
|
||||
return
|
||||
|
||||
if not result:
|
||||
self._show_analysis_error("Analysis failed - no response from server")
|
||||
return
|
||||
|
||||
if result.is_terminal:
|
||||
if result.success:
|
||||
self._show_analysis_result(result.segments_updated, result.speaker_ids)
|
||||
else:
|
||||
self._show_analysis_error(result.error_message or "Analysis failed")
|
||||
return
|
||||
|
||||
if not result.job_id:
|
||||
self._show_analysis_error(result.error_message or "Server did not return job ID")
|
||||
return
|
||||
|
||||
# Job queued/running - poll for completion
|
||||
self._show_analysis_progress(self._format_job_status(result.status))
|
||||
self._start_diarization_poll(result.job_id)
|
||||
|
||||
def _show_analysis_progress(self, status: str = "Refining...") -> None:
|
||||
"""Show refinement in progress indicator."""
|
||||
if self._analyze_btn:
|
||||
self._analyze_btn.disabled = True
|
||||
self._analyze_btn.text = status
|
||||
self._state.request_update()
|
||||
|
||||
def _show_analysis_result(self, segments_updated: int, speaker_ids: list[str]) -> None:
|
||||
"""Show refinement success result.
|
||||
|
||||
Args:
|
||||
segments_updated: Number of segments with speaker labels.
|
||||
speaker_ids: List of detected speaker IDs.
|
||||
"""
|
||||
if self._analyze_btn:
|
||||
self._analyze_btn.disabled = False
|
||||
self._analyze_btn.text = "Refine Speakers"
|
||||
|
||||
speaker_list = ", ".join(speaker_ids) if speaker_ids else "None found"
|
||||
|
||||
result_dialog = ft.AlertDialog(
|
||||
title=ft.Text("Refinement Complete"),
|
||||
content=ft.Column(
|
||||
[
|
||||
ft.Text(f"Segments updated: {segments_updated}"),
|
||||
ft.Text(f"Speakers found: {speaker_list}"),
|
||||
ft.Text(
|
||||
"Reload the meeting to see speaker labels.",
|
||||
size=12,
|
||||
color=ft.Colors.GREY_600,
|
||||
italic=True,
|
||||
),
|
||||
],
|
||||
spacing=5,
|
||||
tight=True,
|
||||
),
|
||||
actions=[ft.TextButton("OK", on_click=lambda e: self._close_result_dialog(e))],
|
||||
)
|
||||
|
||||
if self._state._page:
|
||||
self._state._page.dialog = result_dialog
|
||||
result_dialog.open = True
|
||||
self._state.request_update()
|
||||
|
||||
def _show_analysis_error(self, error_message: str) -> None:
|
||||
"""Show analysis error.
|
||||
|
||||
Args:
|
||||
error_message: Error description.
|
||||
"""
|
||||
if self._analyze_btn:
|
||||
self._analyze_btn.disabled = False
|
||||
self._analyze_btn.text = "Refine Speakers"
|
||||
self._show_simple_dialog("Refinement Failed", ft.Text(error_message))
|
||||
|
||||
def _close_result_dialog(self, e: ft.ControlEvent) -> None:
|
||||
"""Close any result dialog."""
|
||||
if self._state._page and self._state._page.dialog:
|
||||
self._state._page.dialog.open = False
|
||||
self._state.request_update()
|
||||
|
||||
def _start_diarization_poll(self, job_id: str) -> None:
|
||||
"""Start polling for diarization job completion."""
|
||||
page = self._state._page
|
||||
if page and hasattr(page, "run_thread"):
|
||||
page.run_thread(lambda: self._poll_diarization_job(job_id))
|
||||
return
|
||||
|
||||
threading.Thread(
|
||||
target=self._poll_diarization_job,
|
||||
args=(job_id,),
|
||||
daemon=True,
|
||||
name="diarization-poll",
|
||||
).start()
|
||||
|
||||
def _poll_diarization_job(self, job_id: str) -> None:
|
||||
"""Poll background diarization job until completion."""
|
||||
client = self._get_client()
|
||||
if not client:
|
||||
self._state.run_on_ui_thread(
|
||||
lambda: self._show_analysis_error("No gRPC client available for polling")
|
||||
)
|
||||
return
|
||||
|
||||
while True:
|
||||
result = client.get_diarization_job_status(job_id)
|
||||
if not result:
|
||||
self._state.run_on_ui_thread(
|
||||
lambda: self._show_analysis_error("Failed to fetch diarization status")
|
||||
)
|
||||
return
|
||||
|
||||
if result.is_terminal:
|
||||
if result.success:
|
||||
self._state.run_on_ui_thread(
|
||||
lambda r=result: self._show_analysis_result(
|
||||
r.segments_updated,
|
||||
r.speaker_ids,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._state.run_on_ui_thread(
|
||||
lambda r=result: self._show_analysis_error(
|
||||
r.error_message or "Diarization failed"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Update status text while running
|
||||
self._state.run_on_ui_thread(
|
||||
lambda r=result: self._show_analysis_progress(self._format_job_status(r.status))
|
||||
)
|
||||
time.sleep(self.DIARIZATION_POLL_INTERVAL_SECONDS)
|
||||
|
||||
@staticmethod
|
||||
def _format_job_status(status: str) -> str:
|
||||
"""Format job status for button label."""
|
||||
return {
|
||||
"queued": "Queued...",
|
||||
"running": "Refining...",
|
||||
}.get(status, "Refining...")
|
||||
|
||||
def _show_simple_dialog(self, title: str, content: ft.Control) -> None:
|
||||
"""Show a simple dialog with title, content, and OK button.
|
||||
|
||||
Args:
|
||||
title: Dialog title.
|
||||
content: Dialog content control.
|
||||
"""
|
||||
dialog = ft.AlertDialog(
|
||||
title=ft.Text(title),
|
||||
content=content,
|
||||
actions=[ft.TextButton("OK", on_click=self._close_result_dialog)],
|
||||
)
|
||||
if self._state._page:
|
||||
self._state._page.dialog = dialog
|
||||
dialog.open = True
|
||||
self._state.request_update()
|
||||
|
||||
# =========================================================================
|
||||
# Speaker Rename Methods
|
||||
# =========================================================================
|
||||
|
||||
def _show_rename_dialog(self, e: ft.ControlEvent) -> None:
|
||||
"""Show speaker rename dialog with current speaker IDs."""
|
||||
if not self._state.selected_meeting:
|
||||
return
|
||||
|
||||
if not self._can_refine_speakers(self._state.selected_meeting):
|
||||
self._show_simple_dialog(
|
||||
"Meeting still active",
|
||||
ft.Text("Stop the meeting before renaming speakers."),
|
||||
)
|
||||
return
|
||||
|
||||
client = self._get_client()
|
||||
if not client:
|
||||
logger.warning("No gRPC client available")
|
||||
return
|
||||
|
||||
# Get segments to extract distinct speaker IDs
|
||||
meeting_id = self._state.selected_meeting.id
|
||||
segments = client.get_meeting_segments(meeting_id)
|
||||
|
||||
# Extract distinct speaker IDs
|
||||
speaker_ids = sorted({s.speaker_id for s in segments if s.speaker_id})
|
||||
|
||||
if not speaker_ids:
|
||||
self._show_no_speakers_message()
|
||||
return
|
||||
|
||||
# Create text fields for each speaker
|
||||
self._rename_fields.clear()
|
||||
speaker_controls: list[ft.Control] = []
|
||||
|
||||
for speaker_id in speaker_ids:
|
||||
field = ft.TextField(
|
||||
label=f"{speaker_id}",
|
||||
hint_text="Enter new name",
|
||||
width=200,
|
||||
)
|
||||
self._rename_fields[speaker_id] = field
|
||||
speaker_controls.append(
|
||||
ft.Row(
|
||||
[
|
||||
ft.Text(speaker_id, width=120, size=12),
|
||||
ft.Icon(ft.Icons.ARROW_RIGHT, size=16),
|
||||
field,
|
||||
],
|
||||
alignment=ft.MainAxisAlignment.START,
|
||||
)
|
||||
)
|
||||
|
||||
self._rename_dialog = ft.AlertDialog(
|
||||
title=ft.Text("Rename Speakers"),
|
||||
content=ft.Column(
|
||||
[
|
||||
ft.Text(f"Meeting: {self._state.selected_meeting.title}"),
|
||||
ft.Text(
|
||||
"Enter new names for speakers (leave blank to keep current):",
|
||||
size=12,
|
||||
color=ft.Colors.GREY_600,
|
||||
),
|
||||
ft.Divider(),
|
||||
*speaker_controls,
|
||||
],
|
||||
spacing=10,
|
||||
scroll=ft.ScrollMode.AUTO,
|
||||
height=300,
|
||||
),
|
||||
actions=[
|
||||
ft.TextButton("Cancel", on_click=self._close_rename_dialog),
|
||||
ft.ElevatedButton("Apply", on_click=self._do_rename),
|
||||
],
|
||||
actions_alignment=ft.MainAxisAlignment.END,
|
||||
)
|
||||
|
||||
if self._state._page:
|
||||
self._state._page.dialog = self._rename_dialog
|
||||
self._rename_dialog.open = True
|
||||
self._state.request_update()
|
||||
|
||||
def _close_rename_dialog(self, e: ft.ControlEvent | None = None) -> None:
|
||||
"""Close the rename dialog."""
|
||||
if self._rename_dialog:
|
||||
self._rename_dialog.open = False
|
||||
self._state.request_update()
|
||||
|
||||
def _show_no_speakers_message(self) -> None:
|
||||
"""Show message when no speakers found."""
|
||||
self._show_simple_dialog(
|
||||
"No Speakers Found",
|
||||
ft.Text(
|
||||
"This meeting has no speaker labels. "
|
||||
"Run 'Refine Speakers' first to identify speakers."
|
||||
),
|
||||
)
|
||||
|
||||
def _do_rename(self, e: ft.ControlEvent) -> None:
|
||||
"""Apply speaker renames."""
|
||||
if not self._state.selected_meeting:
|
||||
return
|
||||
|
||||
client = self._get_client()
|
||||
if not client:
|
||||
logger.warning("No gRPC client available")
|
||||
return
|
||||
|
||||
meeting_id = self._state.selected_meeting.id
|
||||
self._close_rename_dialog()
|
||||
|
||||
# Collect renames (only non-empty values)
|
||||
renames: list[tuple[str, str]] = []
|
||||
for old_id, field in self._rename_fields.items():
|
||||
new_name = (field.value or "").strip()
|
||||
if new_name and new_name != old_id:
|
||||
renames.append((old_id, new_name))
|
||||
|
||||
if not renames:
|
||||
return
|
||||
|
||||
# Apply renames
|
||||
total_updated = 0
|
||||
errors: list[str] = []
|
||||
|
||||
for old_id, new_name in renames:
|
||||
try:
|
||||
result = client.rename_speaker(meeting_id, old_id, new_name)
|
||||
if result and result.success:
|
||||
total_updated += result.segments_updated
|
||||
else:
|
||||
errors.append(f"{old_id}: rename failed")
|
||||
except Exception as exc:
|
||||
logger.error("Error renaming speaker %s: %s", old_id, exc)
|
||||
errors.append(f"{old_id}: {exc}")
|
||||
|
||||
# Show result
|
||||
if errors:
|
||||
self._show_rename_errors(errors)
|
||||
else:
|
||||
self._show_rename_success(total_updated, len(renames))
|
||||
|
||||
def _show_rename_success(self, segments_updated: int, speakers_renamed: int) -> None:
|
||||
"""Show rename success message.
|
||||
|
||||
Args:
|
||||
segments_updated: Total number of segments updated.
|
||||
speakers_renamed: Number of speakers renamed.
|
||||
"""
|
||||
success_dialog = ft.AlertDialog(
|
||||
title=ft.Text("Rename Complete"),
|
||||
content=ft.Column(
|
||||
[
|
||||
ft.Text(f"Renamed {speakers_renamed} speaker(s)"),
|
||||
ft.Text(f"Updated {segments_updated} segment(s)"),
|
||||
ft.Text(
|
||||
"Reload the meeting to see the new speaker names.",
|
||||
size=12,
|
||||
color=ft.Colors.GREY_600,
|
||||
italic=True,
|
||||
),
|
||||
],
|
||||
spacing=5,
|
||||
tight=True,
|
||||
),
|
||||
actions=[ft.TextButton("OK", on_click=lambda e: self._close_result_dialog(e))],
|
||||
)
|
||||
|
||||
if self._state._page:
|
||||
self._state._page.dialog = success_dialog
|
||||
success_dialog.open = True
|
||||
self._state.request_update()
|
||||
|
||||
def _show_rename_errors(self, errors: list[str]) -> None:
|
||||
"""Show rename errors.
|
||||
|
||||
Args:
|
||||
errors: List of error messages.
|
||||
"""
|
||||
self._show_simple_dialog("Rename Errors", ft.Text("\n".join(errors)))
|
||||
|
||||
@staticmethod
|
||||
def _can_refine_speakers(meeting: MeetingInfo) -> bool:
|
||||
"""Return True when meeting is stopped/completed and safe to refine/rename."""
|
||||
return meeting.state in {"stopped", "completed", "error"}
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
"""Playback controls component with play/pause/stop and timeline.
|
||||
|
||||
Uses SoundDevicePlayback from infrastructure.audio and format_timestamp from _formatting.
|
||||
Does not recreate any types - imports and uses existing ones.
|
||||
Receives position updates via callback from SoundDevicePlayback.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Final
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import flet as ft
|
||||
|
||||
from noteflow.client.components._thread_mixin import BackgroundWorkerMixin
|
||||
|
||||
# REUSE existing types - do not recreate
|
||||
from noteflow.infrastructure.audio import PlaybackState
|
||||
from noteflow.infrastructure.export._formatting import format_timestamp
|
||||
@@ -23,13 +21,12 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
POSITION_POLL_INTERVAL: Final[float] = 0.1 # 100ms for smooth timeline updates
|
||||
|
||||
|
||||
class PlaybackControlsComponent(BackgroundWorkerMixin):
|
||||
class PlaybackControlsComponent:
|
||||
"""Audio playback controls with play/pause/stop and timeline.
|
||||
|
||||
Uses SoundDevicePlayback from state and format_timestamp from _formatting.
|
||||
Receives position updates via callback from SoundDevicePlayback.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -45,7 +42,7 @@ class PlaybackControlsComponent(BackgroundWorkerMixin):
|
||||
"""
|
||||
self._state = state
|
||||
self._on_position_change = on_position_change
|
||||
self._init_worker()
|
||||
self._active = False
|
||||
|
||||
# UI elements
|
||||
self._play_btn: ft.IconButton | None = None
|
||||
@@ -159,20 +156,20 @@ class PlaybackControlsComponent(BackgroundWorkerMixin):
|
||||
|
||||
if playback.state == PlaybackState.PLAYING:
|
||||
playback.pause()
|
||||
self._stop_polling()
|
||||
self._stop_position_updates()
|
||||
self._update_play_button(playing=False)
|
||||
elif playback.state == PlaybackState.PAUSED:
|
||||
playback.resume()
|
||||
self._start_polling()
|
||||
self._start_position_updates()
|
||||
self._update_play_button(playing=True)
|
||||
elif buffer := self._state.session_audio_buffer:
|
||||
playback.play(buffer)
|
||||
self._start_polling()
|
||||
self._start_position_updates()
|
||||
self._update_play_button(playing=True)
|
||||
|
||||
def _on_stop_click(self, e: ft.ControlEvent) -> None:
|
||||
"""Handle stop button click."""
|
||||
self._stop_polling()
|
||||
self._stop_position_updates()
|
||||
self._state.playback.stop()
|
||||
self._state.playback_position = 0.0
|
||||
self._update_play_button(playing=False)
|
||||
@@ -195,37 +192,47 @@ class PlaybackControlsComponent(BackgroundWorkerMixin):
|
||||
self._play_btn.tooltip = "Play"
|
||||
self._state.request_update()
|
||||
|
||||
def _start_polling(self) -> None:
|
||||
"""Start position polling thread."""
|
||||
self._start_worker(self._poll_loop, "PlaybackPositionPoll")
|
||||
def _start_position_updates(self) -> None:
|
||||
"""Start receiving position updates via callback."""
|
||||
if self._active:
|
||||
return
|
||||
self._active = True
|
||||
self._state.playback.add_position_callback(self._on_position_update)
|
||||
|
||||
def _stop_polling(self) -> None:
|
||||
"""Stop position polling thread."""
|
||||
self._stop_worker()
|
||||
def _stop_position_updates(self) -> None:
|
||||
"""Stop receiving position updates."""
|
||||
if not self._active:
|
||||
return
|
||||
self._active = False
|
||||
self._state.playback.remove_position_callback(self._on_position_update)
|
||||
|
||||
def _poll_loop(self) -> None:
|
||||
"""Background polling loop for position updates."""
|
||||
while self._should_run():
|
||||
playback = self._state.playback
|
||||
def _on_position_update(self, position: float) -> None:
|
||||
"""Handle position update from playback callback.
|
||||
|
||||
if playback.state == PlaybackState.PLAYING:
|
||||
position = playback.current_position
|
||||
self._state.playback_position = position
|
||||
self._state.run_on_ui_thread(self._update_position_display)
|
||||
Called from audio thread - schedules UI work on UI thread.
|
||||
"""
|
||||
if not self._active:
|
||||
return
|
||||
|
||||
# Notify callback
|
||||
if self._on_position_change:
|
||||
try:
|
||||
self._on_position_change(position)
|
||||
except Exception as e:
|
||||
logger.error("Position change callback error: %s", e)
|
||||
playback = self._state.playback
|
||||
|
||||
elif playback.state == PlaybackState.STOPPED:
|
||||
# Playback finished - update UI and stop polling
|
||||
self._state.run_on_ui_thread(self._on_playback_finished)
|
||||
break
|
||||
# Check if playback stopped
|
||||
if playback.state == PlaybackState.STOPPED:
|
||||
self._active = False
|
||||
self._state.playback.remove_position_callback(self._on_position_update)
|
||||
self._state.run_on_ui_thread(self._on_playback_finished)
|
||||
return
|
||||
|
||||
self._wait_interval(POSITION_POLL_INTERVAL)
|
||||
# Update position state
|
||||
self._state.playback_position = position
|
||||
self._state.run_on_ui_thread(self._update_position_display)
|
||||
|
||||
# Notify external callback
|
||||
if self._on_position_change:
|
||||
try:
|
||||
self._on_position_change(position)
|
||||
except Exception as e:
|
||||
logger.error("Position change callback error: %s", e)
|
||||
|
||||
def _update_position_display(self) -> None:
|
||||
"""Update position display elements (UI thread only)."""
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
"""Playback-transcript synchronization controller.
|
||||
|
||||
Polls playback position and updates transcript highlight state.
|
||||
Follows RecordingTimerComponent pattern for background threading.
|
||||
Uses playback position callbacks to update transcript highlight state.
|
||||
No polling thread - receives position updates directly from SoundDevicePlayback.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Final
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.infrastructure.audio import PlaybackState
|
||||
|
||||
@@ -18,14 +17,12 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
POSITION_POLL_INTERVAL: Final[float] = 0.1 # 100ms for smooth highlighting
|
||||
|
||||
|
||||
class PlaybackSyncController:
|
||||
"""Synchronize playback position with transcript highlighting.
|
||||
|
||||
Polls playback position and updates state.highlighted_segment_index.
|
||||
Triggers UI updates via state.run_on_ui_thread().
|
||||
Receives position updates via callback from SoundDevicePlayback.
|
||||
Updates state.highlighted_segment_index and triggers UI updates.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -41,46 +38,46 @@ class PlaybackSyncController:
|
||||
"""
|
||||
self._state = state
|
||||
self._on_highlight_change = on_highlight_change
|
||||
self._sync_thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
self._active = False
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start position sync polling."""
|
||||
if self._sync_thread and self._sync_thread.is_alive():
|
||||
"""Start position sync by registering callback with playback."""
|
||||
if self._active:
|
||||
return
|
||||
|
||||
self._stop_event.clear()
|
||||
self._sync_thread = threading.Thread(
|
||||
target=self._sync_loop,
|
||||
daemon=True,
|
||||
name="PlaybackSyncController",
|
||||
)
|
||||
self._sync_thread.start()
|
||||
self._active = True
|
||||
self._state.playback.add_position_callback(self._on_position_update)
|
||||
logger.debug("Started playback sync controller")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop position sync polling."""
|
||||
self._stop_event.set()
|
||||
if self._sync_thread:
|
||||
self._sync_thread.join(timeout=2.0)
|
||||
self._sync_thread = None
|
||||
"""Stop position sync by unregistering callback."""
|
||||
if not self._active:
|
||||
return
|
||||
|
||||
self._active = False
|
||||
self._state.playback.remove_position_callback(self._on_position_update)
|
||||
|
||||
# Clear highlight when stopped
|
||||
if self._state.highlighted_segment_index is not None:
|
||||
self._state.highlighted_segment_index = None
|
||||
self._state.run_on_ui_thread(self._notify_highlight_change)
|
||||
|
||||
logger.debug("Stopped playback sync controller")
|
||||
|
||||
def _sync_loop(self) -> None:
|
||||
"""Background sync loop - polls position and updates highlight."""
|
||||
while not self._stop_event.is_set():
|
||||
playback = self._state.playback
|
||||
def _on_position_update(self, position: float) -> None:
|
||||
"""Handle position update from playback callback.
|
||||
|
||||
if playback.state == PlaybackState.PLAYING:
|
||||
position = playback.current_position
|
||||
self._update_position(position)
|
||||
elif playback.state == PlaybackState.STOPPED:
|
||||
# Clear highlight when stopped
|
||||
if self._state.highlighted_segment_index is not None:
|
||||
self._state.highlighted_segment_index = None
|
||||
self._state.run_on_ui_thread(self._notify_highlight_change)
|
||||
Called from audio thread - schedules UI work on UI thread.
|
||||
"""
|
||||
if not self._active:
|
||||
return
|
||||
|
||||
self._stop_event.wait(POSITION_POLL_INTERVAL)
|
||||
# Check if playback stopped
|
||||
if self._state.playback.state == PlaybackState.STOPPED:
|
||||
self.stop()
|
||||
return
|
||||
|
||||
self._update_position(position)
|
||||
|
||||
def _update_position(self, position: float) -> None:
|
||||
"""Update state with current position and find matching segment."""
|
||||
|
||||
@@ -6,8 +6,8 @@ Does not recreate any types - imports and uses existing ones.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import hashlib
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import flet as ft
|
||||
|
||||
@@ -1,5 +1,14 @@
|
||||
"""NoteFlow configuration module."""
|
||||
|
||||
from .constants import DEFAULT_GRPC_PORT, DEFAULT_SAMPLE_RATE, MAX_GRPC_MESSAGE_SIZE
|
||||
from .settings import Settings, TriggerSettings, get_settings, get_trigger_settings
|
||||
|
||||
__all__ = ["Settings", "TriggerSettings", "get_settings", "get_trigger_settings"]
|
||||
__all__ = [
|
||||
"DEFAULT_GRPC_PORT",
|
||||
"DEFAULT_SAMPLE_RATE",
|
||||
"MAX_GRPC_MESSAGE_SIZE",
|
||||
"Settings",
|
||||
"TriggerSettings",
|
||||
"get_settings",
|
||||
"get_trigger_settings",
|
||||
]
|
||||
|
||||
23
src/noteflow/config/constants.py
Normal file
23
src/noteflow/config/constants.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Centralized constants for NoteFlow.
|
||||
|
||||
This module provides shared constants used across the codebase to avoid
|
||||
magic numbers and ensure consistency.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Final
|
||||
|
||||
# Audio constants
|
||||
DEFAULT_SAMPLE_RATE: Final[int] = 16000
|
||||
"""Default audio sample rate in Hz (16 kHz)."""
|
||||
|
||||
POSITION_UPDATE_INTERVAL: Final[float] = 0.1
|
||||
"""Playback position update interval in seconds (100ms)."""
|
||||
|
||||
# gRPC constants
|
||||
DEFAULT_GRPC_PORT: Final[int] = 50051
|
||||
"""Default gRPC server port."""
|
||||
|
||||
MAX_GRPC_MESSAGE_SIZE: Final[int] = 100 * 1024 * 1024
|
||||
"""Maximum gRPC message size in bytes (100 MB)."""
|
||||
@@ -57,10 +57,10 @@ class TriggerSettings(BaseSettings):
|
||||
Field(default=0.80, ge=0.0, le=1.0, description="Confidence to auto-start recording"),
|
||||
]
|
||||
|
||||
# Audio trigger tuning
|
||||
# App audio trigger tuning (system output from whitelisted apps)
|
||||
trigger_audio_enabled: Annotated[
|
||||
bool,
|
||||
Field(default=True, description="Enable audio activity detection"),
|
||||
Field(default=True, description="Enable app audio activity detection"),
|
||||
]
|
||||
trigger_audio_threshold_db: Annotated[
|
||||
float,
|
||||
@@ -83,6 +83,27 @@ class TriggerSettings(BaseSettings):
|
||||
Field(default=50, ge=10, le=1000, description="Max audio activity samples to retain"),
|
||||
]
|
||||
|
||||
# Calendar trigger tuning (optional integration)
|
||||
trigger_calendar_enabled: Annotated[
|
||||
bool,
|
||||
Field(default=False, description="Enable calendar-based trigger detection"),
|
||||
]
|
||||
trigger_calendar_lookahead_minutes: Annotated[
|
||||
int,
|
||||
Field(default=5, ge=0, le=60, description="Minutes before event start to trigger"),
|
||||
]
|
||||
trigger_calendar_lookbehind_minutes: Annotated[
|
||||
int,
|
||||
Field(default=5, ge=0, le=60, description="Minutes after event start to keep triggering"),
|
||||
]
|
||||
trigger_calendar_events: Annotated[
|
||||
list[dict[str, object]],
|
||||
Field(
|
||||
default_factory=list,
|
||||
description="Calendar events as JSON list of {start, end, title}",
|
||||
),
|
||||
]
|
||||
|
||||
# Foreground app trigger tuning
|
||||
trigger_foreground_enabled: Annotated[
|
||||
bool,
|
||||
@@ -148,6 +169,28 @@ class TriggerSettings(BaseSettings):
|
||||
return [str(item).strip() for item in parsed if str(item).strip()]
|
||||
return [item.strip() for item in value.split(",") if item.strip()]
|
||||
|
||||
@field_validator("trigger_calendar_events", mode="before")
|
||||
@classmethod
|
||||
def _parse_calendar_events(cls, value: object) -> list[dict[str, object]]:
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
if not stripped:
|
||||
return []
|
||||
try:
|
||||
parsed = json.loads(stripped)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
if isinstance(parsed, list):
|
||||
return [item for item in parsed if isinstance(item, dict)]
|
||||
return [parsed] if isinstance(parsed, dict) else []
|
||||
if isinstance(value, dict):
|
||||
return [value]
|
||||
if isinstance(value, list):
|
||||
return [item for item in value if isinstance(item, dict)]
|
||||
return []
|
||||
|
||||
|
||||
class Settings(TriggerSettings):
|
||||
"""Application settings loaded from environment variables.
|
||||
|
||||
@@ -72,12 +72,5 @@ class TriggerDecision:
|
||||
|
||||
@property
|
||||
def detected_app(self) -> str | None:
|
||||
"""Get the detected app name from foreground signal if present."""
|
||||
return next(
|
||||
(
|
||||
signal.app_name
|
||||
for signal in self.signals
|
||||
if signal.source == TriggerSource.FOREGROUND_APP and signal.app_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
"""Get the detected app name from any signal if present."""
|
||||
return next((signal.app_name for signal in self.signals if signal.app_name), None)
|
||||
|
||||
@@ -4,9 +4,11 @@ from noteflow.domain.value_objects import MeetingState
|
||||
|
||||
from .client import (
|
||||
AnnotationInfo,
|
||||
DiarizationResult,
|
||||
ExportResult,
|
||||
MeetingInfo,
|
||||
NoteFlowClient,
|
||||
RenameSpeakerResult,
|
||||
ServerInfo,
|
||||
TranscriptSegment,
|
||||
)
|
||||
@@ -15,12 +17,14 @@ from .service import NoteFlowServicer
|
||||
|
||||
__all__ = [
|
||||
"AnnotationInfo",
|
||||
"DiarizationResult",
|
||||
"ExportResult",
|
||||
"MeetingInfo",
|
||||
"MeetingState",
|
||||
"MeetingStore",
|
||||
"NoteFlowClient",
|
||||
"NoteFlowServicer",
|
||||
"RenameSpeakerResult",
|
||||
"ServerInfo",
|
||||
"TranscriptSegment",
|
||||
]
|
||||
|
||||
17
src/noteflow/grpc/_mixins/__init__.py
Normal file
17
src/noteflow/grpc/_mixins/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""gRPC service mixins for NoteFlowServicer."""
|
||||
|
||||
from .annotation import AnnotationMixin
|
||||
from .diarization import DiarizationMixin
|
||||
from .export import ExportMixin
|
||||
from .meeting import MeetingMixin
|
||||
from .streaming import StreamingMixin
|
||||
from .summarization import SummarizationMixin
|
||||
|
||||
__all__ = [
|
||||
"AnnotationMixin",
|
||||
"DiarizationMixin",
|
||||
"ExportMixin",
|
||||
"MeetingMixin",
|
||||
"StreamingMixin",
|
||||
"SummarizationMixin",
|
||||
]
|
||||
161
src/noteflow/grpc/_mixins/annotation.py
Normal file
161
src/noteflow/grpc/_mixins/annotation.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Annotation management mixin for gRPC service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import grpc.aio
|
||||
|
||||
from noteflow.domain.entities import Annotation
|
||||
from noteflow.domain.value_objects import AnnotationId, MeetingId
|
||||
|
||||
from ..proto import noteflow_pb2
|
||||
from .converters import annotation_to_proto, proto_to_annotation_type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .protocols import ServicerHost
|
||||
|
||||
|
||||
class AnnotationMixin:
|
||||
"""Mixin providing annotation CRUD functionality.
|
||||
|
||||
Requires host to implement ServicerHost protocol.
|
||||
Annotations require database persistence.
|
||||
"""
|
||||
|
||||
async def AddAnnotation(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.AddAnnotationRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.Annotation:
|
||||
"""Add an annotation to a meeting."""
|
||||
if not self._use_database():
|
||||
await context.abort(
|
||||
grpc.StatusCode.UNIMPLEMENTED,
|
||||
"Annotations require database persistence",
|
||||
)
|
||||
|
||||
annotation_type = proto_to_annotation_type(request.annotation_type)
|
||||
|
||||
annotation = Annotation(
|
||||
id=AnnotationId(uuid4()),
|
||||
meeting_id=MeetingId(UUID(request.meeting_id)),
|
||||
annotation_type=annotation_type,
|
||||
text=request.text,
|
||||
start_time=request.start_time,
|
||||
end_time=request.end_time,
|
||||
segment_ids=list(request.segment_ids),
|
||||
)
|
||||
|
||||
async with self._create_uow() as uow:
|
||||
saved = await uow.annotations.add(annotation)
|
||||
await uow.commit()
|
||||
return annotation_to_proto(saved)
|
||||
|
||||
async def GetAnnotation(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.GetAnnotationRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.Annotation:
|
||||
"""Get an annotation by ID."""
|
||||
if not self._use_database():
|
||||
await context.abort(
|
||||
grpc.StatusCode.UNIMPLEMENTED,
|
||||
"Annotations require database persistence",
|
||||
)
|
||||
|
||||
async with self._create_uow() as uow:
|
||||
annotation = await uow.annotations.get(AnnotationId(UUID(request.annotation_id)))
|
||||
if annotation is None:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
f"Annotation {request.annotation_id} not found",
|
||||
)
|
||||
return annotation_to_proto(annotation)
|
||||
|
||||
async def ListAnnotations(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.ListAnnotationsRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.ListAnnotationsResponse:
|
||||
"""List annotations for a meeting."""
|
||||
if not self._use_database():
|
||||
await context.abort(
|
||||
grpc.StatusCode.UNIMPLEMENTED,
|
||||
"Annotations require database persistence",
|
||||
)
|
||||
|
||||
async with self._create_uow() as uow:
|
||||
meeting_id = MeetingId(UUID(request.meeting_id))
|
||||
# Check if time range filter is specified
|
||||
if request.start_time > 0 or request.end_time > 0:
|
||||
annotations = await uow.annotations.get_by_time_range(
|
||||
meeting_id,
|
||||
request.start_time,
|
||||
request.end_time,
|
||||
)
|
||||
else:
|
||||
annotations = await uow.annotations.get_by_meeting(meeting_id)
|
||||
|
||||
return noteflow_pb2.ListAnnotationsResponse(
|
||||
annotations=[annotation_to_proto(a) for a in annotations]
|
||||
)
|
||||
|
||||
async def UpdateAnnotation(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.UpdateAnnotationRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.Annotation:
|
||||
"""Update an existing annotation."""
|
||||
if not self._use_database():
|
||||
await context.abort(
|
||||
grpc.StatusCode.UNIMPLEMENTED,
|
||||
"Annotations require database persistence",
|
||||
)
|
||||
|
||||
async with self._create_uow() as uow:
|
||||
annotation = await uow.annotations.get(AnnotationId(UUID(request.annotation_id)))
|
||||
if annotation is None:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
f"Annotation {request.annotation_id} not found",
|
||||
)
|
||||
|
||||
# Update fields if provided
|
||||
if request.annotation_type != noteflow_pb2.ANNOTATION_TYPE_UNSPECIFIED:
|
||||
annotation.annotation_type = proto_to_annotation_type(request.annotation_type)
|
||||
if request.text:
|
||||
annotation.text = request.text
|
||||
if request.start_time > 0:
|
||||
annotation.start_time = request.start_time
|
||||
if request.end_time > 0:
|
||||
annotation.end_time = request.end_time
|
||||
if request.segment_ids:
|
||||
annotation.segment_ids = list(request.segment_ids)
|
||||
|
||||
updated = await uow.annotations.update(annotation)
|
||||
await uow.commit()
|
||||
return annotation_to_proto(updated)
|
||||
|
||||
async def DeleteAnnotation(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.DeleteAnnotationRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.DeleteAnnotationResponse:
|
||||
"""Delete an annotation."""
|
||||
if not self._use_database():
|
||||
await context.abort(
|
||||
grpc.StatusCode.UNIMPLEMENTED,
|
||||
"Annotations require database persistence",
|
||||
)
|
||||
|
||||
async with self._create_uow() as uow:
|
||||
success = await uow.annotations.delete(AnnotationId(UUID(request.annotation_id)))
|
||||
if success:
|
||||
await uow.commit()
|
||||
return noteflow_pb2.DeleteAnnotationResponse(success=True)
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
f"Annotation {request.annotation_id} not found",
|
||||
)
|
||||
227
src/noteflow/grpc/_mixins/converters.py
Normal file
227
src/noteflow/grpc/_mixins/converters.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""Standalone proto ↔ domain conversion functions for gRPC service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.application.services.export_service import ExportFormat
|
||||
from noteflow.domain.entities import Annotation, Meeting, Segment, Summary
|
||||
from noteflow.domain.value_objects import AnnotationType, MeetingId
|
||||
from noteflow.infrastructure.converters import AsrConverter
|
||||
|
||||
from ..proto import noteflow_pb2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.infrastructure.asr.dto import AsrResult
|
||||
|
||||
|
||||
def meeting_to_proto(
|
||||
meeting: Meeting,
|
||||
include_segments: bool = True,
|
||||
include_summary: bool = True,
|
||||
) -> noteflow_pb2.Meeting:
|
||||
"""Convert domain Meeting to protobuf."""
|
||||
segments = []
|
||||
if include_segments:
|
||||
for seg in meeting.segments:
|
||||
words = [
|
||||
noteflow_pb2.WordTiming(
|
||||
word=w.word,
|
||||
start_time=w.start_time,
|
||||
end_time=w.end_time,
|
||||
probability=w.probability,
|
||||
)
|
||||
for w in seg.words
|
||||
]
|
||||
segments.append(
|
||||
noteflow_pb2.FinalSegment(
|
||||
segment_id=seg.segment_id,
|
||||
text=seg.text,
|
||||
start_time=seg.start_time,
|
||||
end_time=seg.end_time,
|
||||
words=words,
|
||||
language=seg.language,
|
||||
language_confidence=seg.language_confidence,
|
||||
avg_logprob=seg.avg_logprob,
|
||||
no_speech_prob=seg.no_speech_prob,
|
||||
speaker_id=seg.speaker_id or "",
|
||||
speaker_confidence=seg.speaker_confidence,
|
||||
)
|
||||
)
|
||||
|
||||
summary = None
|
||||
if include_summary and meeting.summary:
|
||||
summary = summary_to_proto(meeting.summary)
|
||||
|
||||
return noteflow_pb2.Meeting(
|
||||
id=str(meeting.id),
|
||||
title=meeting.title,
|
||||
state=meeting.state.value,
|
||||
created_at=meeting.created_at.timestamp(),
|
||||
started_at=meeting.started_at.timestamp() if meeting.started_at else 0,
|
||||
ended_at=meeting.ended_at.timestamp() if meeting.ended_at else 0,
|
||||
duration_seconds=meeting.duration_seconds,
|
||||
segments=segments,
|
||||
summary=summary,
|
||||
metadata=meeting.metadata,
|
||||
)
|
||||
|
||||
|
||||
def summary_to_proto(summary: Summary) -> noteflow_pb2.Summary:
|
||||
"""Convert domain Summary to protobuf."""
|
||||
key_points = [
|
||||
noteflow_pb2.KeyPoint(
|
||||
text=kp.text,
|
||||
segment_ids=kp.segment_ids,
|
||||
start_time=kp.start_time,
|
||||
end_time=kp.end_time,
|
||||
)
|
||||
for kp in summary.key_points
|
||||
]
|
||||
action_items = [
|
||||
noteflow_pb2.ActionItem(
|
||||
text=ai.text,
|
||||
assignee=ai.assignee,
|
||||
due_date=ai.due_date.timestamp() if ai.due_date is not None else 0,
|
||||
priority=ai.priority,
|
||||
segment_ids=ai.segment_ids,
|
||||
)
|
||||
for ai in summary.action_items
|
||||
]
|
||||
return noteflow_pb2.Summary(
|
||||
meeting_id=str(summary.meeting_id),
|
||||
executive_summary=summary.executive_summary,
|
||||
key_points=key_points,
|
||||
action_items=action_items,
|
||||
generated_at=(summary.generated_at.timestamp() if summary.generated_at is not None else 0),
|
||||
model_version=summary.model_version,
|
||||
)
|
||||
|
||||
|
||||
def segment_to_proto_update(
|
||||
meeting_id: str,
|
||||
segment: Segment,
|
||||
) -> noteflow_pb2.TranscriptUpdate:
|
||||
"""Convert domain Segment to protobuf TranscriptUpdate."""
|
||||
words = [
|
||||
noteflow_pb2.WordTiming(
|
||||
word=w.word,
|
||||
start_time=w.start_time,
|
||||
end_time=w.end_time,
|
||||
probability=w.probability,
|
||||
)
|
||||
for w in segment.words
|
||||
]
|
||||
final_segment = noteflow_pb2.FinalSegment(
|
||||
segment_id=segment.segment_id,
|
||||
text=segment.text,
|
||||
start_time=segment.start_time,
|
||||
end_time=segment.end_time,
|
||||
words=words,
|
||||
language=segment.language,
|
||||
language_confidence=segment.language_confidence,
|
||||
avg_logprob=segment.avg_logprob,
|
||||
no_speech_prob=segment.no_speech_prob,
|
||||
speaker_id=segment.speaker_id or "",
|
||||
speaker_confidence=segment.speaker_confidence,
|
||||
)
|
||||
return noteflow_pb2.TranscriptUpdate(
|
||||
meeting_id=meeting_id,
|
||||
update_type=noteflow_pb2.UPDATE_TYPE_FINAL,
|
||||
segment=final_segment,
|
||||
server_timestamp=time.time(),
|
||||
)
|
||||
|
||||
|
||||
def annotation_to_proto(annotation: Annotation) -> noteflow_pb2.Annotation:
|
||||
"""Convert domain Annotation to protobuf."""
|
||||
return noteflow_pb2.Annotation(
|
||||
id=str(annotation.id),
|
||||
meeting_id=str(annotation.meeting_id),
|
||||
annotation_type=annotation_type_to_proto(annotation.annotation_type),
|
||||
text=annotation.text,
|
||||
start_time=annotation.start_time,
|
||||
end_time=annotation.end_time,
|
||||
segment_ids=annotation.segment_ids,
|
||||
created_at=annotation.created_at.timestamp(),
|
||||
)
|
||||
|
||||
|
||||
def annotation_type_to_proto(annotation_type: AnnotationType) -> int:
|
||||
"""Convert domain AnnotationType to protobuf enum."""
|
||||
mapping = {
|
||||
AnnotationType.ACTION_ITEM: noteflow_pb2.ANNOTATION_TYPE_ACTION_ITEM,
|
||||
AnnotationType.DECISION: noteflow_pb2.ANNOTATION_TYPE_DECISION,
|
||||
AnnotationType.NOTE: noteflow_pb2.ANNOTATION_TYPE_NOTE,
|
||||
AnnotationType.RISK: noteflow_pb2.ANNOTATION_TYPE_RISK,
|
||||
}
|
||||
return mapping.get(annotation_type, noteflow_pb2.ANNOTATION_TYPE_UNSPECIFIED)
|
||||
|
||||
|
||||
def proto_to_annotation_type(proto_type: int) -> AnnotationType:
|
||||
"""Convert protobuf enum to domain AnnotationType."""
|
||||
mapping: dict[int, AnnotationType] = {
|
||||
int(noteflow_pb2.ANNOTATION_TYPE_ACTION_ITEM): AnnotationType.ACTION_ITEM,
|
||||
int(noteflow_pb2.ANNOTATION_TYPE_DECISION): AnnotationType.DECISION,
|
||||
int(noteflow_pb2.ANNOTATION_TYPE_NOTE): AnnotationType.NOTE,
|
||||
int(noteflow_pb2.ANNOTATION_TYPE_RISK): AnnotationType.RISK,
|
||||
}
|
||||
return mapping.get(proto_type, AnnotationType.NOTE)
|
||||
|
||||
|
||||
def create_vad_update(
|
||||
meeting_id: str,
|
||||
update_type: int,
|
||||
) -> noteflow_pb2.TranscriptUpdate:
|
||||
"""Create a VAD event update.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting identifier.
|
||||
update_type: VAD_START or VAD_END.
|
||||
|
||||
Returns:
|
||||
TranscriptUpdate with VAD event.
|
||||
"""
|
||||
return noteflow_pb2.TranscriptUpdate(
|
||||
meeting_id=meeting_id,
|
||||
update_type=update_type,
|
||||
server_timestamp=time.time(),
|
||||
)
|
||||
|
||||
|
||||
def create_segment_from_asr(
|
||||
meeting_id: MeetingId,
|
||||
segment_id: int,
|
||||
result: AsrResult,
|
||||
segment_start_time: float,
|
||||
) -> Segment:
|
||||
"""Create a Segment from ASR result.
|
||||
|
||||
Use converters to transform ASR DTO to domain entities.
|
||||
"""
|
||||
words = AsrConverter.result_to_domain_words(result)
|
||||
if segment_start_time:
|
||||
for word in words:
|
||||
word.start_time += segment_start_time
|
||||
word.end_time += segment_start_time
|
||||
|
||||
return Segment(
|
||||
segment_id=segment_id,
|
||||
text=result.text,
|
||||
start_time=result.start + segment_start_time,
|
||||
end_time=result.end + segment_start_time,
|
||||
meeting_id=meeting_id,
|
||||
words=words,
|
||||
language=result.language,
|
||||
language_confidence=result.language_probability,
|
||||
avg_logprob=result.avg_logprob,
|
||||
no_speech_prob=result.no_speech_prob,
|
||||
)
|
||||
|
||||
|
||||
def proto_to_export_format(proto_format: int) -> ExportFormat:
|
||||
"""Convert protobuf ExportFormat to domain ExportFormat."""
|
||||
if proto_format == noteflow_pb2.EXPORT_FORMAT_HTML:
|
||||
return ExportFormat.HTML
|
||||
return ExportFormat.MARKDOWN # Default to Markdown
|
||||
486
src/noteflow/grpc/_mixins/diarization.py
Normal file
486
src/noteflow/grpc/_mixins/diarization.py
Normal file
@@ -0,0 +1,486 @@
|
||||
"""Speaker diarization mixin for gRPC service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import grpc.aio
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from noteflow.domain.entities import Segment
|
||||
from noteflow.domain.value_objects import MeetingId, MeetingState
|
||||
from noteflow.infrastructure.audio.reader import MeetingAudioReader
|
||||
from noteflow.infrastructure.diarization import SpeakerTurn, assign_speaker
|
||||
|
||||
from ..proto import noteflow_pb2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .protocols import ServicerHost
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DiarizationJob:
|
||||
"""Track background diarization job state."""
|
||||
|
||||
job_id: str
|
||||
meeting_id: str
|
||||
status: int
|
||||
segments_updated: int = 0
|
||||
speaker_ids: list[str] = field(default_factory=list)
|
||||
error_message: str = ""
|
||||
created_at: float = field(default_factory=time.time)
|
||||
updated_at: float = field(default_factory=time.time)
|
||||
task: asyncio.Task[None] | None = None
|
||||
|
||||
def to_proto(self) -> noteflow_pb2.DiarizationJobStatus:
|
||||
"""Convert to protobuf message."""
|
||||
return noteflow_pb2.DiarizationJobStatus(
|
||||
job_id=self.job_id,
|
||||
status=self.status,
|
||||
segments_updated=self.segments_updated,
|
||||
speaker_ids=self.speaker_ids,
|
||||
error_message=self.error_message,
|
||||
)
|
||||
|
||||
|
||||
class DiarizationMixin:
|
||||
"""Mixin providing speaker diarization functionality.
|
||||
|
||||
Requires host to implement ServicerHost protocol.
|
||||
"""
|
||||
|
||||
# Job retention constant
|
||||
DIARIZATION_JOB_TTL_SECONDS: float = 60 * 60 # 1 hour
|
||||
|
||||
def _process_streaming_diarization(
|
||||
self: ServicerHost,
|
||||
meeting_id: str,
|
||||
audio: NDArray[np.float32],
|
||||
) -> None:
|
||||
"""Process an audio chunk for streaming diarization (best-effort)."""
|
||||
if self._diarization_engine is None:
|
||||
return
|
||||
if meeting_id in self._diarization_streaming_failed:
|
||||
return
|
||||
if audio.size == 0:
|
||||
return
|
||||
|
||||
if not self._diarization_engine.is_streaming_loaded:
|
||||
try:
|
||||
self._diarization_engine.load_streaming_model()
|
||||
except (RuntimeError, ValueError) as exc:
|
||||
logger.warning(
|
||||
"Streaming diarization disabled for meeting %s: %s",
|
||||
meeting_id,
|
||||
exc,
|
||||
)
|
||||
self._diarization_streaming_failed.add(meeting_id)
|
||||
return
|
||||
|
||||
stream_time = self._diarization_stream_time.get(meeting_id, 0.0)
|
||||
duration = len(audio) / self.DEFAULT_SAMPLE_RATE
|
||||
|
||||
try:
|
||||
turns = self._diarization_engine.process_chunk(
|
||||
audio,
|
||||
sample_rate=self.DEFAULT_SAMPLE_RATE,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Streaming diarization failed for meeting %s: %s",
|
||||
meeting_id,
|
||||
exc,
|
||||
)
|
||||
self._diarization_streaming_failed.add(meeting_id)
|
||||
return
|
||||
|
||||
diarization_turns = self._diarization_turns.setdefault(meeting_id, [])
|
||||
for turn in turns:
|
||||
diarization_turns.append(
|
||||
SpeakerTurn(
|
||||
speaker=turn.speaker,
|
||||
start=turn.start + stream_time,
|
||||
end=turn.end + stream_time,
|
||||
confidence=turn.confidence,
|
||||
)
|
||||
)
|
||||
|
||||
self._diarization_stream_time[meeting_id] = stream_time + duration
|
||||
|
||||
def _maybe_assign_speaker(
|
||||
self: ServicerHost,
|
||||
meeting_id: str,
|
||||
segment: Segment,
|
||||
) -> None:
|
||||
"""Assign speaker to a segment using streaming diarization turns (best-effort)."""
|
||||
if self._diarization_engine is None:
|
||||
return
|
||||
if meeting_id in self._diarization_streaming_failed:
|
||||
return
|
||||
turns = self._diarization_turns.get(meeting_id)
|
||||
if not turns:
|
||||
return
|
||||
|
||||
speaker_id, confidence = assign_speaker(
|
||||
segment.start_time,
|
||||
segment.end_time,
|
||||
turns,
|
||||
)
|
||||
if speaker_id is None:
|
||||
return
|
||||
|
||||
segment.speaker_id = speaker_id
|
||||
segment.speaker_confidence = confidence
|
||||
|
||||
def _prune_diarization_jobs(self: ServicerHost) -> None:
|
||||
"""Remove completed diarization jobs older than retention window."""
|
||||
if not self._diarization_jobs:
|
||||
return
|
||||
now = time.time()
|
||||
terminal_statuses = {
|
||||
noteflow_pb2.JOB_STATUS_COMPLETED,
|
||||
noteflow_pb2.JOB_STATUS_FAILED,
|
||||
}
|
||||
expired = [
|
||||
job_id
|
||||
for job_id, job in self._diarization_jobs.items()
|
||||
if job.status in terminal_statuses
|
||||
and now - job.updated_at > self.DIARIZATION_JOB_TTL_SECONDS
|
||||
]
|
||||
for job_id in expired:
|
||||
self._diarization_jobs.pop(job_id, None)
|
||||
|
||||
async def RefineSpeakerDiarization(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.RefineSpeakerDiarizationRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.RefineSpeakerDiarizationResponse:
|
||||
"""Run post-meeting speaker diarization refinement.
|
||||
|
||||
Load the full meeting audio, run offline diarization, and update
|
||||
segment speaker assignments.
|
||||
"""
|
||||
self._prune_diarization_jobs()
|
||||
|
||||
if not self._diarization_refinement_enabled:
|
||||
response = noteflow_pb2.RefineSpeakerDiarizationResponse()
|
||||
response.segments_updated = 0
|
||||
response.speaker_ids[:] = []
|
||||
response.error_message = "Diarization refinement disabled on server"
|
||||
response.job_id = ""
|
||||
response.status = noteflow_pb2.JOB_STATUS_FAILED
|
||||
return response
|
||||
|
||||
if self._diarization_engine is None:
|
||||
response = noteflow_pb2.RefineSpeakerDiarizationResponse()
|
||||
response.segments_updated = 0
|
||||
response.speaker_ids[:] = []
|
||||
response.error_message = "Diarization not enabled on server"
|
||||
response.job_id = ""
|
||||
response.status = noteflow_pb2.JOB_STATUS_FAILED
|
||||
return response
|
||||
|
||||
try:
|
||||
meeting_uuid = UUID(request.meeting_id)
|
||||
except ValueError:
|
||||
response = noteflow_pb2.RefineSpeakerDiarizationResponse()
|
||||
response.segments_updated = 0
|
||||
response.speaker_ids[:] = []
|
||||
response.error_message = "Invalid meeting_id"
|
||||
response.job_id = ""
|
||||
response.status = noteflow_pb2.JOB_STATUS_FAILED
|
||||
return response
|
||||
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
meeting = await uow.meetings.get(MeetingId(meeting_uuid))
|
||||
else:
|
||||
store = self._get_memory_store()
|
||||
meeting = store.get(request.meeting_id)
|
||||
if meeting is None:
|
||||
response = noteflow_pb2.RefineSpeakerDiarizationResponse()
|
||||
response.segments_updated = 0
|
||||
response.speaker_ids[:] = []
|
||||
response.error_message = "Meeting not found"
|
||||
response.job_id = ""
|
||||
response.status = noteflow_pb2.JOB_STATUS_FAILED
|
||||
return response
|
||||
meeting_state = meeting.state
|
||||
if meeting_state in (
|
||||
MeetingState.UNSPECIFIED,
|
||||
MeetingState.CREATED,
|
||||
MeetingState.RECORDING,
|
||||
MeetingState.STOPPING,
|
||||
):
|
||||
response = noteflow_pb2.RefineSpeakerDiarizationResponse()
|
||||
response.segments_updated = 0
|
||||
response.speaker_ids[:] = []
|
||||
response.error_message = (
|
||||
f"Meeting must be stopped before refinement (state: {meeting_state.name.lower()})"
|
||||
)
|
||||
response.job_id = ""
|
||||
response.status = noteflow_pb2.JOB_STATUS_FAILED
|
||||
return response
|
||||
|
||||
num_speakers = request.num_speakers if request.num_speakers > 0 else None
|
||||
|
||||
job_id = str(uuid4())
|
||||
job = _DiarizationJob(
|
||||
job_id=job_id,
|
||||
meeting_id=request.meeting_id,
|
||||
status=noteflow_pb2.JOB_STATUS_QUEUED,
|
||||
)
|
||||
self._diarization_jobs[job_id] = job
|
||||
|
||||
# Task runs in background, no need to await
|
||||
task = asyncio.create_task(self._run_diarization_job(job_id, num_speakers))
|
||||
job.task = task
|
||||
|
||||
response = noteflow_pb2.RefineSpeakerDiarizationResponse()
|
||||
response.segments_updated = 0
|
||||
response.speaker_ids[:] = []
|
||||
response.error_message = ""
|
||||
response.job_id = job_id
|
||||
response.status = noteflow_pb2.JOB_STATUS_QUEUED
|
||||
return response
|
||||
|
||||
async def _run_diarization_job(
|
||||
self: ServicerHost,
|
||||
job_id: str,
|
||||
num_speakers: int | None,
|
||||
) -> None:
|
||||
"""Run background diarization job."""
|
||||
job = self._diarization_jobs.get(job_id)
|
||||
if job is None:
|
||||
return
|
||||
|
||||
job.status = noteflow_pb2.JOB_STATUS_RUNNING
|
||||
job.updated_at = time.time()
|
||||
|
||||
try:
|
||||
updated_count = await self.refine_speaker_diarization(
|
||||
meeting_id=job.meeting_id,
|
||||
num_speakers=num_speakers,
|
||||
)
|
||||
speaker_ids = await self._collect_speaker_ids(job.meeting_id)
|
||||
job.segments_updated = updated_count
|
||||
job.speaker_ids = speaker_ids
|
||||
job.status = noteflow_pb2.JOB_STATUS_COMPLETED
|
||||
except Exception as exc:
|
||||
logger.exception("Diarization failed for meeting %s", job.meeting_id)
|
||||
job.error_message = str(exc)
|
||||
job.status = noteflow_pb2.JOB_STATUS_FAILED
|
||||
finally:
|
||||
job.updated_at = time.time()
|
||||
|
||||
async def refine_speaker_diarization(
|
||||
self: ServicerHost,
|
||||
meeting_id: str,
|
||||
num_speakers: int | None = None,
|
||||
) -> int:
|
||||
"""Run post-meeting speaker diarization refinement.
|
||||
|
||||
Load the full meeting audio, run offline diarization, and update
|
||||
segment speaker assignments. This provides higher quality speaker
|
||||
labels than streaming diarization.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting UUID string.
|
||||
num_speakers: Known number of speakers (None for auto-detect).
|
||||
|
||||
Returns:
|
||||
Number of segments updated with speaker labels.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If diarization engine not available or meeting not found.
|
||||
"""
|
||||
turns = await asyncio.to_thread(
|
||||
self._run_diarization_inference,
|
||||
meeting_id,
|
||||
num_speakers,
|
||||
)
|
||||
|
||||
updated_count = await self._apply_diarization_turns(meeting_id, turns)
|
||||
|
||||
logger.info(
|
||||
"Updated %d segments with speaker labels for meeting %s",
|
||||
updated_count,
|
||||
meeting_id,
|
||||
)
|
||||
|
||||
return updated_count
|
||||
|
||||
def _run_diarization_inference(
|
||||
self: ServicerHost,
|
||||
meeting_id: str,
|
||||
num_speakers: int | None,
|
||||
) -> list[SpeakerTurn]:
|
||||
"""Run offline diarization and return speaker turns (blocking)."""
|
||||
if self._diarization_engine is None:
|
||||
raise RuntimeError("Diarization engine not configured")
|
||||
|
||||
if not self._diarization_engine.is_offline_loaded:
|
||||
logger.info("Loading offline diarization model for refinement...")
|
||||
self._diarization_engine.load_offline_model()
|
||||
|
||||
audio_reader = MeetingAudioReader(self._crypto, self._meetings_dir)
|
||||
if not audio_reader.audio_exists(meeting_id):
|
||||
raise RuntimeError("No audio file found for meeting")
|
||||
|
||||
logger.info("Loading audio for meeting %s", meeting_id)
|
||||
try:
|
||||
audio_chunks = audio_reader.load_meeting_audio(meeting_id)
|
||||
except (FileNotFoundError, ValueError) as exc:
|
||||
raise RuntimeError(f"Failed to load audio: {exc}") from exc
|
||||
|
||||
if not audio_chunks:
|
||||
raise RuntimeError("No audio chunks loaded for meeting")
|
||||
|
||||
sample_rate = audio_reader.sample_rate
|
||||
all_audio = np.concatenate([chunk.frames for chunk in audio_chunks])
|
||||
|
||||
logger.info(
|
||||
"Running offline diarization on %.2f seconds of audio",
|
||||
len(all_audio) / sample_rate,
|
||||
)
|
||||
|
||||
turns = self._diarization_engine.diarize_full(
|
||||
all_audio,
|
||||
sample_rate=sample_rate,
|
||||
num_speakers=num_speakers,
|
||||
)
|
||||
|
||||
logger.info("Diarization found %d speaker turns", len(turns))
|
||||
return list(turns)
|
||||
|
||||
async def _apply_diarization_turns(
|
||||
self: ServicerHost,
|
||||
meeting_id: str,
|
||||
turns: list[SpeakerTurn],
|
||||
) -> int:
|
||||
"""Apply diarization turns to segments and return updated count."""
|
||||
updated_count = 0
|
||||
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
segments = await uow.segments.get_by_meeting(MeetingId(UUID(meeting_id)))
|
||||
for segment in segments:
|
||||
if segment.db_id is None:
|
||||
continue
|
||||
speaker_id, confidence = assign_speaker(
|
||||
segment.start_time,
|
||||
segment.end_time,
|
||||
turns,
|
||||
)
|
||||
if speaker_id is None:
|
||||
continue
|
||||
await uow.segments.update_speaker(
|
||||
segment.db_id,
|
||||
speaker_id,
|
||||
confidence,
|
||||
)
|
||||
updated_count += 1
|
||||
await uow.commit()
|
||||
else:
|
||||
store = self._get_memory_store()
|
||||
if meeting := store.get(meeting_id):
|
||||
for segment in meeting.segments:
|
||||
speaker_id, confidence = assign_speaker(
|
||||
segment.start_time,
|
||||
segment.end_time,
|
||||
turns,
|
||||
)
|
||||
if speaker_id is None:
|
||||
continue
|
||||
segment.speaker_id = speaker_id
|
||||
segment.speaker_confidence = confidence
|
||||
updated_count += 1
|
||||
|
||||
return updated_count
|
||||
|
||||
async def _collect_speaker_ids(self: ServicerHost, meeting_id: str) -> list[str]:
|
||||
"""Collect distinct speaker IDs for a meeting."""
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
segments = await uow.segments.get_by_meeting(MeetingId(UUID(meeting_id)))
|
||||
return sorted({s.speaker_id for s in segments if s.speaker_id})
|
||||
store = self._get_memory_store()
|
||||
if meeting := store.get(meeting_id):
|
||||
return sorted({s.speaker_id for s in meeting.segments if s.speaker_id})
|
||||
return []
|
||||
|
||||
async def RenameSpeaker(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.RenameSpeakerRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.RenameSpeakerResponse:
|
||||
"""Rename a speaker ID in all segments of a meeting.
|
||||
|
||||
Update all segments where speaker_id matches old_speaker_id
|
||||
to use new_speaker_name instead.
|
||||
"""
|
||||
if not request.old_speaker_id or not request.new_speaker_name:
|
||||
await context.abort(
|
||||
grpc.StatusCode.INVALID_ARGUMENT,
|
||||
"old_speaker_id and new_speaker_name are required",
|
||||
)
|
||||
|
||||
try:
|
||||
meeting_uuid = UUID(request.meeting_id)
|
||||
except ValueError:
|
||||
await context.abort(
|
||||
grpc.StatusCode.INVALID_ARGUMENT,
|
||||
"Invalid meeting_id",
|
||||
)
|
||||
|
||||
updated_count = 0
|
||||
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
segments = await uow.segments.get_by_meeting(MeetingId(meeting_uuid))
|
||||
|
||||
for segment in segments:
|
||||
if segment.speaker_id == request.old_speaker_id and segment.db_id:
|
||||
await uow.segments.update_speaker(
|
||||
segment.db_id,
|
||||
request.new_speaker_name,
|
||||
segment.speaker_confidence,
|
||||
)
|
||||
updated_count += 1
|
||||
|
||||
await uow.commit()
|
||||
else:
|
||||
store = self._get_memory_store()
|
||||
if meeting := store.get(request.meeting_id):
|
||||
for segment in meeting.segments:
|
||||
if segment.speaker_id == request.old_speaker_id:
|
||||
segment.speaker_id = request.new_speaker_name
|
||||
updated_count += 1
|
||||
|
||||
return noteflow_pb2.RenameSpeakerResponse(
|
||||
segments_updated=updated_count,
|
||||
success=updated_count > 0,
|
||||
)
|
||||
|
||||
async def GetDiarizationJobStatus(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.GetDiarizationJobStatusRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.DiarizationJobStatus:
|
||||
"""Return current status for a diarization job."""
|
||||
self._prune_diarization_jobs()
|
||||
job = self._diarization_jobs.get(request.job_id)
|
||||
if job is None:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
"Diarization job not found",
|
||||
)
|
||||
return job.to_proto()
|
||||
68
src/noteflow/grpc/_mixins/export.py
Normal file
68
src/noteflow/grpc/_mixins/export.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Export mixin for gRPC service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
import grpc.aio
|
||||
|
||||
from noteflow.application.services.export_service import ExportFormat, ExportService
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
|
||||
from ..proto import noteflow_pb2
|
||||
from .converters import proto_to_export_format
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .protocols import ServicerHost
|
||||
|
||||
|
||||
class ExportMixin:
|
||||
"""Mixin providing export functionality.
|
||||
|
||||
Requires host to implement ServicerHost protocol.
|
||||
Export requires database persistence.
|
||||
"""
|
||||
|
||||
async def ExportTranscript(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.ExportTranscriptRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.ExportTranscriptResponse:
|
||||
"""Export meeting transcript to specified format."""
|
||||
if not self._use_database():
|
||||
await context.abort(
|
||||
grpc.StatusCode.UNIMPLEMENTED,
|
||||
"Export requires database persistence",
|
||||
)
|
||||
|
||||
# Map proto format to ExportFormat
|
||||
fmt = proto_to_export_format(request.format)
|
||||
|
||||
export_service = ExportService(self._create_uow())
|
||||
try:
|
||||
content = await export_service.export_transcript(
|
||||
MeetingId(UUID(request.meeting_id)),
|
||||
fmt,
|
||||
)
|
||||
exporter_info = export_service.get_supported_formats()
|
||||
fmt_name = ""
|
||||
fmt_ext = ""
|
||||
for name, ext in exporter_info:
|
||||
if fmt == ExportFormat.MARKDOWN and ext == ".md":
|
||||
fmt_name, fmt_ext = name, ext
|
||||
break
|
||||
if fmt == ExportFormat.HTML and ext == ".html":
|
||||
fmt_name, fmt_ext = name, ext
|
||||
break
|
||||
|
||||
return noteflow_pb2.ExportTranscriptResponse(
|
||||
content=content,
|
||||
format_name=fmt_name,
|
||||
file_extension=fmt_ext,
|
||||
)
|
||||
except ValueError as e:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
str(e),
|
||||
)
|
||||
190
src/noteflow/grpc/_mixins/meeting.py
Normal file
190
src/noteflow/grpc/_mixins/meeting.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""Meeting management mixin for gRPC service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
import grpc.aio
|
||||
|
||||
from noteflow.domain.entities import Meeting
|
||||
from noteflow.domain.value_objects import MeetingId, MeetingState
|
||||
|
||||
from ..proto import noteflow_pb2
|
||||
from .converters import meeting_to_proto
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .protocols import ServicerHost
|
||||
|
||||
|
||||
class MeetingMixin:
|
||||
"""Mixin providing meeting CRUD functionality.
|
||||
|
||||
Requires host to implement ServicerHost protocol.
|
||||
"""
|
||||
|
||||
async def CreateMeeting(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.CreateMeetingRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.Meeting:
|
||||
"""Create a new meeting."""
|
||||
metadata = dict(request.metadata) if request.metadata else {}
|
||||
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
meeting = Meeting.create(title=request.title, metadata=metadata)
|
||||
saved = await uow.meetings.create(meeting)
|
||||
await uow.commit()
|
||||
return meeting_to_proto(saved)
|
||||
else:
|
||||
store = self._get_memory_store()
|
||||
meeting = store.create(title=request.title, metadata=metadata)
|
||||
return meeting_to_proto(meeting)
|
||||
|
||||
async def StopMeeting(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.StopMeetingRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.Meeting:
|
||||
"""Stop a meeting using graceful STOPPING -> STOPPED transition."""
|
||||
meeting_id = request.meeting_id
|
||||
|
||||
# Close audio writer if open
|
||||
if meeting_id in self._audio_writers:
|
||||
self._close_audio_writer(meeting_id)
|
||||
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
meeting = await uow.meetings.get(MeetingId(UUID(meeting_id)))
|
||||
if meeting is None:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
f"Meeting {meeting_id} not found",
|
||||
)
|
||||
try:
|
||||
# Graceful shutdown: RECORDING -> STOPPING -> STOPPED
|
||||
meeting.begin_stopping()
|
||||
meeting.stop_recording()
|
||||
except ValueError as e:
|
||||
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
||||
await uow.meetings.update(meeting)
|
||||
await uow.commit()
|
||||
return meeting_to_proto(meeting)
|
||||
store = self._get_memory_store()
|
||||
meeting = store.get(meeting_id)
|
||||
if meeting is None:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
f"Meeting {meeting_id} not found",
|
||||
)
|
||||
try:
|
||||
# Graceful shutdown: RECORDING -> STOPPING -> STOPPED
|
||||
meeting.begin_stopping()
|
||||
meeting.stop_recording()
|
||||
except ValueError as e:
|
||||
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
||||
store.update(meeting)
|
||||
return meeting_to_proto(meeting)
|
||||
|
||||
async def ListMeetings(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.ListMeetingsRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.ListMeetingsResponse:
|
||||
"""List meetings."""
|
||||
limit = request.limit or 100
|
||||
offset = request.offset or 0
|
||||
sort_desc = request.sort_order != noteflow_pb2.SORT_ORDER_CREATED_ASC
|
||||
|
||||
if self._use_database():
|
||||
states = [MeetingState(s) for s in request.states] if request.states else None
|
||||
async with self._create_uow() as uow:
|
||||
meetings, total = await uow.meetings.list_all(
|
||||
states=states,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort_desc=sort_desc,
|
||||
)
|
||||
return noteflow_pb2.ListMeetingsResponse(
|
||||
meetings=[meeting_to_proto(m, include_segments=False) for m in meetings],
|
||||
total_count=total,
|
||||
)
|
||||
else:
|
||||
store = self._get_memory_store()
|
||||
states = [MeetingState(s) for s in request.states] if request.states else None
|
||||
meetings, total = store.list_all(
|
||||
states=states,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort_desc=sort_desc,
|
||||
)
|
||||
return noteflow_pb2.ListMeetingsResponse(
|
||||
meetings=[meeting_to_proto(m, include_segments=False) for m in meetings],
|
||||
total_count=total,
|
||||
)
|
||||
|
||||
async def GetMeeting(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.GetMeetingRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.Meeting:
|
||||
"""Get meeting details."""
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
meeting = await uow.meetings.get(MeetingId(UUID(request.meeting_id)))
|
||||
if meeting is None:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
f"Meeting {request.meeting_id} not found",
|
||||
)
|
||||
# Load segments if requested
|
||||
if request.include_segments:
|
||||
segments = await uow.segments.get_by_meeting(meeting.id)
|
||||
meeting.segments = list(segments)
|
||||
# Load summary if requested
|
||||
if request.include_summary:
|
||||
summary = await uow.summaries.get_by_meeting(meeting.id)
|
||||
meeting.summary = summary
|
||||
return meeting_to_proto(
|
||||
meeting,
|
||||
include_segments=request.include_segments,
|
||||
include_summary=request.include_summary,
|
||||
)
|
||||
store = self._get_memory_store()
|
||||
meeting = store.get(request.meeting_id)
|
||||
if meeting is None:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
f"Meeting {request.meeting_id} not found",
|
||||
)
|
||||
return meeting_to_proto(
|
||||
meeting,
|
||||
include_segments=request.include_segments,
|
||||
include_summary=request.include_summary,
|
||||
)
|
||||
|
||||
async def DeleteMeeting(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.DeleteMeetingRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.DeleteMeetingResponse:
|
||||
"""Delete a meeting."""
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
success = await uow.meetings.delete(MeetingId(UUID(request.meeting_id)))
|
||||
if success:
|
||||
await uow.commit()
|
||||
return noteflow_pb2.DeleteMeetingResponse(success=True)
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
f"Meeting {request.meeting_id} not found",
|
||||
)
|
||||
store = self._get_memory_store()
|
||||
success = store.delete(request.meeting_id)
|
||||
if not success:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
f"Meeting {request.meeting_id} not found",
|
||||
)
|
||||
return noteflow_pb2.DeleteMeetingResponse(success=True)
|
||||
114
src/noteflow/grpc/_mixins/protocols.py
Normal file
114
src/noteflow/grpc/_mixins/protocols.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Protocol contracts for gRPC service mixins."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from noteflow.domain.entities import Meeting
|
||||
from noteflow.infrastructure.asr import FasterWhisperEngine, Segmenter, StreamingVad
|
||||
from noteflow.infrastructure.audio.writer import MeetingAudioWriter
|
||||
from noteflow.infrastructure.diarization import DiarizationEngine, SpeakerTurn
|
||||
from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWork
|
||||
from noteflow.infrastructure.security.crypto import AesGcmCryptoBox
|
||||
|
||||
from ..meeting_store import MeetingStore
|
||||
|
||||
|
||||
class ServicerHost(Protocol):
|
||||
"""Protocol defining shared state and methods for service mixins.
|
||||
|
||||
All mixins should type-hint `self` as `ServicerHost` to access these
|
||||
attributes and methods from the host NoteFlowServicer class.
|
||||
"""
|
||||
|
||||
# Configuration
|
||||
_session_factory: async_sessionmaker[AsyncSession] | None
|
||||
_memory_store: MeetingStore | None
|
||||
_meetings_dir: Path
|
||||
_crypto: AesGcmCryptoBox
|
||||
|
||||
# Engines
|
||||
_asr_engine: FasterWhisperEngine | None
|
||||
_diarization_engine: DiarizationEngine | None
|
||||
_summarization_service: object | None
|
||||
_diarization_refinement_enabled: bool
|
||||
|
||||
# Audio writers
|
||||
_audio_writers: dict[str, MeetingAudioWriter]
|
||||
_audio_write_failed: set[str]
|
||||
|
||||
# VAD and segmentation state per meeting
|
||||
_vad_instances: dict[str, StreamingVad]
|
||||
_segmenters: dict[str, Segmenter]
|
||||
_was_speaking: dict[str, bool]
|
||||
_segment_counters: dict[str, int]
|
||||
_stream_formats: dict[str, tuple[int, int]]
|
||||
_active_streams: set[str]
|
||||
|
||||
# Partial transcription state per meeting
|
||||
_partial_buffers: dict[str, list[NDArray[np.float32]]]
|
||||
_last_partial_time: dict[str, float]
|
||||
_last_partial_text: dict[str, str]
|
||||
|
||||
# Streaming diarization state per meeting
|
||||
_diarization_turns: dict[str, list[SpeakerTurn]]
|
||||
_diarization_stream_time: dict[str, float]
|
||||
_diarization_streaming_failed: set[str]
|
||||
|
||||
# Constants
|
||||
DEFAULT_SAMPLE_RATE: int
|
||||
SUPPORTED_SAMPLE_RATES: list[int]
|
||||
PARTIAL_CADENCE_SECONDS: float
|
||||
MIN_PARTIAL_AUDIO_SECONDS: float
|
||||
|
||||
def _use_database(self) -> bool:
|
||||
"""Check if database persistence is configured."""
|
||||
...
|
||||
|
||||
def _get_memory_store(self) -> MeetingStore:
|
||||
"""Get the in-memory store, raising if not configured."""
|
||||
...
|
||||
|
||||
def _create_uow(self) -> SqlAlchemyUnitOfWork:
|
||||
"""Create a new Unit of Work."""
|
||||
...
|
||||
|
||||
def _next_segment_id(self, meeting_id: str, fallback: int = 0) -> int:
|
||||
"""Get and increment the next segment id for a meeting."""
|
||||
...
|
||||
|
||||
def _init_streaming_state(self, meeting_id: str, next_segment_id: int) -> None:
|
||||
"""Initialize VAD, Segmenter, speaking state, and partial buffers."""
|
||||
...
|
||||
|
||||
def _cleanup_streaming_state(self, meeting_id: str) -> None:
|
||||
"""Clean up streaming state for a meeting."""
|
||||
...
|
||||
|
||||
def _ensure_meeting_dek(self, meeting: Meeting) -> tuple[bytes, bytes, bool]:
|
||||
"""Ensure meeting has a DEK, generating one if needed."""
|
||||
...
|
||||
|
||||
def _start_meeting_if_needed(self, meeting: Meeting) -> tuple[bool, str | None]:
|
||||
"""Start recording on meeting if not already recording."""
|
||||
...
|
||||
|
||||
def _open_meeting_audio_writer(
|
||||
self,
|
||||
meeting_id: str,
|
||||
dek: bytes,
|
||||
wrapped_dek: bytes,
|
||||
) -> None:
|
||||
"""Open audio writer for a meeting."""
|
||||
...
|
||||
|
||||
def _close_audio_writer(self, meeting_id: str) -> None:
|
||||
"""Close and remove the audio writer for a meeting."""
|
||||
...
|
||||
576
src/noteflow/grpc/_mixins/streaming.py
Normal file
576
src/noteflow/grpc/_mixins/streaming.py
Normal file
@@ -0,0 +1,576 @@
|
||||
"""Streaming audio transcription mixin for gRPC service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import struct
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
import grpc.aio
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
|
||||
from ..proto import noteflow_pb2
|
||||
from .converters import create_segment_from_asr, create_vad_update, segment_to_proto_update
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .protocols import ServicerHost
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _StreamSessionInit:
|
||||
"""Result of stream session initialization."""
|
||||
|
||||
next_segment_id: int
|
||||
error_code: int | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
"""Check if initialization succeeded."""
|
||||
return self.error_code is None
|
||||
|
||||
|
||||
class StreamingMixin:
|
||||
"""Mixin providing streaming transcription functionality.
|
||||
|
||||
Requires host to implement ServicerHost protocol.
|
||||
"""
|
||||
|
||||
async def StreamTranscription(
|
||||
self: ServicerHost,
|
||||
request_iterator: AsyncIterator[noteflow_pb2.AudioChunk],
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> AsyncIterator[noteflow_pb2.TranscriptUpdate]:
|
||||
"""Handle bidirectional audio streaming with persistence.
|
||||
|
||||
Receive audio chunks from client, process through ASR,
|
||||
persist segments, and yield transcript updates.
|
||||
"""
|
||||
if self._asr_engine is None or not self._asr_engine.is_loaded:
|
||||
await context.abort(
|
||||
grpc.StatusCode.FAILED_PRECONDITION,
|
||||
"ASR engine not loaded",
|
||||
)
|
||||
|
||||
current_meeting_id: str | None = None
|
||||
|
||||
try:
|
||||
async for chunk in request_iterator:
|
||||
meeting_id = chunk.meeting_id
|
||||
if not meeting_id:
|
||||
await context.abort(
|
||||
grpc.StatusCode.INVALID_ARGUMENT,
|
||||
"meeting_id required",
|
||||
)
|
||||
|
||||
# Initialize stream on first chunk
|
||||
if current_meeting_id is None:
|
||||
init_result = await self._init_stream_for_meeting(meeting_id, context)
|
||||
if init_result is None:
|
||||
return # Error already sent via context.abort
|
||||
current_meeting_id = meeting_id
|
||||
elif meeting_id != current_meeting_id:
|
||||
await context.abort(
|
||||
grpc.StatusCode.INVALID_ARGUMENT,
|
||||
"Stream may only contain a single meeting_id",
|
||||
)
|
||||
|
||||
# Process audio chunk
|
||||
async for update in self._process_stream_chunk(current_meeting_id, chunk, context):
|
||||
yield update
|
||||
|
||||
# Flush any remaining audio from segmenter
|
||||
if current_meeting_id and current_meeting_id in self._segmenters:
|
||||
async for update in self._flush_segmenter(current_meeting_id):
|
||||
yield update
|
||||
finally:
|
||||
if current_meeting_id:
|
||||
self._cleanup_streaming_state(current_meeting_id)
|
||||
self._close_audio_writer(current_meeting_id)
|
||||
self._active_streams.discard(current_meeting_id)
|
||||
|
||||
async def _init_stream_for_meeting(
|
||||
self: ServicerHost,
|
||||
meeting_id: str,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> _StreamSessionInit | None:
|
||||
"""Initialize streaming for a meeting.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting ID string.
|
||||
context: gRPC context for error handling.
|
||||
|
||||
Returns:
|
||||
Initialization result, or None if error was sent.
|
||||
"""
|
||||
if meeting_id in self._active_streams:
|
||||
await context.abort(
|
||||
grpc.StatusCode.FAILED_PRECONDITION,
|
||||
f"Meeting {meeting_id} already streaming",
|
||||
)
|
||||
|
||||
self._active_streams.add(meeting_id)
|
||||
|
||||
if self._use_database():
|
||||
init_result = await self._init_stream_session_db(meeting_id)
|
||||
else:
|
||||
init_result = self._init_stream_session_memory(meeting_id)
|
||||
|
||||
if not init_result.success:
|
||||
self._active_streams.discard(meeting_id)
|
||||
await context.abort(init_result.error_code, init_result.error_message or "")
|
||||
|
||||
return init_result
|
||||
|
||||
async def _init_stream_session_db(
|
||||
self: ServicerHost,
|
||||
meeting_id: str,
|
||||
) -> _StreamSessionInit:
|
||||
"""Initialize stream session using database persistence.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting ID string.
|
||||
|
||||
Returns:
|
||||
Stream session initialization result.
|
||||
"""
|
||||
async with self._create_uow() as uow:
|
||||
meeting = await uow.meetings.get(MeetingId(UUID(meeting_id)))
|
||||
if meeting is None:
|
||||
return _StreamSessionInit(
|
||||
next_segment_id=0,
|
||||
error_code=grpc.StatusCode.NOT_FOUND,
|
||||
error_message=f"Meeting {meeting_id} not found",
|
||||
)
|
||||
|
||||
dek, wrapped_dek, dek_updated = self._ensure_meeting_dek(meeting)
|
||||
recording_updated, error_msg = self._start_meeting_if_needed(meeting)
|
||||
|
||||
if error_msg:
|
||||
return _StreamSessionInit(
|
||||
next_segment_id=0,
|
||||
error_code=grpc.StatusCode.INVALID_ARGUMENT,
|
||||
error_message=error_msg,
|
||||
)
|
||||
|
||||
if dek_updated or recording_updated:
|
||||
await uow.meetings.update(meeting)
|
||||
await uow.commit()
|
||||
|
||||
next_segment_id = await uow.segments.get_next_segment_id(meeting.id)
|
||||
self._open_meeting_audio_writer(meeting_id, dek, wrapped_dek)
|
||||
self._init_streaming_state(meeting_id, next_segment_id)
|
||||
|
||||
return _StreamSessionInit(next_segment_id=next_segment_id)
|
||||
|
||||
def _init_stream_session_memory(
|
||||
self: ServicerHost,
|
||||
meeting_id: str,
|
||||
) -> _StreamSessionInit:
|
||||
"""Initialize stream session using in-memory store.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting ID string.
|
||||
|
||||
Returns:
|
||||
Stream session initialization result.
|
||||
"""
|
||||
store = self._get_memory_store()
|
||||
meeting = store.get(meeting_id)
|
||||
if meeting is None:
|
||||
return _StreamSessionInit(
|
||||
next_segment_id=0,
|
||||
error_code=grpc.StatusCode.NOT_FOUND,
|
||||
error_message=f"Meeting {meeting_id} not found",
|
||||
)
|
||||
|
||||
dek, wrapped_dek, dek_updated = self._ensure_meeting_dek(meeting)
|
||||
recording_updated, error_msg = self._start_meeting_if_needed(meeting)
|
||||
|
||||
if error_msg:
|
||||
return _StreamSessionInit(
|
||||
next_segment_id=0,
|
||||
error_code=grpc.StatusCode.INVALID_ARGUMENT,
|
||||
error_message=error_msg,
|
||||
)
|
||||
|
||||
if dek_updated or recording_updated:
|
||||
store.update(meeting)
|
||||
|
||||
next_segment_id = meeting.next_segment_id
|
||||
self._open_meeting_audio_writer(meeting_id, dek, wrapped_dek)
|
||||
self._init_streaming_state(meeting_id, next_segment_id)
|
||||
|
||||
return _StreamSessionInit(next_segment_id=next_segment_id)
|
||||
|
||||
async def _process_stream_chunk(
|
||||
self: ServicerHost,
|
||||
meeting_id: str,
|
||||
chunk: noteflow_pb2.AudioChunk,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> AsyncIterator[noteflow_pb2.TranscriptUpdate]:
|
||||
"""Process a single audio chunk from the stream.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting ID string.
|
||||
chunk: Audio chunk from client.
|
||||
context: gRPC context for error handling.
|
||||
|
||||
Yields:
|
||||
Transcript updates from processing.
|
||||
"""
|
||||
try:
|
||||
sample_rate, channels = self._normalize_stream_format(
|
||||
meeting_id,
|
||||
chunk.sample_rate,
|
||||
chunk.channels,
|
||||
)
|
||||
except ValueError as e:
|
||||
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
||||
|
||||
audio = self._decode_audio_chunk(chunk)
|
||||
if audio is None:
|
||||
return
|
||||
|
||||
try:
|
||||
audio = self._convert_audio_format(audio, sample_rate, channels)
|
||||
except ValueError as e:
|
||||
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
||||
|
||||
# Write to encrypted audio file
|
||||
self._write_audio_chunk_safe(meeting_id, audio)
|
||||
|
||||
# VAD-driven segmentation
|
||||
async for update in self._process_audio_with_vad(meeting_id, audio):
|
||||
yield update
|
||||
|
||||
def _normalize_stream_format(
|
||||
self: ServicerHost,
|
||||
meeting_id: str,
|
||||
sample_rate: int,
|
||||
channels: int,
|
||||
) -> tuple[int, int]:
|
||||
"""Validate and persist stream audio format for a meeting."""
|
||||
normalized_rate = sample_rate or self.DEFAULT_SAMPLE_RATE
|
||||
normalized_channels = channels or 1
|
||||
|
||||
if normalized_rate not in self.SUPPORTED_SAMPLE_RATES:
|
||||
raise ValueError(
|
||||
"Unsupported sample_rate "
|
||||
f"{normalized_rate}; supported: {self.SUPPORTED_SAMPLE_RATES}"
|
||||
)
|
||||
if normalized_channels < 1:
|
||||
raise ValueError("channels must be >= 1")
|
||||
|
||||
existing = self._stream_formats.get(meeting_id)
|
||||
if existing and existing != (normalized_rate, normalized_channels):
|
||||
raise ValueError("Stream audio format cannot change mid-stream")
|
||||
|
||||
self._stream_formats.setdefault(meeting_id, (normalized_rate, normalized_channels))
|
||||
return normalized_rate, normalized_channels
|
||||
|
||||
def _convert_audio_format(
|
||||
self: ServicerHost,
|
||||
audio: NDArray[np.float32],
|
||||
sample_rate: int,
|
||||
channels: int,
|
||||
) -> NDArray[np.float32]:
|
||||
"""Downmix/resample audio to the server's expected format."""
|
||||
if channels > 1:
|
||||
if audio.size % channels != 0:
|
||||
raise ValueError("Audio buffer size is not divisible by channel count")
|
||||
audio = audio.reshape(-1, channels).mean(axis=1)
|
||||
|
||||
if sample_rate != self.DEFAULT_SAMPLE_RATE:
|
||||
audio = self._resample_audio(audio, sample_rate, self.DEFAULT_SAMPLE_RATE)
|
||||
|
||||
return audio
|
||||
|
||||
@staticmethod
|
||||
def _resample_audio(
|
||||
audio: NDArray[np.float32],
|
||||
src_rate: int,
|
||||
dst_rate: int,
|
||||
) -> NDArray[np.float32]:
|
||||
"""Resample audio using linear interpolation."""
|
||||
if src_rate == dst_rate or audio.size == 0:
|
||||
return audio
|
||||
|
||||
ratio = dst_rate / src_rate
|
||||
new_length = round(audio.shape[0] * ratio)
|
||||
if new_length <= 0:
|
||||
return np.array([], dtype=np.float32)
|
||||
|
||||
old_indices = np.arange(audio.shape[0])
|
||||
new_indices = np.arange(new_length) / ratio
|
||||
return np.interp(new_indices, old_indices, audio).astype(np.float32)
|
||||
|
||||
def _decode_audio_chunk(
|
||||
self: ServicerHost,
|
||||
chunk: noteflow_pb2.AudioChunk,
|
||||
) -> NDArray[np.float32] | None:
|
||||
"""Decode audio chunk from protobuf to numpy array."""
|
||||
if not chunk.audio_data:
|
||||
return None
|
||||
try:
|
||||
return np.frombuffer(chunk.audio_data, dtype=np.float32)
|
||||
except (ValueError, struct.error) as e:
|
||||
logger.warning("Failed to decode audio chunk: %s", e)
|
||||
return None
|
||||
|
||||
def _write_audio_chunk_safe(
|
||||
self: ServicerHost,
|
||||
meeting_id: str,
|
||||
audio: NDArray[np.float32],
|
||||
) -> None:
|
||||
"""Write audio chunk to encrypted file, logging errors without raising.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting ID string.
|
||||
audio: Audio samples to write.
|
||||
"""
|
||||
if meeting_id not in self._audio_writers:
|
||||
return
|
||||
if meeting_id in self._audio_write_failed:
|
||||
return # Already failed, skip to avoid log spam
|
||||
try:
|
||||
self._audio_writers[meeting_id].write_chunk(audio)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Audio write failed for meeting %s: %s. Recording may be incomplete.",
|
||||
meeting_id,
|
||||
e,
|
||||
)
|
||||
self._audio_write_failed.add(meeting_id)
|
||||
|
||||
async def _process_audio_with_vad(
|
||||
self: ServicerHost,
|
||||
meeting_id: str,
|
||||
audio: NDArray[np.float32],
|
||||
) -> AsyncIterator[noteflow_pb2.TranscriptUpdate]:
|
||||
"""Process audio chunk through VAD and Segmenter.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting identifier.
|
||||
audio: Audio samples (float32, mono).
|
||||
|
||||
Yields:
|
||||
TranscriptUpdates for VAD events, partials, and finals.
|
||||
"""
|
||||
vad = self._vad_instances.get(meeting_id)
|
||||
segmenter = self._segmenters.get(meeting_id)
|
||||
|
||||
if vad is None or segmenter is None:
|
||||
return
|
||||
|
||||
# Get VAD decision
|
||||
is_speech = vad.process_chunk(audio)
|
||||
|
||||
# Streaming diarization (optional) - call mixin method if available
|
||||
if hasattr(self, "_process_streaming_diarization"):
|
||||
self._process_streaming_diarization(meeting_id, audio)
|
||||
|
||||
# Emit VAD state change events
|
||||
was_speaking = self._was_speaking.get(meeting_id, False)
|
||||
if is_speech and not was_speaking:
|
||||
# Speech started
|
||||
yield create_vad_update(meeting_id, noteflow_pb2.UPDATE_TYPE_VAD_START)
|
||||
self._was_speaking[meeting_id] = True
|
||||
elif not is_speech and was_speaking:
|
||||
# Speech ended
|
||||
yield create_vad_update(meeting_id, noteflow_pb2.UPDATE_TYPE_VAD_END)
|
||||
self._was_speaking[meeting_id] = False
|
||||
|
||||
# Buffer audio for partial transcription
|
||||
if is_speech:
|
||||
if meeting_id in self._partial_buffers:
|
||||
self._partial_buffers[meeting_id].append(audio.copy())
|
||||
|
||||
# Check if we should emit a partial
|
||||
partial_update = await self._maybe_emit_partial(meeting_id)
|
||||
if partial_update is not None:
|
||||
yield partial_update
|
||||
|
||||
# Process through segmenter
|
||||
for audio_segment in segmenter.process_audio(audio, is_speech):
|
||||
# Clear partial buffer when we get a final segment
|
||||
self._clear_partial_buffer(meeting_id)
|
||||
async for update in self._process_audio_segment(
|
||||
meeting_id,
|
||||
audio_segment.audio,
|
||||
audio_segment.start_time,
|
||||
):
|
||||
yield update
|
||||
|
||||
async def _maybe_emit_partial(
|
||||
self: ServicerHost,
|
||||
meeting_id: str,
|
||||
) -> noteflow_pb2.TranscriptUpdate | None:
|
||||
"""Check if it's time to emit a partial and generate if so.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting identifier.
|
||||
|
||||
Returns:
|
||||
TranscriptUpdate with partial text, or None if not time yet.
|
||||
"""
|
||||
if self._asr_engine is None or not self._asr_engine.is_loaded:
|
||||
return None
|
||||
|
||||
last_time = self._last_partial_time.get(meeting_id, 0)
|
||||
now = time.time()
|
||||
|
||||
# Check if enough time has passed since last partial
|
||||
if now - last_time < self.PARTIAL_CADENCE_SECONDS:
|
||||
return None
|
||||
|
||||
# Check if we have enough audio
|
||||
buffer = self._partial_buffers.get(meeting_id, [])
|
||||
if not buffer:
|
||||
return None
|
||||
|
||||
# Concatenate buffered audio
|
||||
combined = np.concatenate(buffer)
|
||||
audio_seconds = len(combined) / self.DEFAULT_SAMPLE_RATE
|
||||
|
||||
if audio_seconds < self.MIN_PARTIAL_AUDIO_SECONDS:
|
||||
return None
|
||||
|
||||
# Run inference on buffered audio (async to avoid blocking event loop)
|
||||
results = await self._asr_engine.transcribe_async(combined)
|
||||
partial_text = " ".join(result.text for result in results)
|
||||
|
||||
# Clear buffer after inference to keep partials incremental and bounded
|
||||
self._partial_buffers[meeting_id] = []
|
||||
|
||||
# Only emit if text changed (debounce)
|
||||
last_text = self._last_partial_text.get(meeting_id, "")
|
||||
if partial_text and partial_text != last_text:
|
||||
self._last_partial_time[meeting_id] = now
|
||||
self._last_partial_text[meeting_id] = partial_text
|
||||
return noteflow_pb2.TranscriptUpdate(
|
||||
meeting_id=meeting_id,
|
||||
update_type=noteflow_pb2.UPDATE_TYPE_PARTIAL,
|
||||
partial_text=partial_text,
|
||||
server_timestamp=now,
|
||||
)
|
||||
|
||||
self._last_partial_time[meeting_id] = now
|
||||
return None
|
||||
|
||||
def _clear_partial_buffer(self: ServicerHost, meeting_id: str) -> None:
|
||||
"""Clear the partial buffer and reset state after a final is emitted.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting identifier.
|
||||
"""
|
||||
if meeting_id in self._partial_buffers:
|
||||
self._partial_buffers[meeting_id] = []
|
||||
if meeting_id in self._last_partial_text:
|
||||
self._last_partial_text[meeting_id] = ""
|
||||
if meeting_id in self._last_partial_time:
|
||||
self._last_partial_time[meeting_id] = time.time()
|
||||
|
||||
async def _flush_segmenter(
|
||||
self: ServicerHost,
|
||||
meeting_id: str,
|
||||
) -> AsyncIterator[noteflow_pb2.TranscriptUpdate]:
|
||||
"""Flush remaining audio from segmenter at stream end.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting identifier.
|
||||
|
||||
Yields:
|
||||
TranscriptUpdates for final segment.
|
||||
"""
|
||||
segmenter = self._segmenters.get(meeting_id)
|
||||
if segmenter is None:
|
||||
return
|
||||
|
||||
# Clear partial buffer since we're flushing to final
|
||||
self._clear_partial_buffer(meeting_id)
|
||||
|
||||
final_segment = segmenter.flush()
|
||||
if final_segment is not None:
|
||||
async for update in self._process_audio_segment(
|
||||
meeting_id,
|
||||
final_segment.audio,
|
||||
final_segment.start_time,
|
||||
):
|
||||
yield update
|
||||
|
||||
async def _process_audio_segment(
|
||||
self: ServicerHost,
|
||||
meeting_id: str,
|
||||
audio: NDArray[np.float32],
|
||||
segment_start_time: float,
|
||||
) -> AsyncIterator[noteflow_pb2.TranscriptUpdate]:
|
||||
"""Process a complete audio segment through ASR.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting identifier.
|
||||
audio: Complete audio segment.
|
||||
segment_start_time: Segment start time in stream seconds.
|
||||
|
||||
Yields:
|
||||
TranscriptUpdates for transcribed segments.
|
||||
"""
|
||||
if len(audio) == 0 or self._asr_engine is None:
|
||||
return
|
||||
|
||||
if self._use_database():
|
||||
async with self._create_uow() as uow:
|
||||
meeting = await uow.meetings.get(MeetingId(UUID(meeting_id)))
|
||||
if meeting is None:
|
||||
return
|
||||
|
||||
results = await self._asr_engine.transcribe_async(audio)
|
||||
for result in results:
|
||||
segment_id = self._next_segment_id(
|
||||
meeting_id,
|
||||
fallback=meeting.next_segment_id,
|
||||
)
|
||||
segment = create_segment_from_asr(
|
||||
meeting.id,
|
||||
segment_id,
|
||||
result,
|
||||
segment_start_time,
|
||||
)
|
||||
# Call diarization mixin method if available
|
||||
if hasattr(self, "_maybe_assign_speaker"):
|
||||
self._maybe_assign_speaker(meeting_id, segment)
|
||||
meeting.add_segment(segment)
|
||||
await uow.segments.add(meeting.id, segment)
|
||||
await uow.commit()
|
||||
yield segment_to_proto_update(meeting_id, segment)
|
||||
else:
|
||||
store = self._get_memory_store()
|
||||
meeting = store.get(meeting_id)
|
||||
if meeting is None:
|
||||
return
|
||||
results = await self._asr_engine.transcribe_async(audio)
|
||||
for result in results:
|
||||
segment_id = self._next_segment_id(
|
||||
meeting_id,
|
||||
fallback=meeting.next_segment_id,
|
||||
)
|
||||
segment = create_segment_from_asr(
|
||||
meeting.id,
|
||||
segment_id,
|
||||
result,
|
||||
segment_start_time,
|
||||
)
|
||||
# Call diarization mixin method if available
|
||||
if hasattr(self, "_maybe_assign_speaker"):
|
||||
self._maybe_assign_speaker(meeting_id, segment)
|
||||
store.add_segment(meeting_id, segment)
|
||||
yield segment_to_proto_update(meeting_id, segment)
|
||||
149
src/noteflow/grpc/_mixins/summarization.py
Normal file
149
src/noteflow/grpc/_mixins/summarization.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Summarization mixin for gRPC service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
import grpc.aio
|
||||
|
||||
from noteflow.domain.entities import Segment, Summary
|
||||
from noteflow.domain.summarization import ProviderUnavailableError
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
|
||||
from ..proto import noteflow_pb2
|
||||
from .converters import summary_to_proto
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.application.services.summarization_service import SummarizationService
|
||||
|
||||
from .protocols import ServicerHost
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SummarizationMixin:
|
||||
"""Mixin providing summarization functionality.
|
||||
|
||||
Requires host to implement ServicerHost protocol.
|
||||
"""
|
||||
|
||||
_summarization_service: SummarizationService | None
|
||||
|
||||
async def GenerateSummary(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.GenerateSummaryRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.Summary:
|
||||
"""Generate meeting summary using SummarizationService with fallback."""
|
||||
if self._use_database():
|
||||
return await self._generate_summary_db(request, context)
|
||||
|
||||
return await self._generate_summary_memory(request, context)
|
||||
|
||||
async def _generate_summary_db(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.GenerateSummaryRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.Summary:
|
||||
"""Generate summary for a meeting stored in the database.
|
||||
|
||||
The potentially slow summarization step is executed outside the UoW to
|
||||
avoid holding database connections while waiting on LLMs.
|
||||
"""
|
||||
meeting_id = MeetingId(UUID(request.meeting_id))
|
||||
|
||||
# 1) Load meeting, existing summary, and segments inside a short UoW
|
||||
async with self._create_uow() as uow:
|
||||
meeting = await uow.meetings.get(meeting_id)
|
||||
if meeting is None:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
f"Meeting {request.meeting_id} not found",
|
||||
)
|
||||
|
||||
existing = await uow.summaries.get_by_meeting(meeting.id)
|
||||
if existing and not request.force_regenerate:
|
||||
return summary_to_proto(existing)
|
||||
|
||||
segments = list(await uow.segments.get_by_meeting(meeting.id))
|
||||
|
||||
# 2) Run summarization outside DB transaction
|
||||
summary = await self._summarize_or_placeholder(meeting_id, segments)
|
||||
|
||||
# 3) Persist in a fresh UoW
|
||||
async with self._create_uow() as uow:
|
||||
saved = await uow.summaries.save(summary)
|
||||
await uow.commit()
|
||||
|
||||
return summary_to_proto(saved)
|
||||
|
||||
async def _generate_summary_memory(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.GenerateSummaryRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.Summary:
|
||||
"""Generate summary for meetings held in the in-memory store."""
|
||||
store = self._get_memory_store()
|
||||
meeting = store.get(request.meeting_id)
|
||||
if meeting is None:
|
||||
await context.abort(
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
f"Meeting {request.meeting_id} not found",
|
||||
)
|
||||
|
||||
if meeting.summary and not request.force_regenerate:
|
||||
return summary_to_proto(meeting.summary)
|
||||
|
||||
summary = await self._summarize_or_placeholder(meeting.id, meeting.segments)
|
||||
store.set_summary(request.meeting_id, summary)
|
||||
return summary_to_proto(summary)
|
||||
|
||||
async def _summarize_or_placeholder(
|
||||
self: ServicerHost,
|
||||
meeting_id: MeetingId,
|
||||
segments: list[Segment],
|
||||
) -> Summary:
|
||||
"""Try to summarize via service, fallback to placeholder on failure."""
|
||||
if self._summarization_service is None:
|
||||
logger.warning("SummarizationService not configured; using placeholder summary")
|
||||
return self._generate_placeholder_summary(meeting_id, segments)
|
||||
|
||||
try:
|
||||
result = await self._summarization_service.summarize(
|
||||
meeting_id=meeting_id,
|
||||
segments=segments,
|
||||
)
|
||||
logger.info(
|
||||
"Generated summary using %s (fallback=%s)",
|
||||
result.provider_used,
|
||||
result.fallback_used,
|
||||
)
|
||||
return result.summary
|
||||
except ProviderUnavailableError as exc:
|
||||
logger.warning("Summarization provider unavailable; using placeholder: %s", exc)
|
||||
except (TimeoutError, RuntimeError, ValueError) as exc:
|
||||
logger.exception(
|
||||
"Summarization failed (%s); using placeholder summary", type(exc).__name__
|
||||
)
|
||||
|
||||
return self._generate_placeholder_summary(meeting_id, segments)
|
||||
|
||||
def _generate_placeholder_summary(
|
||||
self: ServicerHost,
|
||||
meeting_id: MeetingId,
|
||||
segments: list[Segment],
|
||||
) -> Summary:
|
||||
"""Generate a lightweight placeholder summary when summarization fails."""
|
||||
full_text = " ".join(s.text for s in segments)
|
||||
executive = f"{full_text[:200]}..." if len(full_text) > 200 else full_text
|
||||
executive = executive or "No transcript available."
|
||||
|
||||
return Summary(
|
||||
meeting_id=meeting_id,
|
||||
executive_summary=executive,
|
||||
key_points=[],
|
||||
action_items=[],
|
||||
model_version="placeholder-v0",
|
||||
)
|
||||
@@ -12,6 +12,8 @@ from typing import TYPE_CHECKING, Final
|
||||
|
||||
import grpc
|
||||
|
||||
from noteflow.config.constants import DEFAULT_SAMPLE_RATE
|
||||
|
||||
from .proto import noteflow_pb2, noteflow_pb2_grpc
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -88,6 +90,35 @@ class ExportResult:
|
||||
file_extension: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiarizationResult:
|
||||
"""Result of speaker diarization refinement."""
|
||||
|
||||
job_id: str
|
||||
status: str
|
||||
segments_updated: int
|
||||
speaker_ids: list[str]
|
||||
error_message: str = ""
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
"""Check if diarization succeeded."""
|
||||
return self.status == "completed" and not self.error_message
|
||||
|
||||
@property
|
||||
def is_terminal(self) -> bool:
|
||||
"""Check if job reached a terminal state."""
|
||||
return self.status in {"completed", "failed"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RenameSpeakerResult:
|
||||
"""Result of speaker rename operation."""
|
||||
|
||||
segments_updated: int
|
||||
success: bool
|
||||
|
||||
|
||||
# Callback types
|
||||
TranscriptCallback = Callable[[TranscriptSegment], None]
|
||||
ConnectionCallback = Callable[[bool, str], None]
|
||||
@@ -146,25 +177,7 @@ class NoteFlowClient:
|
||||
True if connected successfully.
|
||||
"""
|
||||
try:
|
||||
self._channel = grpc.insecure_channel(
|
||||
self._server_address,
|
||||
options=[
|
||||
("grpc.max_send_message_length", 100 * 1024 * 1024),
|
||||
("grpc.max_receive_message_length", 100 * 1024 * 1024),
|
||||
],
|
||||
)
|
||||
|
||||
# Wait for channel to be ready
|
||||
grpc.channel_ready_future(self._channel).result(timeout=timeout)
|
||||
|
||||
self._stub = noteflow_pb2_grpc.NoteFlowServiceStub(self._channel)
|
||||
self._connected = True
|
||||
|
||||
logger.info("Connected to server at %s", self._server_address)
|
||||
self._notify_connection(True, "Connected")
|
||||
|
||||
return True
|
||||
|
||||
return self._extracted_from_connect_11(timeout)
|
||||
except grpc.FutureTimeoutError:
|
||||
logger.error("Connection timeout: %s", self._server_address)
|
||||
self._notify_connection(False, "Connection timeout")
|
||||
@@ -174,6 +187,27 @@ class NoteFlowClient:
|
||||
self._notify_connection(False, str(e))
|
||||
return False
|
||||
|
||||
# TODO Rename this here and in `connect`
|
||||
def _extracted_from_connect_11(self, timeout):
|
||||
self._channel = grpc.insecure_channel(
|
||||
self._server_address,
|
||||
options=[
|
||||
("grpc.max_send_message_length", 100 * 1024 * 1024),
|
||||
("grpc.max_receive_message_length", 100 * 1024 * 1024),
|
||||
],
|
||||
)
|
||||
|
||||
# Wait for channel to be ready
|
||||
grpc.channel_ready_future(self._channel).result(timeout=timeout)
|
||||
|
||||
self._stub = noteflow_pb2_grpc.NoteFlowServiceStub(self._channel)
|
||||
self._connected = True
|
||||
|
||||
logger.info("Connected to server at %s", self._server_address)
|
||||
self._notify_connection(True, "Connected")
|
||||
|
||||
return True
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from the server."""
|
||||
self.stop_streaming()
|
||||
@@ -427,7 +461,7 @@ class NoteFlowClient:
|
||||
meeting_id=meeting_id,
|
||||
audio_data=audio.tobytes(),
|
||||
timestamp=timestamp,
|
||||
sample_rate=16000,
|
||||
sample_rate=DEFAULT_SAMPLE_RATE,
|
||||
channels=1,
|
||||
)
|
||||
except queue.Empty:
|
||||
@@ -788,3 +822,110 @@ class NoteFlowClient:
|
||||
"html": noteflow_pb2.EXPORT_FORMAT_HTML,
|
||||
}
|
||||
return format_map.get(format_name.lower(), noteflow_pb2.EXPORT_FORMAT_MARKDOWN)
|
||||
|
||||
@staticmethod
|
||||
def _job_status_to_str(status: int) -> str:
|
||||
"""Convert job status enum to string."""
|
||||
# JobStatus enum values extend int, so they work as dictionary keys
|
||||
status_map = {
|
||||
noteflow_pb2.JOB_STATUS_UNSPECIFIED: "unspecified",
|
||||
noteflow_pb2.JOB_STATUS_QUEUED: "queued",
|
||||
noteflow_pb2.JOB_STATUS_RUNNING: "running",
|
||||
noteflow_pb2.JOB_STATUS_COMPLETED: "completed",
|
||||
noteflow_pb2.JOB_STATUS_FAILED: "failed",
|
||||
}
|
||||
return status_map.get(status, "unspecified") # type: ignore[arg-type]
|
||||
|
||||
# =========================================================================
|
||||
# Speaker Diarization Methods
|
||||
# =========================================================================
|
||||
|
||||
def refine_speaker_diarization(
|
||||
self,
|
||||
meeting_id: str,
|
||||
num_speakers: int | None = None,
|
||||
) -> DiarizationResult | None:
|
||||
"""Run post-meeting speaker diarization refinement.
|
||||
|
||||
Requests the server to run offline diarization on the meeting audio
|
||||
as a background job and update segment speaker assignments.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting ID.
|
||||
num_speakers: Optional known number of speakers (auto-detect if None).
|
||||
|
||||
Returns:
|
||||
DiarizationResult with job status or None if request fails.
|
||||
"""
|
||||
if not self._stub:
|
||||
return None
|
||||
|
||||
try:
|
||||
request = noteflow_pb2.RefineSpeakerDiarizationRequest(
|
||||
meeting_id=meeting_id,
|
||||
num_speakers=num_speakers or 0,
|
||||
)
|
||||
response = self._stub.RefineSpeakerDiarization(request)
|
||||
return DiarizationResult(
|
||||
job_id=response.job_id,
|
||||
status=self._job_status_to_str(response.status),
|
||||
segments_updated=response.segments_updated,
|
||||
speaker_ids=list(response.speaker_ids),
|
||||
error_message=response.error_message,
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logger.error("Failed to refine speaker diarization: %s", e)
|
||||
return None
|
||||
|
||||
def get_diarization_job_status(self, job_id: str) -> DiarizationResult | None:
|
||||
"""Get status for a diarization background job."""
|
||||
if not self._stub:
|
||||
return None
|
||||
|
||||
try:
|
||||
request = noteflow_pb2.GetDiarizationJobStatusRequest(job_id=job_id)
|
||||
response = self._stub.GetDiarizationJobStatus(request)
|
||||
return DiarizationResult(
|
||||
job_id=response.job_id,
|
||||
status=self._job_status_to_str(response.status),
|
||||
segments_updated=response.segments_updated,
|
||||
speaker_ids=list(response.speaker_ids),
|
||||
error_message=response.error_message,
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logger.error("Failed to get diarization job status: %s", e)
|
||||
return None
|
||||
|
||||
def rename_speaker(
|
||||
self,
|
||||
meeting_id: str,
|
||||
old_speaker_id: str,
|
||||
new_speaker_name: str,
|
||||
) -> RenameSpeakerResult | None:
|
||||
"""Rename a speaker in all segments of a meeting.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting ID.
|
||||
old_speaker_id: Current speaker ID (e.g., "SPEAKER_00").
|
||||
new_speaker_name: New speaker name (e.g., "Alice").
|
||||
|
||||
Returns:
|
||||
RenameSpeakerResult or None if request fails.
|
||||
"""
|
||||
if not self._stub:
|
||||
return None
|
||||
|
||||
try:
|
||||
request = noteflow_pb2.RenameSpeakerRequest(
|
||||
meeting_id=meeting_id,
|
||||
old_speaker_id=old_speaker_id,
|
||||
new_speaker_name=new_speaker_name,
|
||||
)
|
||||
response = self._stub.RenameSpeaker(request)
|
||||
return RenameSpeakerResult(
|
||||
segments_updated=response.segments_updated,
|
||||
success=response.success,
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logger.error("Failed to rename speaker: %s", e)
|
||||
return None
|
||||
|
||||
@@ -14,6 +14,7 @@ from noteflow.domain.value_objects import MeetingState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class MeetingStore:
|
||||
@@ -142,6 +143,57 @@ class MeetingStore:
|
||||
meeting.summary = summary
|
||||
return meeting
|
||||
|
||||
def update_state(self, meeting_id: str, state: MeetingState) -> bool:
|
||||
"""Atomically update meeting state.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting ID.
|
||||
state: New state.
|
||||
|
||||
Returns:
|
||||
True if updated, False if meeting not found.
|
||||
"""
|
||||
with self._lock:
|
||||
meeting = self._meetings.get(meeting_id)
|
||||
if meeting is None:
|
||||
return False
|
||||
meeting.state = state
|
||||
return True
|
||||
|
||||
def update_title(self, meeting_id: str, title: str) -> bool:
|
||||
"""Atomically update meeting title.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting ID.
|
||||
title: New title.
|
||||
|
||||
Returns:
|
||||
True if updated, False if meeting not found.
|
||||
"""
|
||||
with self._lock:
|
||||
meeting = self._meetings.get(meeting_id)
|
||||
if meeting is None:
|
||||
return False
|
||||
meeting.title = title
|
||||
return True
|
||||
|
||||
def update_end_time(self, meeting_id: str, end_time: datetime) -> bool:
|
||||
"""Atomically update meeting end time.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting ID.
|
||||
end_time: New end time.
|
||||
|
||||
Returns:
|
||||
True if updated, False if meeting not found.
|
||||
"""
|
||||
with self._lock:
|
||||
meeting = self._meetings.get(meeting_id)
|
||||
if meeting is None:
|
||||
return False
|
||||
meeting.end_time = end_time
|
||||
return True
|
||||
|
||||
def delete(self, meeting_id: str) -> bool:
|
||||
"""Delete a meeting.
|
||||
|
||||
|
||||
@@ -33,6 +33,11 @@ service NoteFlowService {
|
||||
// Export functionality
|
||||
rpc ExportTranscript(ExportTranscriptRequest) returns (ExportTranscriptResponse);
|
||||
|
||||
// Speaker diarization
|
||||
rpc RefineSpeakerDiarization(RefineSpeakerDiarizationRequest) returns (RefineSpeakerDiarizationResponse);
|
||||
rpc RenameSpeaker(RenameSpeakerRequest) returns (RenameSpeakerResponse);
|
||||
rpc GetDiarizationJobStatus(GetDiarizationJobStatusRequest) returns (DiarizationJobStatus);
|
||||
|
||||
// Server health and capabilities
|
||||
rpc GetServerInfo(ServerInfoRequest) returns (ServerInfo);
|
||||
}
|
||||
@@ -438,6 +443,14 @@ enum ExportFormat {
|
||||
EXPORT_FORMAT_HTML = 2;
|
||||
}
|
||||
|
||||
enum JobStatus {
|
||||
JOB_STATUS_UNSPECIFIED = 0;
|
||||
JOB_STATUS_QUEUED = 1;
|
||||
JOB_STATUS_RUNNING = 2;
|
||||
JOB_STATUS_COMPLETED = 3;
|
||||
JOB_STATUS_FAILED = 4;
|
||||
}
|
||||
|
||||
message ExportTranscriptRequest {
|
||||
// Meeting ID to export
|
||||
string meeting_id = 1;
|
||||
@@ -456,3 +469,73 @@ message ExportTranscriptResponse {
|
||||
// Suggested file extension
|
||||
string file_extension = 3;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Speaker Diarization Messages
|
||||
// =============================================================================
|
||||
|
||||
message RefineSpeakerDiarizationRequest {
|
||||
// Meeting ID to run diarization on
|
||||
string meeting_id = 1;
|
||||
|
||||
// Optional known number of speakers (auto-detect if not set or 0)
|
||||
int32 num_speakers = 2;
|
||||
}
|
||||
|
||||
message RefineSpeakerDiarizationResponse {
|
||||
// Number of segments updated with speaker labels
|
||||
int32 segments_updated = 1;
|
||||
|
||||
// Distinct speaker IDs found
|
||||
repeated string speaker_ids = 2;
|
||||
|
||||
// Error message if diarization failed
|
||||
string error_message = 3;
|
||||
|
||||
// Background job identifier (empty if request failed)
|
||||
string job_id = 4;
|
||||
|
||||
// Current job status
|
||||
JobStatus status = 5;
|
||||
}
|
||||
|
||||
message RenameSpeakerRequest {
|
||||
// Meeting ID
|
||||
string meeting_id = 1;
|
||||
|
||||
// Original speaker ID (e.g., "SPEAKER_00")
|
||||
string old_speaker_id = 2;
|
||||
|
||||
// New speaker name (e.g., "Alice")
|
||||
string new_speaker_name = 3;
|
||||
}
|
||||
|
||||
message RenameSpeakerResponse {
|
||||
// Number of segments updated
|
||||
int32 segments_updated = 1;
|
||||
|
||||
// Success flag
|
||||
bool success = 2;
|
||||
}
|
||||
|
||||
message GetDiarizationJobStatusRequest {
|
||||
// Job ID returned by RefineSpeakerDiarization
|
||||
string job_id = 1;
|
||||
}
|
||||
|
||||
message DiarizationJobStatus {
|
||||
// Job ID
|
||||
string job_id = 1;
|
||||
|
||||
// Current status
|
||||
JobStatus status = 2;
|
||||
|
||||
// Number of segments updated (when completed)
|
||||
int32 segments_updated = 3;
|
||||
|
||||
// Distinct speaker IDs found (when completed)
|
||||
repeated string speaker_ids = 4;
|
||||
|
||||
// Error message if failed
|
||||
string error_message = 5;
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -36,6 +36,28 @@ class Priority(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
||||
PRIORITY_LOW: _ClassVar[Priority]
|
||||
PRIORITY_MEDIUM: _ClassVar[Priority]
|
||||
PRIORITY_HIGH: _ClassVar[Priority]
|
||||
|
||||
class AnnotationType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
||||
__slots__ = ()
|
||||
ANNOTATION_TYPE_UNSPECIFIED: _ClassVar[AnnotationType]
|
||||
ANNOTATION_TYPE_ACTION_ITEM: _ClassVar[AnnotationType]
|
||||
ANNOTATION_TYPE_DECISION: _ClassVar[AnnotationType]
|
||||
ANNOTATION_TYPE_NOTE: _ClassVar[AnnotationType]
|
||||
ANNOTATION_TYPE_RISK: _ClassVar[AnnotationType]
|
||||
|
||||
class ExportFormat(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
||||
__slots__ = ()
|
||||
EXPORT_FORMAT_UNSPECIFIED: _ClassVar[ExportFormat]
|
||||
EXPORT_FORMAT_MARKDOWN: _ClassVar[ExportFormat]
|
||||
EXPORT_FORMAT_HTML: _ClassVar[ExportFormat]
|
||||
|
||||
class JobStatus(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
||||
__slots__ = ()
|
||||
JOB_STATUS_UNSPECIFIED: _ClassVar[JobStatus]
|
||||
JOB_STATUS_QUEUED: _ClassVar[JobStatus]
|
||||
JOB_STATUS_RUNNING: _ClassVar[JobStatus]
|
||||
JOB_STATUS_COMPLETED: _ClassVar[JobStatus]
|
||||
JOB_STATUS_FAILED: _ClassVar[JobStatus]
|
||||
UPDATE_TYPE_UNSPECIFIED: UpdateType
|
||||
UPDATE_TYPE_PARTIAL: UpdateType
|
||||
UPDATE_TYPE_FINAL: UpdateType
|
||||
@@ -54,6 +76,19 @@ PRIORITY_UNSPECIFIED: Priority
|
||||
PRIORITY_LOW: Priority
|
||||
PRIORITY_MEDIUM: Priority
|
||||
PRIORITY_HIGH: Priority
|
||||
ANNOTATION_TYPE_UNSPECIFIED: AnnotationType
|
||||
ANNOTATION_TYPE_ACTION_ITEM: AnnotationType
|
||||
ANNOTATION_TYPE_DECISION: AnnotationType
|
||||
ANNOTATION_TYPE_NOTE: AnnotationType
|
||||
ANNOTATION_TYPE_RISK: AnnotationType
|
||||
EXPORT_FORMAT_UNSPECIFIED: ExportFormat
|
||||
EXPORT_FORMAT_MARKDOWN: ExportFormat
|
||||
EXPORT_FORMAT_HTML: ExportFormat
|
||||
JOB_STATUS_UNSPECIFIED: JobStatus
|
||||
JOB_STATUS_QUEUED: JobStatus
|
||||
JOB_STATUS_RUNNING: JobStatus
|
||||
JOB_STATUS_COMPLETED: JobStatus
|
||||
JOB_STATUS_FAILED: JobStatus
|
||||
|
||||
class AudioChunk(_message.Message):
|
||||
__slots__ = ("meeting_id", "audio_data", "timestamp", "sample_rate", "channels")
|
||||
@@ -290,3 +325,167 @@ class ServerInfo(_message.Message):
|
||||
diarization_enabled: bool
|
||||
diarization_ready: bool
|
||||
def __init__(self, version: _Optional[str] = ..., asr_model: _Optional[str] = ..., asr_ready: bool = ..., supported_sample_rates: _Optional[_Iterable[int]] = ..., max_chunk_size: _Optional[int] = ..., uptime_seconds: _Optional[float] = ..., active_meetings: _Optional[int] = ..., diarization_enabled: bool = ..., diarization_ready: bool = ...) -> None: ...
|
||||
|
||||
class Annotation(_message.Message):
|
||||
__slots__ = ("id", "meeting_id", "annotation_type", "text", "start_time", "end_time", "segment_ids", "created_at")
|
||||
ID_FIELD_NUMBER: _ClassVar[int]
|
||||
MEETING_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
ANNOTATION_TYPE_FIELD_NUMBER: _ClassVar[int]
|
||||
TEXT_FIELD_NUMBER: _ClassVar[int]
|
||||
START_TIME_FIELD_NUMBER: _ClassVar[int]
|
||||
END_TIME_FIELD_NUMBER: _ClassVar[int]
|
||||
SEGMENT_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||
CREATED_AT_FIELD_NUMBER: _ClassVar[int]
|
||||
id: str
|
||||
meeting_id: str
|
||||
annotation_type: AnnotationType
|
||||
text: str
|
||||
start_time: float
|
||||
end_time: float
|
||||
segment_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||
created_at: float
|
||||
def __init__(self, id: _Optional[str] = ..., meeting_id: _Optional[str] = ..., annotation_type: _Optional[_Union[AnnotationType, str]] = ..., text: _Optional[str] = ..., start_time: _Optional[float] = ..., end_time: _Optional[float] = ..., segment_ids: _Optional[_Iterable[int]] = ..., created_at: _Optional[float] = ...) -> None: ...
|
||||
|
||||
class AddAnnotationRequest(_message.Message):
|
||||
__slots__ = ("meeting_id", "annotation_type", "text", "start_time", "end_time", "segment_ids")
|
||||
MEETING_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
ANNOTATION_TYPE_FIELD_NUMBER: _ClassVar[int]
|
||||
TEXT_FIELD_NUMBER: _ClassVar[int]
|
||||
START_TIME_FIELD_NUMBER: _ClassVar[int]
|
||||
END_TIME_FIELD_NUMBER: _ClassVar[int]
|
||||
SEGMENT_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||
meeting_id: str
|
||||
annotation_type: AnnotationType
|
||||
text: str
|
||||
start_time: float
|
||||
end_time: float
|
||||
segment_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||
def __init__(self, meeting_id: _Optional[str] = ..., annotation_type: _Optional[_Union[AnnotationType, str]] = ..., text: _Optional[str] = ..., start_time: _Optional[float] = ..., end_time: _Optional[float] = ..., segment_ids: _Optional[_Iterable[int]] = ...) -> None: ...
|
||||
|
||||
class GetAnnotationRequest(_message.Message):
|
||||
__slots__ = ("annotation_id",)
|
||||
ANNOTATION_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
annotation_id: str
|
||||
def __init__(self, annotation_id: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class ListAnnotationsRequest(_message.Message):
|
||||
__slots__ = ("meeting_id", "start_time", "end_time")
|
||||
MEETING_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
START_TIME_FIELD_NUMBER: _ClassVar[int]
|
||||
END_TIME_FIELD_NUMBER: _ClassVar[int]
|
||||
meeting_id: str
|
||||
start_time: float
|
||||
end_time: float
|
||||
def __init__(self, meeting_id: _Optional[str] = ..., start_time: _Optional[float] = ..., end_time: _Optional[float] = ...) -> None: ...
|
||||
|
||||
class ListAnnotationsResponse(_message.Message):
|
||||
__slots__ = ("annotations",)
|
||||
ANNOTATIONS_FIELD_NUMBER: _ClassVar[int]
|
||||
annotations: _containers.RepeatedCompositeFieldContainer[Annotation]
|
||||
def __init__(self, annotations: _Optional[_Iterable[_Union[Annotation, _Mapping]]] = ...) -> None: ...
|
||||
|
||||
class UpdateAnnotationRequest(_message.Message):
|
||||
__slots__ = ("annotation_id", "annotation_type", "text", "start_time", "end_time", "segment_ids")
|
||||
ANNOTATION_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
ANNOTATION_TYPE_FIELD_NUMBER: _ClassVar[int]
|
||||
TEXT_FIELD_NUMBER: _ClassVar[int]
|
||||
START_TIME_FIELD_NUMBER: _ClassVar[int]
|
||||
END_TIME_FIELD_NUMBER: _ClassVar[int]
|
||||
SEGMENT_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||
annotation_id: str
|
||||
annotation_type: AnnotationType
|
||||
text: str
|
||||
start_time: float
|
||||
end_time: float
|
||||
segment_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||
def __init__(self, annotation_id: _Optional[str] = ..., annotation_type: _Optional[_Union[AnnotationType, str]] = ..., text: _Optional[str] = ..., start_time: _Optional[float] = ..., end_time: _Optional[float] = ..., segment_ids: _Optional[_Iterable[int]] = ...) -> None: ...
|
||||
|
||||
class DeleteAnnotationRequest(_message.Message):
|
||||
__slots__ = ("annotation_id",)
|
||||
ANNOTATION_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
annotation_id: str
|
||||
def __init__(self, annotation_id: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class DeleteAnnotationResponse(_message.Message):
|
||||
__slots__ = ("success",)
|
||||
SUCCESS_FIELD_NUMBER: _ClassVar[int]
|
||||
success: bool
|
||||
def __init__(self, success: bool = ...) -> None: ...
|
||||
|
||||
class ExportTranscriptRequest(_message.Message):
|
||||
__slots__ = ("meeting_id", "format")
|
||||
MEETING_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
FORMAT_FIELD_NUMBER: _ClassVar[int]
|
||||
meeting_id: str
|
||||
format: ExportFormat
|
||||
def __init__(self, meeting_id: _Optional[str] = ..., format: _Optional[_Union[ExportFormat, str]] = ...) -> None: ...
|
||||
|
||||
class ExportTranscriptResponse(_message.Message):
|
||||
__slots__ = ("content", "format_name", "file_extension")
|
||||
CONTENT_FIELD_NUMBER: _ClassVar[int]
|
||||
FORMAT_NAME_FIELD_NUMBER: _ClassVar[int]
|
||||
FILE_EXTENSION_FIELD_NUMBER: _ClassVar[int]
|
||||
content: str
|
||||
format_name: str
|
||||
file_extension: str
|
||||
def __init__(self, content: _Optional[str] = ..., format_name: _Optional[str] = ..., file_extension: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class RefineSpeakerDiarizationRequest(_message.Message):
|
||||
__slots__ = ("meeting_id", "num_speakers")
|
||||
MEETING_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
NUM_SPEAKERS_FIELD_NUMBER: _ClassVar[int]
|
||||
meeting_id: str
|
||||
num_speakers: int
|
||||
def __init__(self, meeting_id: _Optional[str] = ..., num_speakers: _Optional[int] = ...) -> None: ...
|
||||
|
||||
class RefineSpeakerDiarizationResponse(_message.Message):
|
||||
__slots__ = ("segments_updated", "speaker_ids", "error_message", "job_id", "status")
|
||||
SEGMENTS_UPDATED_FIELD_NUMBER: _ClassVar[int]
|
||||
SPEAKER_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||
ERROR_MESSAGE_FIELD_NUMBER: _ClassVar[int]
|
||||
JOB_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
STATUS_FIELD_NUMBER: _ClassVar[int]
|
||||
segments_updated: int
|
||||
speaker_ids: _containers.RepeatedScalarFieldContainer[str]
|
||||
error_message: str
|
||||
job_id: str
|
||||
status: JobStatus
|
||||
def __init__(self, segments_updated: _Optional[int] = ..., speaker_ids: _Optional[_Iterable[str]] = ..., error_message: _Optional[str] = ..., job_id: _Optional[str] = ..., status: _Optional[JobStatus] = ...) -> None: ...
|
||||
|
||||
class RenameSpeakerRequest(_message.Message):
|
||||
__slots__ = ("meeting_id", "old_speaker_id", "new_speaker_name")
|
||||
MEETING_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
OLD_SPEAKER_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
NEW_SPEAKER_NAME_FIELD_NUMBER: _ClassVar[int]
|
||||
meeting_id: str
|
||||
old_speaker_id: str
|
||||
new_speaker_name: str
|
||||
def __init__(self, meeting_id: _Optional[str] = ..., old_speaker_id: _Optional[str] = ..., new_speaker_name: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class RenameSpeakerResponse(_message.Message):
|
||||
__slots__ = ("segments_updated", "success")
|
||||
SEGMENTS_UPDATED_FIELD_NUMBER: _ClassVar[int]
|
||||
SUCCESS_FIELD_NUMBER: _ClassVar[int]
|
||||
segments_updated: int
|
||||
success: bool
|
||||
def __init__(self, segments_updated: _Optional[int] = ..., success: bool = ...) -> None: ...
|
||||
|
||||
class GetDiarizationJobStatusRequest(_message.Message):
|
||||
__slots__ = ("job_id",)
|
||||
JOB_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
job_id: str
|
||||
def __init__(self, job_id: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class DiarizationJobStatus(_message.Message):
|
||||
__slots__ = ("job_id", "status", "segments_updated", "speaker_ids", "error_message")
|
||||
JOB_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
STATUS_FIELD_NUMBER: _ClassVar[int]
|
||||
SEGMENTS_UPDATED_FIELD_NUMBER: _ClassVar[int]
|
||||
SPEAKER_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||
ERROR_MESSAGE_FIELD_NUMBER: _ClassVar[int]
|
||||
job_id: str
|
||||
status: JobStatus
|
||||
segments_updated: int
|
||||
speaker_ids: _containers.RepeatedScalarFieldContainer[str]
|
||||
error_message: str
|
||||
def __init__(self, job_id: _Optional[str] = ..., status: _Optional[JobStatus] = ..., segments_updated: _Optional[int] = ..., speaker_ids: _Optional[_Iterable[str]] = ..., error_message: _Optional[str] = ...) -> None: ...
|
||||
|
||||
@@ -103,6 +103,21 @@ class NoteFlowServiceStub(object):
|
||||
request_serializer=noteflow__pb2.ExportTranscriptRequest.SerializeToString,
|
||||
response_deserializer=noteflow__pb2.ExportTranscriptResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.RefineSpeakerDiarization = channel.unary_unary(
|
||||
'/noteflow.NoteFlowService/RefineSpeakerDiarization',
|
||||
request_serializer=noteflow__pb2.RefineSpeakerDiarizationRequest.SerializeToString,
|
||||
response_deserializer=noteflow__pb2.RefineSpeakerDiarizationResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.RenameSpeaker = channel.unary_unary(
|
||||
'/noteflow.NoteFlowService/RenameSpeaker',
|
||||
request_serializer=noteflow__pb2.RenameSpeakerRequest.SerializeToString,
|
||||
response_deserializer=noteflow__pb2.RenameSpeakerResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.GetDiarizationJobStatus = channel.unary_unary(
|
||||
'/noteflow.NoteFlowService/GetDiarizationJobStatus',
|
||||
request_serializer=noteflow__pb2.GetDiarizationJobStatusRequest.SerializeToString,
|
||||
response_deserializer=noteflow__pb2.DiarizationJobStatus.FromString,
|
||||
_registered_method=True)
|
||||
self.GetServerInfo = channel.unary_unary(
|
||||
'/noteflow.NoteFlowService/GetServerInfo',
|
||||
request_serializer=noteflow__pb2.ServerInfoRequest.SerializeToString,
|
||||
@@ -200,6 +215,25 @@ class NoteFlowServiceServicer(object):
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def RefineSpeakerDiarization(self, request, context):
|
||||
"""Speaker diarization
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def RenameSpeaker(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GetDiarizationJobStatus(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GetServerInfo(self, request, context):
|
||||
"""Server health and capabilities
|
||||
"""
|
||||
@@ -275,6 +309,21 @@ def add_NoteFlowServiceServicer_to_server(servicer, server):
|
||||
request_deserializer=noteflow__pb2.ExportTranscriptRequest.FromString,
|
||||
response_serializer=noteflow__pb2.ExportTranscriptResponse.SerializeToString,
|
||||
),
|
||||
'RefineSpeakerDiarization': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.RefineSpeakerDiarization,
|
||||
request_deserializer=noteflow__pb2.RefineSpeakerDiarizationRequest.FromString,
|
||||
response_serializer=noteflow__pb2.RefineSpeakerDiarizationResponse.SerializeToString,
|
||||
),
|
||||
'RenameSpeaker': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.RenameSpeaker,
|
||||
request_deserializer=noteflow__pb2.RenameSpeakerRequest.FromString,
|
||||
response_serializer=noteflow__pb2.RenameSpeakerResponse.SerializeToString,
|
||||
),
|
||||
'GetDiarizationJobStatus': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetDiarizationJobStatus,
|
||||
request_deserializer=noteflow__pb2.GetDiarizationJobStatusRequest.FromString,
|
||||
response_serializer=noteflow__pb2.DiarizationJobStatus.SerializeToString,
|
||||
),
|
||||
'GetServerInfo': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetServerInfo,
|
||||
request_deserializer=noteflow__pb2.ServerInfoRequest.FromString,
|
||||
@@ -646,6 +695,87 @@ class NoteFlowService(object):
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def RefineSpeakerDiarization(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/noteflow.NoteFlowService/RefineSpeakerDiarization',
|
||||
noteflow__pb2.RefineSpeakerDiarizationRequest.SerializeToString,
|
||||
noteflow__pb2.RefineSpeakerDiarizationResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def RenameSpeaker(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/noteflow.NoteFlowService/RenameSpeaker',
|
||||
noteflow__pb2.RenameSpeakerRequest.SerializeToString,
|
||||
noteflow__pb2.RenameSpeakerResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def GetDiarizationJobStatus(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/noteflow.NoteFlowService/GetDiarizationJobStatus',
|
||||
noteflow__pb2.GetDiarizationJobStatusRequest.SerializeToString,
|
||||
noteflow__pb2.DiarizationJobStatus.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def GetServerInfo(request,
|
||||
target,
|
||||
|
||||
@@ -7,15 +7,17 @@ import asyncio
|
||||
import logging
|
||||
import signal
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Final
|
||||
from typing import TYPE_CHECKING, Any, Final
|
||||
|
||||
import grpc.aio
|
||||
from pydantic import ValidationError
|
||||
|
||||
from noteflow.application.services import RecoveryService
|
||||
from noteflow.application.services.summarization_service import SummarizationService
|
||||
from noteflow.config.settings import get_settings
|
||||
from noteflow.infrastructure.asr import FasterWhisperEngine
|
||||
from noteflow.infrastructure.asr.engine import VALID_MODEL_SIZES
|
||||
from noteflow.infrastructure.diarization import DiarizationEngine
|
||||
from noteflow.infrastructure.persistence.database import create_async_session_factory
|
||||
from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWork
|
||||
from noteflow.infrastructure.summarization import create_summarization_service
|
||||
@@ -43,6 +45,8 @@ class NoteFlowServer:
|
||||
asr_compute_type: str = "int8",
|
||||
session_factory: async_sessionmaker[AsyncSession] | None = None,
|
||||
summarization_service: SummarizationService | None = None,
|
||||
diarization_engine: DiarizationEngine | None = None,
|
||||
diarization_refinement_enabled: bool = True,
|
||||
) -> None:
|
||||
"""Initialize the server.
|
||||
|
||||
@@ -53,6 +57,8 @@ class NoteFlowServer:
|
||||
asr_compute_type: ASR compute type.
|
||||
session_factory: Optional async session factory for database.
|
||||
summarization_service: Optional summarization service for generating summaries.
|
||||
diarization_engine: Optional diarization engine for speaker identification.
|
||||
diarization_refinement_enabled: Whether to allow diarization refinement RPCs.
|
||||
"""
|
||||
self._port = port
|
||||
self._asr_model = asr_model
|
||||
@@ -60,6 +66,8 @@ class NoteFlowServer:
|
||||
self._asr_compute_type = asr_compute_type
|
||||
self._session_factory = session_factory
|
||||
self._summarization_service = summarization_service
|
||||
self._diarization_engine = diarization_engine
|
||||
self._diarization_refinement_enabled = diarization_refinement_enabled
|
||||
self._server: grpc.aio.Server | None = None
|
||||
self._servicer: NoteFlowServicer | None = None
|
||||
|
||||
@@ -90,11 +98,13 @@ class NoteFlowServer:
|
||||
self._summarization_service = create_summarization_service()
|
||||
logger.info("Summarization service initialized (default factory)")
|
||||
|
||||
# Create servicer with session factory and summarization service
|
||||
# Create servicer with session factory, summarization, and diarization
|
||||
self._servicer = NoteFlowServicer(
|
||||
asr_engine=asr_engine,
|
||||
session_factory=self._session_factory,
|
||||
summarization_service=self._summarization_service,
|
||||
diarization_engine=self._diarization_engine,
|
||||
diarization_refinement_enabled=self._diarization_refinement_enabled,
|
||||
)
|
||||
|
||||
# Create async gRPC server
|
||||
@@ -142,6 +152,13 @@ async def run_server(
|
||||
asr_device: str,
|
||||
asr_compute_type: str,
|
||||
database_url: str | None = None,
|
||||
diarization_enabled: bool = False,
|
||||
diarization_hf_token: str | None = None,
|
||||
diarization_device: str = "auto",
|
||||
diarization_streaming_latency: float | None = None,
|
||||
diarization_min_speakers: int | None = None,
|
||||
diarization_max_speakers: int | None = None,
|
||||
diarization_refinement_enabled: bool = True,
|
||||
) -> None:
|
||||
"""Run the async gRPC server.
|
||||
|
||||
@@ -151,6 +168,13 @@ async def run_server(
|
||||
asr_device: Device for ASR.
|
||||
asr_compute_type: ASR compute type.
|
||||
database_url: Optional database URL for persistence.
|
||||
diarization_enabled: Whether to enable speaker diarization.
|
||||
diarization_hf_token: HuggingFace token for pyannote models.
|
||||
diarization_device: Device for diarization ("auto", "cpu", "cuda", "mps").
|
||||
diarization_streaming_latency: Streaming diarization latency in seconds.
|
||||
diarization_min_speakers: Minimum expected speakers for offline diarization.
|
||||
diarization_max_speakers: Maximum expected speakers for offline diarization.
|
||||
diarization_refinement_enabled: Whether to allow diarization refinement RPCs.
|
||||
"""
|
||||
# Create session factory if database URL provided
|
||||
session_factory = None
|
||||
@@ -173,6 +197,29 @@ async def run_server(
|
||||
summarization_service = create_summarization_service()
|
||||
logger.info("Summarization service initialized")
|
||||
|
||||
# Create diarization engine if enabled
|
||||
diarization_engine: DiarizationEngine | None = None
|
||||
if diarization_enabled:
|
||||
if not diarization_hf_token:
|
||||
logger.warning(
|
||||
"Diarization enabled but no HuggingFace token provided. "
|
||||
"Set NOTEFLOW_DIARIZATION_HF_TOKEN or --diarization-hf-token."
|
||||
)
|
||||
else:
|
||||
logger.info("Initializing diarization engine on %s...", diarization_device)
|
||||
diarization_kwargs: dict[str, Any] = {
|
||||
"device": diarization_device,
|
||||
"hf_token": diarization_hf_token,
|
||||
}
|
||||
if diarization_streaming_latency is not None:
|
||||
diarization_kwargs["streaming_latency"] = diarization_streaming_latency
|
||||
if diarization_min_speakers is not None:
|
||||
diarization_kwargs["min_speakers"] = diarization_min_speakers
|
||||
if diarization_max_speakers is not None:
|
||||
diarization_kwargs["max_speakers"] = diarization_max_speakers
|
||||
diarization_engine = DiarizationEngine(**diarization_kwargs)
|
||||
logger.info("Diarization engine initialized (models loaded on demand)")
|
||||
|
||||
server = NoteFlowServer(
|
||||
port=port,
|
||||
asr_model=asr_model,
|
||||
@@ -180,6 +227,8 @@ async def run_server(
|
||||
asr_compute_type=asr_compute_type,
|
||||
session_factory=session_factory,
|
||||
summarization_service=summarization_service,
|
||||
diarization_engine=diarization_engine,
|
||||
diarization_refinement_enabled=diarization_refinement_enabled,
|
||||
)
|
||||
|
||||
# Set up graceful shutdown
|
||||
@@ -201,6 +250,10 @@ async def run_server(
|
||||
print("Database: Connected")
|
||||
else:
|
||||
print("Database: Not configured (in-memory mode)")
|
||||
if diarization_engine:
|
||||
print(f"Diarization: Enabled ({diarization_device})")
|
||||
else:
|
||||
print("Diarization: Disabled")
|
||||
print("Press Ctrl+C to stop\n")
|
||||
|
||||
# Wait for shutdown signal or server termination
|
||||
@@ -255,6 +308,24 @@ def main() -> None:
|
||||
action="store_true",
|
||||
help="Enable verbose logging",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--diarization",
|
||||
action="store_true",
|
||||
help="Enable speaker diarization (requires pyannote.audio)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--diarization-hf-token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="HuggingFace token for pyannote models (overrides NOTEFLOW_DIARIZATION_HF_TOKEN)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--diarization-device",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "cpu", "cuda", "mps"],
|
||||
help="Device for diarization (default: auto)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging
|
||||
@@ -264,14 +335,39 @@ def main() -> None:
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
# Get settings
|
||||
try:
|
||||
settings = get_settings()
|
||||
except (OSError, ValueError, ValidationError) as exc:
|
||||
logger.warning("Failed to load settings: %s", exc)
|
||||
settings = None
|
||||
|
||||
# Get database URL from args or settings
|
||||
database_url = args.database_url
|
||||
if not database_url and settings:
|
||||
database_url = str(settings.database_url)
|
||||
if not database_url:
|
||||
try:
|
||||
settings = get_settings()
|
||||
database_url = str(settings.database_url)
|
||||
except Exception:
|
||||
logger.warning("No database URL configured, running in-memory mode")
|
||||
logger.warning("No database URL configured, running in-memory mode")
|
||||
|
||||
# Get diarization config from args or settings
|
||||
diarization_enabled = args.diarization
|
||||
diarization_hf_token = args.diarization_hf_token
|
||||
diarization_device = args.diarization_device
|
||||
diarization_streaming_latency: float | None = None
|
||||
diarization_min_speakers: int | None = None
|
||||
diarization_max_speakers: int | None = None
|
||||
diarization_refinement_enabled = True
|
||||
if settings and not diarization_enabled:
|
||||
diarization_enabled = settings.diarization_enabled
|
||||
if settings and not diarization_hf_token:
|
||||
diarization_hf_token = settings.diarization_hf_token
|
||||
if settings and diarization_device == "auto":
|
||||
diarization_device = settings.diarization_device
|
||||
if settings:
|
||||
diarization_streaming_latency = settings.diarization_streaming_latency
|
||||
diarization_min_speakers = settings.diarization_min_speakers
|
||||
diarization_max_speakers = settings.diarization_max_speakers
|
||||
diarization_refinement_enabled = settings.diarization_refinement_enabled
|
||||
|
||||
# Run server
|
||||
asyncio.run(
|
||||
@@ -281,6 +377,13 @@ def main() -> None:
|
||||
asr_device=args.device,
|
||||
asr_compute_type=args.compute_type,
|
||||
database_url=database_url,
|
||||
diarization_enabled=diarization_enabled,
|
||||
diarization_hf_token=diarization_hf_token,
|
||||
diarization_device=diarization_device,
|
||||
diarization_streaming_latency=diarization_streaming_latency,
|
||||
diarization_min_speakers=diarization_min_speakers,
|
||||
diarization_max_speakers=diarization_max_speakers,
|
||||
diarization_refinement_enabled=diarization_refinement_enabled,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,8 +5,10 @@ Provides Whisper-based transcription with word-level timestamps.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Final
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -151,6 +153,29 @@ class FasterWhisperEngine:
|
||||
no_speech_prob=segment.no_speech_prob,
|
||||
)
|
||||
|
||||
async def transcribe_async(
|
||||
self,
|
||||
audio: NDArray[np.float32],
|
||||
language: str | None = None,
|
||||
) -> list[AsrResult]:
|
||||
"""Transcribe audio asynchronously using executor.
|
||||
|
||||
Offloads blocking transcription to a thread pool executor to avoid
|
||||
blocking the asyncio event loop.
|
||||
|
||||
Args:
|
||||
audio: Audio samples as float32 array (16kHz mono, normalized).
|
||||
language: Optional language code (e.g., "en").
|
||||
|
||||
Returns:
|
||||
List of AsrResult segments with word-level timestamps.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
partial(lambda a, lang: list(self.transcribe(a, lang)), audio, language),
|
||||
)
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
"""Return True if model is loaded."""
|
||||
|
||||
@@ -12,6 +12,8 @@ from typing import TYPE_CHECKING
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from noteflow.config.constants import DEFAULT_SAMPLE_RATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
@@ -37,7 +39,7 @@ class SegmenterConfig:
|
||||
# Leading audio to include before speech starts (seconds)
|
||||
leading_buffer: float = 0.2
|
||||
# Sample rate for audio processing
|
||||
sample_rate: int = 16000
|
||||
sample_rate: int = DEFAULT_SAMPLE_RATE
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -8,6 +8,7 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from noteflow.config.constants import DEFAULT_SAMPLE_RATE
|
||||
from noteflow.infrastructure.audio import compute_rms
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -111,7 +112,7 @@ class StreamingVad:
|
||||
"""
|
||||
|
||||
engine: VadEngine = field(default_factory=EnergyVad)
|
||||
sample_rate: int = 16000
|
||||
sample_rate: int = DEFAULT_SAMPLE_RATE
|
||||
|
||||
def process_chunk(self, audio: NDArray[np.float32]) -> bool:
|
||||
"""Process audio chunk through VAD engine.
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import TYPE_CHECKING
|
||||
import numpy as np
|
||||
import sounddevice as sd
|
||||
|
||||
from noteflow.config.constants import DEFAULT_SAMPLE_RATE
|
||||
from noteflow.infrastructure.audio.dto import AudioDeviceInfo, AudioFrameCallback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -32,7 +33,7 @@ class SoundDeviceCapture:
|
||||
self._stream: sd.InputStream | None = None
|
||||
self._callback: AudioFrameCallback | None = None
|
||||
self._device_id: int | None = None
|
||||
self._sample_rate: int = 16000
|
||||
self._sample_rate: int = DEFAULT_SAMPLE_RATE
|
||||
self._channels: int = 1
|
||||
|
||||
def list_devices(self) -> list[AudioDeviceInfo]:
|
||||
@@ -79,7 +80,7 @@ class SoundDeviceCapture:
|
||||
self,
|
||||
device_id: int | None,
|
||||
on_frames: AudioFrameCallback,
|
||||
sample_rate: int = 16000,
|
||||
sample_rate: int = DEFAULT_SAMPLE_RATE,
|
||||
channels: int = 1,
|
||||
chunk_duration_ms: int = 100,
|
||||
) -> None:
|
||||
|
||||
@@ -7,6 +7,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -14,6 +15,8 @@ import numpy as np
|
||||
import sounddevice as sd
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from noteflow.config.constants import DEFAULT_SAMPLE_RATE, POSITION_UPDATE_INTERVAL
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.infrastructure.audio.dto import TimestampedAudio
|
||||
|
||||
@@ -35,16 +38,29 @@ class SoundDevicePlayback:
|
||||
Thread-safe for UI callbacks.
|
||||
"""
|
||||
|
||||
def __init__(self, sample_rate: int = 16000, channels: int = 1) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate: int = DEFAULT_SAMPLE_RATE,
|
||||
channels: int = 1,
|
||||
on_position_update: Callable[[float], None] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the playback instance.
|
||||
|
||||
Args:
|
||||
sample_rate: Sample rate in Hz (default 16kHz for ASR audio).
|
||||
channels: Number of channels (default 1 for mono).
|
||||
on_position_update: Optional callback for position updates during playback.
|
||||
Called at ~100ms intervals with current position in seconds.
|
||||
Runs in the audio thread, so keep the callback minimal.
|
||||
"""
|
||||
self._sample_rate = sample_rate
|
||||
self._channels = channels
|
||||
|
||||
# Position update callbacks (can have multiple subscribers)
|
||||
self._position_callbacks: list[Callable[[float], None]] = []
|
||||
if on_position_update is not None:
|
||||
self._position_callbacks.append(on_position_update)
|
||||
|
||||
# Playback state
|
||||
self._state = PlaybackState.STOPPED
|
||||
self._lock = threading.Lock()
|
||||
@@ -54,6 +70,10 @@ class SoundDevicePlayback:
|
||||
self._total_samples: int = 0
|
||||
self._current_sample: int = 0
|
||||
|
||||
# Position callback tracking
|
||||
self._callback_interval_samples = int(sample_rate * POSITION_UPDATE_INTERVAL)
|
||||
self._last_callback_sample: int = 0
|
||||
|
||||
# Stream
|
||||
self._stream: sd.OutputStream | None = None
|
||||
|
||||
@@ -76,6 +96,7 @@ class SoundDevicePlayback:
|
||||
self._audio_data = np.concatenate(frames).astype(np.float32)
|
||||
self._total_samples = len(self._audio_data)
|
||||
self._current_sample = 0
|
||||
self._last_callback_sample = 0
|
||||
|
||||
# Create and start stream
|
||||
self._start_stream()
|
||||
@@ -114,9 +135,15 @@ class SoundDevicePlayback:
|
||||
|
||||
Safe to call even if not playing.
|
||||
"""
|
||||
position = 0.0
|
||||
with self._lock:
|
||||
if self._audio_data is not None:
|
||||
position = self._current_sample / self._sample_rate
|
||||
self._stop_internal()
|
||||
|
||||
# Notify callbacks so UI can react to stop even if no final tick fired.
|
||||
self._notify_position_callbacks(position)
|
||||
|
||||
def _stop_internal(self) -> None:
|
||||
"""Internal stop without lock (caller must hold lock)."""
|
||||
if self._stream is not None:
|
||||
@@ -132,6 +159,7 @@ class SoundDevicePlayback:
|
||||
self._current_sample = 0
|
||||
self._audio_data = None
|
||||
self._total_samples = 0
|
||||
self._last_callback_sample = 0
|
||||
logger.debug("Stopped playback")
|
||||
|
||||
def _start_stream(self) -> None:
|
||||
@@ -149,6 +177,9 @@ class SoundDevicePlayback:
|
||||
if status:
|
||||
logger.warning("Playback stream status: %s", status)
|
||||
|
||||
fire_callback = False
|
||||
position = 0.0
|
||||
|
||||
with self._lock:
|
||||
if self._audio_data is None or self._state != PlaybackState.PLAYING:
|
||||
# Output silence
|
||||
@@ -170,11 +201,22 @@ class SoundDevicePlayback:
|
||||
if to_copy < frames:
|
||||
outdata[to_copy:] = 0
|
||||
|
||||
# Check if we should fire position update callback
|
||||
elapsed = self._current_sample - self._last_callback_sample
|
||||
if elapsed >= self._callback_interval_samples:
|
||||
fire_callback = True
|
||||
position = self._current_sample / self._sample_rate
|
||||
self._last_callback_sample = self._current_sample
|
||||
|
||||
# Check if playback is complete
|
||||
if self._current_sample >= self._total_samples:
|
||||
# Schedule stop on another thread to avoid deadlock
|
||||
threading.Thread(target=self._on_playback_complete, daemon=True).start()
|
||||
|
||||
# Fire callbacks outside lock to avoid potential deadlocks
|
||||
if fire_callback:
|
||||
self._notify_position_callbacks(position)
|
||||
|
||||
try:
|
||||
self._stream = sd.OutputStream(
|
||||
channels=self._channels,
|
||||
@@ -214,13 +256,19 @@ class SoundDevicePlayback:
|
||||
|
||||
# Convert to sample position
|
||||
self._current_sample = int(clamped_position * self._sample_rate)
|
||||
# Reset callback sample so updates resume immediately after seek
|
||||
self._last_callback_sample = self._current_sample
|
||||
|
||||
logger.debug(
|
||||
"Seeked to %.2f seconds (sample %d)",
|
||||
clamped_position,
|
||||
self._current_sample,
|
||||
)
|
||||
return True
|
||||
position_seconds = clamped_position
|
||||
|
||||
# Notify callbacks to update UI/highlights immediately after seek
|
||||
self._notify_position_callbacks(position_seconds)
|
||||
return True
|
||||
|
||||
def is_playing(self) -> bool:
|
||||
"""Check if currently playing audio.
|
||||
@@ -258,3 +306,58 @@ class SoundDevicePlayback:
|
||||
def channels(self) -> int:
|
||||
"""Number of channels."""
|
||||
return self._channels
|
||||
|
||||
def add_position_callback(
|
||||
self,
|
||||
callback: Callable[[float], None],
|
||||
) -> None:
|
||||
"""Add a position update callback.
|
||||
|
||||
Multiple callbacks can be registered. Each receives the current
|
||||
position in seconds during playback.
|
||||
|
||||
Args:
|
||||
callback: Callback receiving current position in seconds.
|
||||
"""
|
||||
if callback not in self._position_callbacks:
|
||||
self._position_callbacks.append(callback)
|
||||
|
||||
def _notify_position_callbacks(self, position: float) -> None:
|
||||
"""Notify all registered position callbacks.
|
||||
|
||||
Runs without holding the playback lock to avoid deadlocks.
|
||||
"""
|
||||
for callback in list(self._position_callbacks):
|
||||
try:
|
||||
callback(position)
|
||||
except Exception as e:
|
||||
logger.debug("Position update callback error: %s", e)
|
||||
|
||||
def remove_position_callback(
|
||||
self,
|
||||
callback: Callable[[float], None],
|
||||
) -> None:
|
||||
"""Remove a position update callback.
|
||||
|
||||
Args:
|
||||
callback: Previously registered callback to remove.
|
||||
"""
|
||||
if callback in self._position_callbacks:
|
||||
self._position_callbacks.remove(callback)
|
||||
|
||||
def set_position_callback(
|
||||
self,
|
||||
callback: Callable[[float], None] | None,
|
||||
) -> None:
|
||||
"""Set or clear the position update callback (replaces all callbacks).
|
||||
|
||||
For backwards compatibility. Use add_position_callback/remove_position_callback
|
||||
for multiple subscribers.
|
||||
|
||||
Args:
|
||||
callback: Callback receiving current position in seconds,
|
||||
or None to clear all callbacks.
|
||||
"""
|
||||
self._position_callbacks.clear()
|
||||
if callback is not None:
|
||||
self._position_callbacks.append(callback)
|
||||
|
||||
@@ -7,6 +7,8 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from noteflow.config.constants import DEFAULT_SAMPLE_RATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
@@ -37,7 +39,7 @@ class AudioCapture(Protocol):
|
||||
self,
|
||||
device_id: int | None,
|
||||
on_frames: AudioFrameCallback,
|
||||
sample_rate: int = 16000,
|
||||
sample_rate: int = DEFAULT_SAMPLE_RATE,
|
||||
channels: int = 1,
|
||||
chunk_duration_ms: int = 100,
|
||||
) -> None:
|
||||
|
||||
@@ -13,6 +13,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from noteflow.config.constants import DEFAULT_SAMPLE_RATE
|
||||
from noteflow.infrastructure.audio.dto import TimestampedAudio
|
||||
from noteflow.infrastructure.security.crypto import ChunkedAssetReader
|
||||
|
||||
@@ -48,7 +49,7 @@ class MeetingAudioReader:
|
||||
self._crypto = crypto
|
||||
self._meetings_dir = meetings_dir
|
||||
self._meeting_dir: Path | None = None
|
||||
self._sample_rate: int = 16000
|
||||
self._sample_rate: int = DEFAULT_SAMPLE_RATE
|
||||
|
||||
def load_meeting_audio(
|
||||
self,
|
||||
@@ -77,7 +78,7 @@ class MeetingAudioReader:
|
||||
raise FileNotFoundError(f"Manifest not found: {manifest_path}")
|
||||
|
||||
manifest = json.loads(manifest_path.read_text())
|
||||
self._sample_rate = manifest.get("sample_rate", 16000)
|
||||
self._sample_rate = manifest.get("sample_rate", DEFAULT_SAMPLE_RATE)
|
||||
wrapped_dek_hex = manifest.get("wrapped_dek")
|
||||
|
||||
if not wrapped_dek_hex:
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from noteflow.config.constants import DEFAULT_SAMPLE_RATE
|
||||
from noteflow.infrastructure.security.crypto import ChunkedAssetWriter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -47,7 +48,7 @@ class MeetingAudioWriter:
|
||||
self._meetings_dir = meetings_dir
|
||||
self._asset_writer: ChunkedAssetWriter | None = None
|
||||
self._meeting_dir: Path | None = None
|
||||
self._sample_rate: int = 16000
|
||||
self._sample_rate: int = DEFAULT_SAMPLE_RATE
|
||||
self._chunk_count: int = 0
|
||||
|
||||
def open(
|
||||
@@ -55,7 +56,7 @@ class MeetingAudioWriter:
|
||||
meeting_id: str,
|
||||
dek: bytes,
|
||||
wrapped_dek: bytes,
|
||||
sample_rate: int = 16000,
|
||||
sample_rate: int = DEFAULT_SAMPLE_RATE,
|
||||
) -> None:
|
||||
"""Open meeting for audio writing.
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.config.constants import DEFAULT_SAMPLE_RATE
|
||||
from noteflow.infrastructure.diarization.dto import SpeakerTurn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -170,7 +171,7 @@ class DiarizationEngine:
|
||||
def process_chunk(
|
||||
self,
|
||||
audio: NDArray[np.float32],
|
||||
sample_rate: int = 16000,
|
||||
sample_rate: int = DEFAULT_SAMPLE_RATE,
|
||||
) -> Sequence[SpeakerTurn]:
|
||||
"""Process an audio chunk for streaming diarization.
|
||||
|
||||
@@ -212,7 +213,7 @@ class DiarizationEngine:
|
||||
def diarize_full(
|
||||
self,
|
||||
audio: NDArray[np.float32],
|
||||
sample_rate: int = 16000,
|
||||
sample_rate: int = DEFAULT_SAMPLE_RATE,
|
||||
num_speakers: int | None = None,
|
||||
) -> Sequence[SpeakerTurn]:
|
||||
"""Diarize a complete audio recording.
|
||||
|
||||
@@ -28,7 +28,14 @@ def upgrade() -> None:
|
||||
op.execute("CREATE SCHEMA IF NOT EXISTS noteflow")
|
||||
|
||||
# Enable pgvector extension
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
try:
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
except sa.exc.ProgrammingError as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to create pgvector extension: {e}. "
|
||||
"Ensure the database user has CREATE EXTENSION privileges, or "
|
||||
"install pgvector manually: CREATE EXTENSION vector;"
|
||||
) from e
|
||||
|
||||
# Create meetings table
|
||||
op.create_table(
|
||||
|
||||
@@ -6,7 +6,9 @@ Provides secure master key storage using OS credential stores.
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
from typing import Final
|
||||
|
||||
@@ -18,6 +20,7 @@ logger = logging.getLogger(__name__)
|
||||
KEY_SIZE: Final[int] = 32 # 256-bit key
|
||||
SERVICE_NAME: Final[str] = "noteflow"
|
||||
KEY_NAME: Final[str] = "master_key"
|
||||
ENV_VAR_NAME: Final[str] = "NOTEFLOW_MASTER_KEY"
|
||||
|
||||
|
||||
class KeyringKeyStore:
|
||||
@@ -46,17 +49,35 @@ class KeyringKeyStore:
|
||||
def get_or_create_master_key(self) -> bytes:
|
||||
"""Retrieve or generate the master encryption key.
|
||||
|
||||
Checks for an environment variable first (for headless/container deployments),
|
||||
then falls back to the OS keyring.
|
||||
|
||||
Returns:
|
||||
32-byte master key.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If keychain is unavailable.
|
||||
RuntimeError: If keychain is unavailable and no env var is set.
|
||||
"""
|
||||
# Check environment variable first (for headless/container deployments)
|
||||
if env_key := os.environ.get(ENV_VAR_NAME):
|
||||
logger.debug("Using master key from environment variable")
|
||||
try:
|
||||
decoded = base64.b64decode(env_key, validate=True)
|
||||
except (binascii.Error, ValueError) as exc:
|
||||
raise RuntimeError(
|
||||
f"{ENV_VAR_NAME} must be base64-encoded {KEY_SIZE}-byte key"
|
||||
) from exc
|
||||
if len(decoded) != KEY_SIZE:
|
||||
raise RuntimeError(
|
||||
f"{ENV_VAR_NAME} must decode to {KEY_SIZE} bytes, got {len(decoded)}"
|
||||
)
|
||||
return decoded
|
||||
|
||||
try:
|
||||
# Try to retrieve existing key
|
||||
# Try to retrieve existing key from keyring
|
||||
stored = keyring.get_password(self._service_name, self._key_name)
|
||||
if stored is not None:
|
||||
logger.debug("Retrieved existing master key")
|
||||
logger.debug("Retrieved existing master key from keyring")
|
||||
return base64.b64decode(stored)
|
||||
|
||||
# Generate new key
|
||||
@@ -65,11 +86,14 @@ class KeyringKeyStore:
|
||||
|
||||
# Store in keyring
|
||||
keyring.set_password(self._service_name, self._key_name, encoded)
|
||||
logger.info("Generated and stored new master key")
|
||||
logger.info("Generated and stored new master key in keyring")
|
||||
return new_key
|
||||
|
||||
except keyring.errors.KeyringError as e:
|
||||
raise RuntimeError(f"Keyring unavailable: {e}") from e
|
||||
raise RuntimeError(
|
||||
f"Keyring unavailable: {e}. "
|
||||
f"Set {ENV_VAR_NAME} environment variable for headless mode."
|
||||
) from e
|
||||
|
||||
def delete_master_key(self) -> None:
|
||||
"""Delete the master key from the keychain.
|
||||
|
||||
@@ -77,7 +77,7 @@ class OllamaSummarizer:
|
||||
# Try to list models to verify connectivity
|
||||
client.list()
|
||||
return True
|
||||
except Exception:
|
||||
except (ConnectionError, TimeoutError, RuntimeError, OSError):
|
||||
return False
|
||||
|
||||
@property
|
||||
|
||||
@@ -3,18 +3,24 @@
|
||||
Provide signal providers for meeting detection triggers.
|
||||
"""
|
||||
|
||||
from noteflow.infrastructure.triggers.app_audio import AppAudioProvider, AppAudioSettings
|
||||
from noteflow.infrastructure.triggers.audio_activity import (
|
||||
AudioActivityProvider,
|
||||
AudioActivitySettings,
|
||||
)
|
||||
from noteflow.infrastructure.triggers.calendar import CalendarProvider, CalendarSettings
|
||||
from noteflow.infrastructure.triggers.foreground_app import (
|
||||
ForegroundAppProvider,
|
||||
ForegroundAppSettings,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AppAudioProvider",
|
||||
"AppAudioSettings",
|
||||
"AudioActivityProvider",
|
||||
"AudioActivitySettings",
|
||||
"CalendarProvider",
|
||||
"CalendarSettings",
|
||||
"ForegroundAppProvider",
|
||||
"ForegroundAppSettings",
|
||||
]
|
||||
|
||||
280
src/noteflow/infrastructure/triggers/app_audio.py
Normal file
280
src/noteflow/infrastructure/triggers/app_audio.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""App audio activity provider.
|
||||
|
||||
Detects audio activity from system output while whitelisted meeting apps are active.
|
||||
This is a best-effort heuristic: it combines (a) system output activity and
|
||||
(b) presence of whitelisted app windows to infer a likely meeting.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.config.constants import DEFAULT_SAMPLE_RATE
|
||||
from noteflow.domain.triggers.entities import TriggerSignal, TriggerSource
|
||||
from noteflow.infrastructure.audio.levels import RmsLevelProvider
|
||||
from noteflow.infrastructure.triggers.audio_activity import (
|
||||
AudioActivityProvider,
|
||||
AudioActivitySettings,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppAudioSettings:
|
||||
"""Configuration for app audio detection.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether app audio detection is enabled.
|
||||
threshold_db: Minimum dB level to consider as activity.
|
||||
window_seconds: Time window for sustained activity detection.
|
||||
min_active_ratio: Minimum ratio of active samples in window.
|
||||
min_samples: Minimum samples required before evaluation.
|
||||
max_history: Maximum samples retained in history.
|
||||
weight: Confidence weight contributed by this provider.
|
||||
meeting_apps: Set of app name substrings to match (lowercase).
|
||||
suppressed_apps: App substrings to ignore even if matched.
|
||||
sample_rate: Sample rate for system output capture.
|
||||
sample_duration_seconds: Duration of each sampling read.
|
||||
chunk_duration_seconds: Duration of sub-chunks for activity history updates.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
threshold_db: float
|
||||
window_seconds: float
|
||||
min_active_ratio: float
|
||||
min_samples: int
|
||||
max_history: int
|
||||
weight: float
|
||||
meeting_apps: set[str] = field(default_factory=set)
|
||||
suppressed_apps: set[str] = field(default_factory=set)
|
||||
sample_rate: int = DEFAULT_SAMPLE_RATE
|
||||
sample_duration_seconds: float = 0.5
|
||||
chunk_duration_seconds: float = 0.1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.meeting_apps = {app.lower() for app in self.meeting_apps}
|
||||
self.suppressed_apps = {app.lower() for app in self.suppressed_apps}
|
||||
|
||||
|
||||
class _SystemOutputSampler:
|
||||
"""Best-effort system output sampler using sounddevice."""
|
||||
|
||||
def __init__(self, sample_rate: int, channels: int = 1) -> None:
|
||||
self._sample_rate = sample_rate
|
||||
self._channels = channels
|
||||
self._stream = None
|
||||
self._extra_settings = None
|
||||
self._device = None
|
||||
self._available: bool | None = None
|
||||
|
||||
def _select_device(self) -> None:
|
||||
try:
|
||||
import sounddevice as sd
|
||||
except ImportError:
|
||||
return self._extracted_from__select_device_5(
|
||||
"sounddevice not available - app audio detection disabled"
|
||||
)
|
||||
# Default to output device and WASAPI loopback when available (Windows)
|
||||
try:
|
||||
default_output = sd.default.device[1]
|
||||
except (TypeError, IndexError):
|
||||
default_output = None
|
||||
|
||||
try:
|
||||
hostapi_index = sd.default.hostapi
|
||||
hostapi = sd.query_hostapis(hostapi_index) if hostapi_index is not None else None
|
||||
except Exception:
|
||||
hostapi = None
|
||||
|
||||
if hostapi and hostapi.get("type") == "Windows WASAPI" and default_output is not None:
|
||||
# On WASAPI, loopback devices appear as separate input devices
|
||||
# Fall through to monitor/loopback device detection below
|
||||
pass
|
||||
|
||||
# Fallback: look for monitor/loopback devices (Linux/PulseAudio)
|
||||
try:
|
||||
devices = sd.query_devices()
|
||||
except Exception:
|
||||
return self._extracted_from__select_device_5(
|
||||
"Failed to query audio devices for app audio detection"
|
||||
)
|
||||
for idx, dev in enumerate(devices):
|
||||
name = str(dev.get("name", "")).lower()
|
||||
if int(dev.get("max_input_channels", 0)) <= 0:
|
||||
continue
|
||||
if "monitor" in name or "loopback" in name:
|
||||
return self._extracted_from__select_device_24(idx)
|
||||
self._available = False
|
||||
logger.warning("No loopback audio device found - app audio detection disabled")
|
||||
|
||||
# TODO Rename this here and in `_select_device`
|
||||
def _extracted_from__select_device_24(self, arg0):
|
||||
self._device = arg0
|
||||
self._available = True
|
||||
return
|
||||
|
||||
# TODO Rename this here and in `_select_device`
|
||||
def _extracted_from__select_device_5(self, arg0):
|
||||
self._available = False
|
||||
logger.warning(arg0)
|
||||
return
|
||||
|
||||
def _ensure_stream(self) -> bool:
|
||||
if self._available is False:
|
||||
return False
|
||||
|
||||
if self._available is None:
|
||||
self._select_device()
|
||||
if self._available is False:
|
||||
return False
|
||||
|
||||
if self._stream is not None:
|
||||
return True
|
||||
|
||||
try:
|
||||
import sounddevice as sd
|
||||
|
||||
self._stream = sd.InputStream(
|
||||
device=self._device,
|
||||
channels=self._channels,
|
||||
samplerate=self._sample_rate,
|
||||
dtype="float32",
|
||||
extra_settings=self._extra_settings,
|
||||
)
|
||||
self._stream.start()
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to start system output capture: %s", exc)
|
||||
self._stream = None
|
||||
self._available = False
|
||||
return False
|
||||
|
||||
def read_frames(self, duration_seconds: float) -> NDArray[np.float32] | None:
|
||||
if not self._ensure_stream():
|
||||
return None
|
||||
|
||||
if self._stream is None:
|
||||
return None
|
||||
|
||||
frames = max(1, int(self._sample_rate * duration_seconds))
|
||||
try:
|
||||
data, _ = self._stream.read(frames)
|
||||
except Exception as exc:
|
||||
logger.debug("System output read failed: %s", exc)
|
||||
return None
|
||||
|
||||
return data.reshape(-1).astype("float32")
|
||||
|
||||
def close(self) -> None:
|
||||
if self._stream is None:
|
||||
return
|
||||
try:
|
||||
self._stream.stop()
|
||||
self._stream.close()
|
||||
except Exception:
|
||||
logger.debug("Failed to close system output stream", exc_info=True)
|
||||
finally:
|
||||
self._stream = None
|
||||
|
||||
|
||||
class AppAudioProvider:
|
||||
"""Detect app audio activity from whitelisted meeting apps."""
|
||||
|
||||
def __init__(self, settings: AppAudioSettings) -> None:
|
||||
self._settings = settings
|
||||
self._sampler = _SystemOutputSampler(sample_rate=settings.sample_rate)
|
||||
self._level_provider = RmsLevelProvider()
|
||||
self._audio_activity = AudioActivityProvider(
|
||||
self._level_provider,
|
||||
AudioActivitySettings(
|
||||
enabled=settings.enabled,
|
||||
threshold_db=settings.threshold_db,
|
||||
window_seconds=settings.window_seconds,
|
||||
min_active_ratio=settings.min_active_ratio,
|
||||
min_samples=settings.min_samples,
|
||||
max_history=settings.max_history,
|
||||
weight=settings.weight,
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def source(self) -> TriggerSource:
|
||||
return TriggerSource.AUDIO_ACTIVITY
|
||||
|
||||
@property
|
||||
def max_weight(self) -> float:
|
||||
return self._settings.weight
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
return self._settings.enabled
|
||||
|
||||
def get_signal(self) -> TriggerSignal | None:
|
||||
if not self.is_enabled():
|
||||
return None
|
||||
if not self._settings.meeting_apps:
|
||||
return None
|
||||
|
||||
app_title = self._detect_meeting_app()
|
||||
if not app_title:
|
||||
return None
|
||||
|
||||
frames = self._sampler.read_frames(self._settings.sample_duration_seconds)
|
||||
if frames is None or frames.size == 0:
|
||||
return None
|
||||
|
||||
self._update_activity_history(frames)
|
||||
if self._audio_activity.get_signal() is None:
|
||||
return None
|
||||
|
||||
return TriggerSignal(
|
||||
source=self.source,
|
||||
weight=self.max_weight,
|
||||
app_name=app_title,
|
||||
)
|
||||
|
||||
def _update_activity_history(self, frames: NDArray[np.float32]) -> None:
|
||||
chunk_size = max(1, int(self._settings.sample_rate * self._settings.chunk_duration_seconds))
|
||||
now = time.monotonic()
|
||||
for offset in range(0, len(frames), chunk_size):
|
||||
chunk = frames[offset : offset + chunk_size]
|
||||
if chunk.size == 0:
|
||||
continue
|
||||
self._audio_activity.update(chunk, now)
|
||||
|
||||
def _detect_meeting_app(self) -> str | None:
|
||||
try:
|
||||
import pywinctl
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
titles: list[str] = []
|
||||
try:
|
||||
if hasattr(pywinctl, "getAllWindows"):
|
||||
windows = pywinctl.getAllWindows()
|
||||
titles = [w.title for w in windows if getattr(w, "title", None)]
|
||||
elif hasattr(pywinctl, "getAllTitles"):
|
||||
titles = [t for t in pywinctl.getAllTitles() if t]
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to list windows for app detection: %s", exc)
|
||||
return None
|
||||
|
||||
for title in titles:
|
||||
title_lower = title.lower()
|
||||
if any(suppressed in title_lower for suppressed in self._settings.suppressed_apps):
|
||||
continue
|
||||
if any(app in title_lower for app in self._settings.meeting_apps):
|
||||
return title
|
||||
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
"""Release system audio resources."""
|
||||
self._sampler.close()
|
||||
150
src/noteflow/infrastructure/triggers/calendar.py
Normal file
150
src/noteflow/infrastructure/triggers/calendar.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""Calendar trigger provider.
|
||||
|
||||
Best-effort calendar integration using configured event windows.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.domain.triggers.entities import TriggerSignal, TriggerSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CalendarEvent:
|
||||
"""Simple calendar event window."""
|
||||
|
||||
start: datetime
|
||||
end: datetime
|
||||
title: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CalendarSettings:
|
||||
"""Configuration for calendar trigger detection."""
|
||||
|
||||
enabled: bool
|
||||
weight: float
|
||||
lookahead_minutes: int
|
||||
lookbehind_minutes: int
|
||||
events: list[CalendarEvent]
|
||||
|
||||
|
||||
class CalendarProvider:
|
||||
"""Provide trigger signal based on calendar proximity."""
|
||||
|
||||
def __init__(self, settings: CalendarSettings) -> None:
|
||||
self._settings = settings
|
||||
|
||||
@property
|
||||
def source(self) -> TriggerSource:
|
||||
return TriggerSource.CALENDAR
|
||||
|
||||
@property
|
||||
def max_weight(self) -> float:
|
||||
return self._settings.weight
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
return self._settings.enabled
|
||||
|
||||
def get_signal(self) -> TriggerSignal | None:
|
||||
if not self.is_enabled():
|
||||
return None
|
||||
|
||||
if not self._settings.events:
|
||||
return None
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
window_start = now - timedelta(minutes=self._settings.lookbehind_minutes)
|
||||
window_end = now + timedelta(minutes=self._settings.lookahead_minutes)
|
||||
|
||||
return next(
|
||||
(
|
||||
TriggerSignal(
|
||||
source=self.source,
|
||||
weight=self.max_weight,
|
||||
app_name=event.title,
|
||||
)
|
||||
for event in self._settings.events
|
||||
if self._event_overlaps_window(event, window_start, window_end)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _event_overlaps_window(
|
||||
event: CalendarEvent,
|
||||
window_start: datetime,
|
||||
window_end: datetime,
|
||||
) -> bool:
|
||||
event_start = _ensure_tz(event.start)
|
||||
event_end = _ensure_tz(event.end)
|
||||
return event_start <= window_end and event_end >= window_start
|
||||
|
||||
|
||||
def parse_calendar_events(raw_events: object) -> list[CalendarEvent]:
|
||||
"""Parse calendar events from config/env payloads."""
|
||||
if raw_events is None:
|
||||
return []
|
||||
|
||||
if isinstance(raw_events, str):
|
||||
raw_events = _load_events_from_json(raw_events)
|
||||
|
||||
if isinstance(raw_events, dict):
|
||||
raw_events = [raw_events]
|
||||
|
||||
if not isinstance(raw_events, Iterable):
|
||||
return []
|
||||
|
||||
events: list[CalendarEvent] = []
|
||||
for item in raw_events:
|
||||
if isinstance(item, CalendarEvent):
|
||||
events.append(item)
|
||||
continue
|
||||
if isinstance(item, dict):
|
||||
start = _parse_datetime(item.get("start"))
|
||||
end = _parse_datetime(item.get("end"))
|
||||
if start and end:
|
||||
events.append(CalendarEvent(start=start, end=end, title=item.get("title")))
|
||||
return events
|
||||
|
||||
|
||||
def _load_events_from_json(raw: str) -> list[dict[str, object]]:
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Failed to parse calendar events JSON")
|
||||
return []
|
||||
if isinstance(parsed, list):
|
||||
return [item for item in parsed if isinstance(item, dict)]
|
||||
return [parsed] if isinstance(parsed, dict) else []
|
||||
|
||||
|
||||
def _parse_datetime(value: object) -> datetime | None:
|
||||
if isinstance(value, datetime):
|
||||
return value
|
||||
if not isinstance(value, str) or not value:
|
||||
return None
|
||||
cleaned = value.strip()
|
||||
if cleaned.endswith("Z"):
|
||||
cleaned = f"{cleaned[:-1]}+00:00"
|
||||
try:
|
||||
return datetime.fromisoformat(cleaned)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _ensure_tz(value: datetime) -> datetime:
|
||||
if value.tzinfo is None:
|
||||
return value.replace(tzinfo=timezone.utc)
|
||||
return value.astimezone(timezone.utc)
|
||||
|
||||
@@ -26,7 +26,10 @@ class TestRetentionServiceProperties:
|
||||
def test_is_enabled_reflects_init(self) -> None:
|
||||
"""is_enabled should reflect constructor parameter."""
|
||||
uow = MagicMock()
|
||||
factory = lambda: uow
|
||||
|
||||
def factory() -> MagicMock:
|
||||
return uow
|
||||
|
||||
enabled_service = RetentionService(factory, retention_days=30, enabled=True)
|
||||
disabled_service = RetentionService(factory, retention_days=30, enabled=False)
|
||||
|
||||
|
||||
@@ -72,11 +72,25 @@ def mock_optional_extras() -> None:
|
||||
)
|
||||
sys.modules["ollama"] = ollama_module
|
||||
|
||||
# pywinctl depends on pymonctl, which may fail in headless environments
|
||||
# Mock both if not already present
|
||||
if "pymonctl" not in sys.modules:
|
||||
try:
|
||||
import pymonctl as _pymonctl # noqa: F401
|
||||
except Exception:
|
||||
# Mock pymonctl for headless environments (Xlib.error.DisplayNameError, etc.)
|
||||
pymonctl_module = types.ModuleType("pymonctl")
|
||||
pymonctl_module.getAllMonitors = lambda: []
|
||||
sys.modules["pymonctl"] = pymonctl_module
|
||||
|
||||
if "pywinctl" not in sys.modules:
|
||||
try:
|
||||
import pywinctl as _pywinctl # noqa: F401
|
||||
except Exception:
|
||||
# In headless environments pywinctl import may fail (e.g., missing DISPLAY)
|
||||
# ImportError: package not installed
|
||||
# OSError/Xlib errors: pywinctl may fail in headless environments
|
||||
pywinctl_module = types.ModuleType("pywinctl")
|
||||
pywinctl_module.getActiveWindow = lambda: None
|
||||
pywinctl_module.getAllWindows = lambda: []
|
||||
pywinctl_module.getAllTitles = lambda: []
|
||||
sys.modules["pywinctl"] = pywinctl_module
|
||||
|
||||
35
tests/grpc/test_diarization_refine.py
Normal file
35
tests/grpc/test_diarization_refine.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Tests for RefineSpeakerDiarization RPC guards."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.grpc.proto import noteflow_pb2
|
||||
from noteflow.grpc.service import NoteFlowServicer
|
||||
|
||||
|
||||
class _DummyContext:
|
||||
"""Minimal gRPC context that raises if abort is invoked."""
|
||||
|
||||
async def abort(self, code, details): # type: ignore[override]
|
||||
raise AssertionError(f"abort called: {code} - {details}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refine_speaker_diarization_rejects_active_meeting() -> None:
|
||||
"""Refinement should be blocked while a meeting is still recording."""
|
||||
servicer = NoteFlowServicer(diarization_engine=object())
|
||||
store = servicer._get_memory_store()
|
||||
|
||||
meeting = store.create("Active meeting")
|
||||
meeting.start_recording()
|
||||
store.update(meeting)
|
||||
|
||||
response = await servicer.RefineSpeakerDiarization(
|
||||
noteflow_pb2.RefineSpeakerDiarizationRequest(meeting_id=str(meeting.id)),
|
||||
_DummyContext(),
|
||||
)
|
||||
|
||||
assert response.segments_updated == 0
|
||||
assert response.error_message
|
||||
assert "stopped" in response.error_message.lower()
|
||||
@@ -37,7 +37,14 @@ def _create_mock_asr_engine(transcribe_results: list[str] | None = None) -> Magi
|
||||
def _transcribe(_audio: NDArray[np.float32]) -> list[MockAsrResult]:
|
||||
return [MockAsrResult(text=text) for text in results]
|
||||
|
||||
async def _transcribe_async(
|
||||
_audio: NDArray[np.float32],
|
||||
_language: str | None = None,
|
||||
) -> list[MockAsrResult]:
|
||||
return [MockAsrResult(text=text) for text in results]
|
||||
|
||||
engine.transcribe = _transcribe
|
||||
engine.transcribe_async = _transcribe_async
|
||||
return engine
|
||||
|
||||
|
||||
|
||||
196
tests/infrastructure/test_diarization.py
Normal file
196
tests/infrastructure/test_diarization.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""Tests for speaker diarization infrastructure.
|
||||
|
||||
Tests the SpeakerTurn DTO and speaker assignment utilities.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.infrastructure.diarization import SpeakerTurn, assign_speaker, assign_speakers_batch
|
||||
|
||||
|
||||
class TestSpeakerTurn:
|
||||
"""Tests for the SpeakerTurn dataclass."""
|
||||
|
||||
def test_create_valid_turn(self) -> None:
|
||||
"""Create a valid speaker turn."""
|
||||
turn = SpeakerTurn(speaker="SPEAKER_00", start=0.0, end=5.0)
|
||||
assert turn.speaker == "SPEAKER_00"
|
||||
assert turn.start == 0.0
|
||||
assert turn.end == 5.0
|
||||
assert turn.confidence == 1.0
|
||||
|
||||
def test_create_turn_with_confidence(self) -> None:
|
||||
"""Create a turn with custom confidence."""
|
||||
turn = SpeakerTurn(speaker="SPEAKER_01", start=10.0, end=15.0, confidence=0.85)
|
||||
assert turn.confidence == 0.85
|
||||
|
||||
def test_invalid_end_before_start_raises(self) -> None:
|
||||
"""End time before start time raises ValueError."""
|
||||
with pytest.raises(ValueError, match=r"end.*<.*start"):
|
||||
SpeakerTurn(speaker="SPEAKER_00", start=10.0, end=5.0)
|
||||
|
||||
def test_invalid_confidence_negative_raises(self) -> None:
|
||||
"""Negative confidence raises ValueError."""
|
||||
with pytest.raises(ValueError, match=r"Confidence must be 0\.0-1\.0"):
|
||||
SpeakerTurn(speaker="SPEAKER_00", start=0.0, end=5.0, confidence=-0.1)
|
||||
|
||||
def test_invalid_confidence_above_one_raises(self) -> None:
|
||||
"""Confidence above 1.0 raises ValueError."""
|
||||
with pytest.raises(ValueError, match=r"Confidence must be 0\.0-1\.0"):
|
||||
SpeakerTurn(speaker="SPEAKER_00", start=0.0, end=5.0, confidence=1.5)
|
||||
|
||||
def test_duration_property(self) -> None:
|
||||
"""Duration property calculates correctly."""
|
||||
turn = SpeakerTurn(speaker="SPEAKER_00", start=2.5, end=7.5)
|
||||
assert turn.duration == 5.0
|
||||
|
||||
def test_overlaps_returns_true_for_overlap(self) -> None:
|
||||
"""overlaps() returns True when ranges overlap."""
|
||||
turn = SpeakerTurn(speaker="SPEAKER_00", start=5.0, end=10.0)
|
||||
assert turn.overlaps(3.0, 7.0)
|
||||
assert turn.overlaps(7.0, 12.0)
|
||||
assert turn.overlaps(5.0, 10.0)
|
||||
assert turn.overlaps(0.0, 15.0)
|
||||
|
||||
def test_overlaps_returns_false_for_no_overlap(self) -> None:
|
||||
"""overlaps() returns False when ranges don't overlap."""
|
||||
turn = SpeakerTurn(speaker="SPEAKER_00", start=5.0, end=10.0)
|
||||
assert not turn.overlaps(0.0, 5.0)
|
||||
assert not turn.overlaps(10.0, 15.0)
|
||||
assert not turn.overlaps(0.0, 3.0)
|
||||
assert not turn.overlaps(12.0, 20.0)
|
||||
|
||||
def test_overlap_duration_full_overlap(self) -> None:
|
||||
"""overlap_duration() for full overlap returns turn duration."""
|
||||
turn = SpeakerTurn(speaker="SPEAKER_00", start=5.0, end=10.0)
|
||||
assert turn.overlap_duration(0.0, 15.0) == 5.0
|
||||
|
||||
def test_overlap_duration_partial_overlap_left(self) -> None:
|
||||
"""overlap_duration() for partial overlap on left side."""
|
||||
turn = SpeakerTurn(speaker="SPEAKER_00", start=5.0, end=10.0)
|
||||
assert turn.overlap_duration(3.0, 7.0) == 2.0
|
||||
|
||||
def test_overlap_duration_partial_overlap_right(self) -> None:
|
||||
"""overlap_duration() for partial overlap on right side."""
|
||||
turn = SpeakerTurn(speaker="SPEAKER_00", start=5.0, end=10.0)
|
||||
assert turn.overlap_duration(8.0, 15.0) == 2.0
|
||||
|
||||
def test_overlap_duration_contained(self) -> None:
|
||||
"""overlap_duration() when range is contained within turn."""
|
||||
turn = SpeakerTurn(speaker="SPEAKER_00", start=0.0, end=20.0)
|
||||
assert turn.overlap_duration(5.0, 10.0) == 5.0
|
||||
|
||||
def test_overlap_duration_no_overlap(self) -> None:
|
||||
"""overlap_duration() returns 0.0 when no overlap."""
|
||||
turn = SpeakerTurn(speaker="SPEAKER_00", start=5.0, end=10.0)
|
||||
assert turn.overlap_duration(0.0, 3.0) == 0.0
|
||||
assert turn.overlap_duration(12.0, 20.0) == 0.0
|
||||
|
||||
|
||||
class TestAssignSpeaker:
|
||||
"""Tests for the assign_speaker function."""
|
||||
|
||||
def test_empty_turns_returns_none(self) -> None:
|
||||
"""Empty turns list returns None with 0 confidence."""
|
||||
speaker, confidence = assign_speaker(0.0, 5.0, [])
|
||||
assert speaker is None
|
||||
assert confidence == 0.0
|
||||
|
||||
def test_zero_duration_segment_returns_none(self) -> None:
|
||||
"""Zero duration segment returns None."""
|
||||
turns = [SpeakerTurn(speaker="SPEAKER_00", start=0.0, end=10.0)]
|
||||
speaker, confidence = assign_speaker(5.0, 5.0, turns)
|
||||
assert speaker is None
|
||||
assert confidence == 0.0
|
||||
|
||||
def test_single_turn_full_overlap(self) -> None:
|
||||
"""Single turn with full overlap returns high confidence."""
|
||||
turns = [SpeakerTurn(speaker="SPEAKER_00", start=0.0, end=10.0)]
|
||||
speaker, confidence = assign_speaker(2.0, 8.0, turns)
|
||||
assert speaker == "SPEAKER_00"
|
||||
assert confidence == 1.0
|
||||
|
||||
def test_single_turn_partial_overlap(self) -> None:
|
||||
"""Single turn with partial overlap returns proportional confidence."""
|
||||
turns = [SpeakerTurn(speaker="SPEAKER_00", start=5.0, end=10.0)]
|
||||
speaker, confidence = assign_speaker(0.0, 10.0, turns)
|
||||
assert speaker == "SPEAKER_00"
|
||||
assert confidence == 0.5
|
||||
|
||||
def test_multiple_turns_chooses_dominant_speaker(self) -> None:
|
||||
"""Multiple turns chooses speaker with most overlap."""
|
||||
turns = [
|
||||
SpeakerTurn(speaker="SPEAKER_00", start=0.0, end=3.0),
|
||||
SpeakerTurn(speaker="SPEAKER_01", start=3.0, end=10.0),
|
||||
]
|
||||
speaker, confidence = assign_speaker(0.0, 10.0, turns)
|
||||
assert speaker == "SPEAKER_01"
|
||||
assert confidence == 0.7
|
||||
|
||||
def test_no_overlap_returns_none(self) -> None:
|
||||
"""No overlapping turns returns None."""
|
||||
turns = [
|
||||
SpeakerTurn(speaker="SPEAKER_00", start=0.0, end=5.0),
|
||||
SpeakerTurn(speaker="SPEAKER_01", start=10.0, end=15.0),
|
||||
]
|
||||
speaker, confidence = assign_speaker(6.0, 9.0, turns)
|
||||
assert speaker is None
|
||||
assert confidence == 0.0
|
||||
|
||||
def test_equal_overlap_chooses_first_encountered(self) -> None:
|
||||
"""Equal overlap chooses first speaker encountered."""
|
||||
turns = [
|
||||
SpeakerTurn(speaker="SPEAKER_00", start=0.0, end=5.0),
|
||||
SpeakerTurn(speaker="SPEAKER_01", start=5.0, end=10.0),
|
||||
]
|
||||
speaker, confidence = assign_speaker(3.0, 7.0, turns)
|
||||
# SPEAKER_00: overlap 2.0, SPEAKER_01: overlap 2.0
|
||||
# First one wins since > not >=
|
||||
assert speaker == "SPEAKER_00"
|
||||
assert confidence == 0.5
|
||||
|
||||
|
||||
class TestAssignSpeakersBatch:
|
||||
"""Tests for the assign_speakers_batch function."""
|
||||
|
||||
def test_empty_segments(self) -> None:
|
||||
"""Empty segments list returns empty results."""
|
||||
turns = [SpeakerTurn(speaker="SPEAKER_00", start=0.0, end=10.0)]
|
||||
results = assign_speakers_batch([], turns)
|
||||
assert results == []
|
||||
|
||||
def test_empty_turns(self) -> None:
|
||||
"""Empty turns returns all None speakers."""
|
||||
segments = [(0.0, 5.0), (5.0, 10.0)]
|
||||
results = assign_speakers_batch(segments, [])
|
||||
assert len(results) == 2
|
||||
assert all(speaker is None for speaker, _ in results)
|
||||
assert all(conf == 0.0 for _, conf in results)
|
||||
|
||||
def test_batch_assignment(self) -> None:
|
||||
"""Batch assignment processes all segments."""
|
||||
turns = [
|
||||
SpeakerTurn(speaker="SPEAKER_00", start=0.0, end=5.0),
|
||||
SpeakerTurn(speaker="SPEAKER_01", start=5.0, end=10.0),
|
||||
SpeakerTurn(speaker="SPEAKER_00", start=10.0, end=15.0),
|
||||
]
|
||||
segments = [(0.0, 5.0), (5.0, 10.0), (10.0, 15.0)]
|
||||
results = assign_speakers_batch(segments, turns)
|
||||
assert len(results) == 3
|
||||
assert results[0] == ("SPEAKER_00", 1.0)
|
||||
assert results[1] == ("SPEAKER_01", 1.0)
|
||||
assert results[2] == ("SPEAKER_00", 1.0)
|
||||
|
||||
def test_batch_with_gaps(self) -> None:
|
||||
"""Batch assignment handles gaps between turns."""
|
||||
turns = [
|
||||
SpeakerTurn(speaker="SPEAKER_00", start=0.0, end=3.0),
|
||||
SpeakerTurn(speaker="SPEAKER_01", start=7.0, end=10.0),
|
||||
]
|
||||
segments = [(0.0, 3.0), (3.0, 7.0), (7.0, 10.0)]
|
||||
results = assign_speakers_batch(segments, turns)
|
||||
assert results[0] == ("SPEAKER_00", 1.0)
|
||||
assert results[1] == (None, 0.0)
|
||||
assert results[2] == ("SPEAKER_01", 1.0)
|
||||
@@ -17,6 +17,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from noteflow.infrastructure.persistence.models import Base
|
||||
|
||||
|
||||
# Store container reference at module level to reuse
|
||||
class PgTestContainer:
|
||||
"""Minimal Postgres testcontainer wrapper with custom readiness wait."""
|
||||
|
||||
5
tests/stress/__init__.py
Normal file
5
tests/stress/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Stress and fuzz tests for NoteFlow.
|
||||
|
||||
These tests detect race conditions, infrastructure defects, and logic bugs
|
||||
through concurrent execution, file corruption simulation, and state machine fuzzing.
|
||||
"""
|
||||
245
tests/stress/conftest.py
Normal file
245
tests/stress/conftest.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""Pytest fixtures for stress and fuzz tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import quote
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from noteflow.grpc.service import NoteFlowServicer
|
||||
from noteflow.infrastructure.security.crypto import AesGcmCryptoBox
|
||||
from noteflow.infrastructure.security.keystore import InMemoryKeyStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import Self
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockAsrResult:
|
||||
"""Mock ASR transcription result."""
|
||||
|
||||
text: str
|
||||
start: float = 0.0
|
||||
end: float = 1.0
|
||||
language: str = "en"
|
||||
language_probability: float = 0.99
|
||||
avg_logprob: float = -0.5
|
||||
no_speech_prob: float = 0.01
|
||||
|
||||
|
||||
# Store container reference at module level to reuse in stress tests
|
||||
class PgTestContainer:
|
||||
"""Minimal Postgres testcontainer wrapper with custom readiness wait."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image: str = "pgvector/pgvector:pg16",
|
||||
username: str = "test",
|
||||
password: str = "test",
|
||||
dbname: str = "noteflow_test",
|
||||
port: int = 5432,
|
||||
) -> None:
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.dbname = dbname
|
||||
self.port = port
|
||||
|
||||
container_module = import_module("testcontainers.core.container")
|
||||
docker_container_cls = container_module.DockerContainer
|
||||
self._container = (
|
||||
docker_container_cls(image)
|
||||
.with_env("POSTGRES_USER", username)
|
||||
.with_env("POSTGRES_PASSWORD", password)
|
||||
.with_env("POSTGRES_DB", dbname)
|
||||
.with_exposed_ports(port)
|
||||
)
|
||||
|
||||
def start(self) -> Self:
|
||||
"""Start the container."""
|
||||
self._container.start()
|
||||
self._wait_until_ready()
|
||||
return self
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the container."""
|
||||
self._container.stop()
|
||||
|
||||
def get_connection_url(self) -> str:
|
||||
"""Return a SQLAlchemy-style connection URL."""
|
||||
host = self._container.get_container_host_ip()
|
||||
port = self._container._get_exposed_port(self.port)
|
||||
quoted_password = quote(self.password, safe=" +")
|
||||
return f"postgresql+psycopg2://{self.username}:{quoted_password}@{host}:{port}/{self.dbname}"
|
||||
|
||||
def _wait_until_ready(self, timeout: float = 30.0, interval: float = 0.5) -> None:
|
||||
"""Wait for Postgres to accept connections by running a simple query."""
|
||||
start_time = time.time()
|
||||
escaped_password = self.password.replace("'", "'\"'\"'")
|
||||
cmd = [
|
||||
"sh",
|
||||
"-c",
|
||||
(
|
||||
f"PGPASSWORD='{escaped_password}' "
|
||||
f"psql --username {self.username} --dbname {self.dbname} --host 127.0.0.1 "
|
||||
"-c 'select 1;'"
|
||||
),
|
||||
]
|
||||
last_error: str | None = None
|
||||
|
||||
while True:
|
||||
result = self._container.exec(cmd)
|
||||
if result.exit_code == 0:
|
||||
return
|
||||
if result.output:
|
||||
last_error = result.output.decode(errors="ignore")
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(
|
||||
"Postgres container did not become ready in time"
|
||||
+ (f": {last_error}" if last_error else "")
|
||||
)
|
||||
time.sleep(interval)
|
||||
|
||||
|
||||
_container: PgTestContainer | None = None
|
||||
_database_url: str | None = None
|
||||
|
||||
|
||||
def get_or_create_container() -> tuple[PgTestContainer, str]:
|
||||
"""Get or create the PostgreSQL container for stress tests."""
|
||||
global _container, _database_url
|
||||
|
||||
if _container is None:
|
||||
container = PgTestContainer().start()
|
||||
_container = container
|
||||
url = container.get_connection_url()
|
||||
_database_url = url.replace("postgresql+psycopg2://", "postgresql+asyncpg://")
|
||||
|
||||
assert _container is not None, "Container should be initialized"
|
||||
assert _database_url is not None, "Database URL should be initialized"
|
||||
return _container, _database_url
|
||||
|
||||
|
||||
def create_mock_asr_engine(transcribe_results: list[str] | None = None) -> MagicMock:
|
||||
"""Create mock ASR engine with configurable transcription results.
|
||||
|
||||
Args:
|
||||
transcribe_results: List of transcription texts to return.
|
||||
|
||||
Returns:
|
||||
Mock ASR engine with sync and async transcribe methods.
|
||||
"""
|
||||
engine = MagicMock()
|
||||
engine.is_loaded = True
|
||||
engine.model_size = "base"
|
||||
|
||||
results = transcribe_results or ["Test transcription"]
|
||||
|
||||
def _transcribe(_audio: NDArray[np.float32]) -> list[MockAsrResult]:
|
||||
return [MockAsrResult(text=text) for text in results]
|
||||
|
||||
async def _transcribe_async(
|
||||
_audio: NDArray[np.float32],
|
||||
_language: str | None = None,
|
||||
) -> list[MockAsrResult]:
|
||||
return [MockAsrResult(text=text) for text in results]
|
||||
|
||||
engine.transcribe = _transcribe
|
||||
engine.transcribe_async = _transcribe_async
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def in_memory_keystore() -> InMemoryKeyStore:
|
||||
"""Create an in-memory keystore for testing."""
|
||||
return InMemoryKeyStore()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def crypto(in_memory_keystore: InMemoryKeyStore) -> AesGcmCryptoBox:
|
||||
"""Create crypto box with in-memory keystore for testing."""
|
||||
return AesGcmCryptoBox(in_memory_keystore)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def meetings_dir(tmp_path: Path) -> Path:
|
||||
"""Create temporary meetings directory."""
|
||||
meetings = tmp_path / "meetings"
|
||||
meetings.mkdir(parents=True)
|
||||
return meetings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_asr_engine() -> MagicMock:
|
||||
"""Create default mock ASR engine."""
|
||||
return create_mock_asr_engine()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_servicer(mock_asr_engine: MagicMock, tmp_path: Path) -> NoteFlowServicer:
|
||||
"""Create NoteFlowServicer with in-memory MeetingStore backend.
|
||||
|
||||
Uses memory store (no database) for fast unit testing of
|
||||
concurrency and state management.
|
||||
"""
|
||||
return NoteFlowServicer(
|
||||
asr_engine=mock_asr_engine,
|
||||
session_factory=None,
|
||||
meetings_dir=tmp_path / "meetings",
|
||||
)
|
||||
|
||||
|
||||
# Import session_factory from integration tests for PostgreSQL backend
|
||||
# This is lazily imported to avoid requiring testcontainers for non-integration tests
|
||||
@pytest.fixture
|
||||
async def postgres_session_factory() -> AsyncGenerator[async_sessionmaker[AsyncSession], None]:
|
||||
"""Create PostgreSQL session factory using testcontainers.
|
||||
|
||||
Uses a local container helper to avoid importing test modules.
|
||||
"""
|
||||
# Import here to avoid requiring testcontainers for all stress tests
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from noteflow.infrastructure.persistence.models import Base
|
||||
|
||||
_, database_url = get_or_create_container()
|
||||
|
||||
engine = create_async_engine(database_url, echo=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
await conn.execute(text("DROP SCHEMA IF EXISTS noteflow CASCADE"))
|
||||
await conn.execute(text("CREATE SCHEMA noteflow"))
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text("DROP SCHEMA IF EXISTS noteflow CASCADE"))
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def pytest_sessionfinish(session: pytest.Session, exitstatus: int) -> None:
|
||||
"""Cleanup container after stress tests complete."""
|
||||
global _container
|
||||
if _container is not None:
|
||||
_container.stop()
|
||||
_container = None
|
||||
493
tests/stress/test_audio_integrity.py
Normal file
493
tests/stress/test_audio_integrity.py
Normal file
@@ -0,0 +1,493 @@
|
||||
"""Stress tests for encrypted audio file format (NFAE) resilience.
|
||||
|
||||
Tests truncation recovery, missing manifest handling, and corruption detection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import struct
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from noteflow.infrastructure.audio.reader import MeetingAudioReader
|
||||
from noteflow.infrastructure.audio.writer import MeetingAudioWriter
|
||||
from noteflow.infrastructure.security.crypto import (
|
||||
FILE_MAGIC,
|
||||
FILE_VERSION,
|
||||
AesGcmCryptoBox,
|
||||
ChunkedAssetReader,
|
||||
ChunkedAssetWriter,
|
||||
)
|
||||
from noteflow.infrastructure.security.keystore import InMemoryKeyStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def crypto() -> AesGcmCryptoBox:
|
||||
"""Create crypto with in-memory keystore."""
|
||||
return AesGcmCryptoBox(InMemoryKeyStore())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def meetings_dir(tmp_path: Path) -> Path:
|
||||
"""Create temporary meetings directory."""
|
||||
return tmp_path / "meetings"
|
||||
|
||||
|
||||
def make_audio(samples: int = 1600) -> NDArray[np.float32]:
|
||||
"""Create test audio with random values."""
|
||||
return np.random.uniform(-0.5, 0.5, samples).astype(np.float32)
|
||||
|
||||
|
||||
class TestTruncatedWriteRecovery:
|
||||
"""Test behavior when audio file is truncated (power loss simulation)."""
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_truncated_header_partial_magic(
|
||||
self, crypto: AesGcmCryptoBox, meetings_dir: Path
|
||||
) -> None:
|
||||
"""Truncated file (only partial magic bytes) raises on read."""
|
||||
meeting_id = str(uuid4())
|
||||
meeting_dir = meetings_dir / meeting_id
|
||||
meeting_dir.mkdir(parents=True)
|
||||
|
||||
audio_path = meeting_dir / "audio.enc"
|
||||
audio_path.write_bytes(FILE_MAGIC[:2])
|
||||
|
||||
reader = ChunkedAssetReader(crypto)
|
||||
dek = crypto.generate_dek()
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid file format"):
|
||||
reader.open(audio_path, dek)
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_truncated_header_missing_version(
|
||||
self, crypto: AesGcmCryptoBox, meetings_dir: Path
|
||||
) -> None:
|
||||
"""File with magic but truncated before version byte."""
|
||||
meeting_id = str(uuid4())
|
||||
meeting_dir = meetings_dir / meeting_id
|
||||
meeting_dir.mkdir(parents=True)
|
||||
|
||||
audio_path = meeting_dir / "audio.enc"
|
||||
audio_path.write_bytes(FILE_MAGIC)
|
||||
|
||||
reader = ChunkedAssetReader(crypto)
|
||||
dek = crypto.generate_dek()
|
||||
|
||||
with pytest.raises((struct.error, ValueError)):
|
||||
reader.open(audio_path, dek)
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_truncated_chunk_length_partial(
|
||||
self, crypto: AesGcmCryptoBox, meetings_dir: Path
|
||||
) -> None:
|
||||
"""File with complete header but truncated chunk length."""
|
||||
meeting_id = str(uuid4())
|
||||
meeting_dir = meetings_dir / meeting_id
|
||||
meeting_dir.mkdir(parents=True)
|
||||
|
||||
audio_path = meeting_dir / "audio.enc"
|
||||
with audio_path.open("wb") as f:
|
||||
f.write(FILE_MAGIC)
|
||||
f.write(struct.pack("B", FILE_VERSION))
|
||||
f.write(struct.pack(">I", 1000)[:2])
|
||||
|
||||
dek = crypto.generate_dek()
|
||||
reader = ChunkedAssetReader(crypto)
|
||||
reader.open(audio_path, dek)
|
||||
|
||||
chunks = list(reader.read_chunks())
|
||||
assert not chunks
|
||||
reader.close()
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_truncated_chunk_data_raises(self, crypto: AesGcmCryptoBox, meetings_dir: Path) -> None:
|
||||
"""File with chunk length but truncated data raises ValueError."""
|
||||
meeting_id = str(uuid4())
|
||||
meeting_dir = meetings_dir / meeting_id
|
||||
meeting_dir.mkdir(parents=True)
|
||||
|
||||
audio_path = meeting_dir / "audio.enc"
|
||||
with audio_path.open("wb") as f:
|
||||
f.write(FILE_MAGIC)
|
||||
f.write(struct.pack("B", FILE_VERSION))
|
||||
f.write(struct.pack(">I", 100))
|
||||
f.write(b"short")
|
||||
|
||||
dek = crypto.generate_dek()
|
||||
reader = ChunkedAssetReader(crypto)
|
||||
reader.open(audio_path, dek)
|
||||
|
||||
with pytest.raises(ValueError, match="Truncated chunk"):
|
||||
list(reader.read_chunks())
|
||||
reader.close()
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_valid_chunks_before_truncation_preserved(
|
||||
self, crypto: AesGcmCryptoBox, meetings_dir: Path
|
||||
) -> None:
|
||||
"""Valid chunks before truncation can still be read."""
|
||||
meeting_id = str(uuid4())
|
||||
meeting_dir = meetings_dir / meeting_id
|
||||
meeting_dir.mkdir(parents=True)
|
||||
|
||||
audio_path = meeting_dir / "audio.enc"
|
||||
dek = crypto.generate_dek()
|
||||
|
||||
writer = ChunkedAssetWriter(crypto)
|
||||
writer.open(audio_path, dek)
|
||||
test_data = b"valid audio chunk data 1"
|
||||
writer.write_chunk(test_data)
|
||||
writer.close()
|
||||
|
||||
with audio_path.open("ab") as f:
|
||||
f.write(struct.pack(">I", 500))
|
||||
f.write(b"truncated")
|
||||
|
||||
reader = ChunkedAssetReader(crypto)
|
||||
reader.open(audio_path, dek)
|
||||
|
||||
chunks = []
|
||||
try:
|
||||
chunks.extend(iter(reader.read_chunks()))
|
||||
except ValueError:
|
||||
pass
|
||||
finally:
|
||||
reader.close()
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0] == test_data
|
||||
|
||||
|
||||
class TestMissingManifest:
|
||||
"""Test behavior when manifest.json is missing."""
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_audio_exists_false_without_manifest(
|
||||
self, crypto: AesGcmCryptoBox, meetings_dir: Path
|
||||
) -> None:
|
||||
"""audio_exists returns False when only audio.enc exists."""
|
||||
meeting_id = str(uuid4())
|
||||
meeting_dir = meetings_dir / meeting_id
|
||||
meeting_dir.mkdir(parents=True)
|
||||
|
||||
(meeting_dir / "audio.enc").write_bytes(FILE_MAGIC + bytes([FILE_VERSION]))
|
||||
|
||||
reader = MeetingAudioReader(crypto, meetings_dir)
|
||||
assert reader.audio_exists(meeting_id) is False
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_audio_exists_false_without_audio(
|
||||
self, crypto: AesGcmCryptoBox, meetings_dir: Path
|
||||
) -> None:
|
||||
"""audio_exists returns False when only manifest exists."""
|
||||
meeting_id = str(uuid4())
|
||||
meeting_dir = meetings_dir / meeting_id
|
||||
meeting_dir.mkdir(parents=True)
|
||||
|
||||
dek = crypto.generate_dek()
|
||||
wrapped_dek = crypto.wrap_dek(dek)
|
||||
manifest = {
|
||||
"meeting_id": meeting_id,
|
||||
"sample_rate": 16000,
|
||||
"wrapped_dek": wrapped_dek.hex(),
|
||||
}
|
||||
(meeting_dir / "manifest.json").write_text(json.dumps(manifest))
|
||||
|
||||
reader = MeetingAudioReader(crypto, meetings_dir)
|
||||
assert reader.audio_exists(meeting_id) is False
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_audio_exists_true_when_both_exist(
|
||||
self, crypto: AesGcmCryptoBox, meetings_dir: Path
|
||||
) -> None:
|
||||
"""audio_exists returns True when both manifest and audio exist."""
|
||||
meeting_id = str(uuid4())
|
||||
dek = crypto.generate_dek()
|
||||
wrapped_dek = crypto.wrap_dek(dek)
|
||||
|
||||
writer = MeetingAudioWriter(crypto, meetings_dir)
|
||||
writer.open(meeting_id, dek, wrapped_dek)
|
||||
writer.write_chunk(make_audio())
|
||||
writer.close()
|
||||
|
||||
reader = MeetingAudioReader(crypto, meetings_dir)
|
||||
assert reader.audio_exists(meeting_id) is True
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_load_audio_raises_without_manifest(
|
||||
self, crypto: AesGcmCryptoBox, meetings_dir: Path
|
||||
) -> None:
|
||||
"""load_meeting_audio raises FileNotFoundError without manifest."""
|
||||
meeting_id = str(uuid4())
|
||||
meeting_dir = meetings_dir / meeting_id
|
||||
meeting_dir.mkdir(parents=True)
|
||||
|
||||
(meeting_dir / "audio.enc").write_bytes(FILE_MAGIC + bytes([FILE_VERSION]))
|
||||
|
||||
reader = MeetingAudioReader(crypto, meetings_dir)
|
||||
with pytest.raises(FileNotFoundError, match="Manifest not found"):
|
||||
reader.load_meeting_audio(meeting_id)
|
||||
|
||||
|
||||
class TestCorruptedCiphertextDetection:
|
||||
"""Test corrupted ciphertext/tag detection."""
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_bit_flip_in_ciphertext_detected(
|
||||
self, crypto: AesGcmCryptoBox, meetings_dir: Path
|
||||
) -> None:
|
||||
"""Single bit flip in ciphertext causes decryption failure."""
|
||||
meeting_id = str(uuid4())
|
||||
dek = crypto.generate_dek()
|
||||
wrapped_dek = crypto.wrap_dek(dek)
|
||||
|
||||
writer = MeetingAudioWriter(crypto, meetings_dir)
|
||||
writer.open(meeting_id, dek, wrapped_dek)
|
||||
writer.write_chunk(make_audio(1600))
|
||||
writer.close()
|
||||
|
||||
audio_path = meetings_dir / meeting_id / "audio.enc"
|
||||
data = bytearray(audio_path.read_bytes())
|
||||
|
||||
header_size = 5
|
||||
length_size = 4
|
||||
nonce_size = 12
|
||||
corrupt_offset = header_size + length_size + nonce_size + 5
|
||||
|
||||
if len(data) > corrupt_offset:
|
||||
data[corrupt_offset] ^= 0x01
|
||||
audio_path.write_bytes(bytes(data))
|
||||
|
||||
reader = ChunkedAssetReader(crypto)
|
||||
reader.open(audio_path, dek)
|
||||
|
||||
with pytest.raises(ValueError, match="Chunk decryption failed"):
|
||||
list(reader.read_chunks())
|
||||
reader.close()
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_bit_flip_in_tag_detected(self, crypto: AesGcmCryptoBox, meetings_dir: Path) -> None:
|
||||
"""Bit flip in authentication tag causes decryption failure."""
|
||||
meeting_id = str(uuid4())
|
||||
dek = crypto.generate_dek()
|
||||
wrapped_dek = crypto.wrap_dek(dek)
|
||||
|
||||
writer = MeetingAudioWriter(crypto, meetings_dir)
|
||||
writer.open(meeting_id, dek, wrapped_dek)
|
||||
writer.write_chunk(make_audio(1600))
|
||||
writer.close()
|
||||
|
||||
audio_path = meetings_dir / meeting_id / "audio.enc"
|
||||
data = bytearray(audio_path.read_bytes())
|
||||
|
||||
data[-5] ^= 0x01
|
||||
audio_path.write_bytes(bytes(data))
|
||||
|
||||
reader = ChunkedAssetReader(crypto)
|
||||
reader.open(audio_path, dek)
|
||||
|
||||
with pytest.raises(ValueError, match="Chunk decryption failed"):
|
||||
list(reader.read_chunks())
|
||||
reader.close()
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_wrong_dek_detected(self, crypto: AesGcmCryptoBox, meetings_dir: Path) -> None:
|
||||
"""Using wrong DEK fails decryption."""
|
||||
meeting_id = str(uuid4())
|
||||
dek = crypto.generate_dek()
|
||||
wrong_dek = crypto.generate_dek()
|
||||
wrapped_dek = crypto.wrap_dek(dek)
|
||||
|
||||
writer = MeetingAudioWriter(crypto, meetings_dir)
|
||||
writer.open(meeting_id, dek, wrapped_dek)
|
||||
writer.write_chunk(make_audio(1600))
|
||||
writer.close()
|
||||
|
||||
audio_path = meetings_dir / meeting_id / "audio.enc"
|
||||
reader = ChunkedAssetReader(crypto)
|
||||
reader.open(audio_path, wrong_dek)
|
||||
|
||||
with pytest.raises(ValueError, match="Chunk decryption failed"):
|
||||
list(reader.read_chunks())
|
||||
reader.close()
|
||||
|
||||
|
||||
class TestInvalidManifest:
|
||||
"""Test handling of invalid manifest.json content."""
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_missing_wrapped_dek_raises(self, crypto: AesGcmCryptoBox, meetings_dir: Path) -> None:
|
||||
"""Manifest without wrapped_dek raises ValueError."""
|
||||
meeting_id = str(uuid4())
|
||||
meeting_dir = meetings_dir / meeting_id
|
||||
meeting_dir.mkdir(parents=True)
|
||||
|
||||
manifest = {"meeting_id": meeting_id, "sample_rate": 16000}
|
||||
(meeting_dir / "manifest.json").write_text(json.dumps(manifest))
|
||||
(meeting_dir / "audio.enc").write_bytes(FILE_MAGIC + bytes([FILE_VERSION]))
|
||||
|
||||
reader = MeetingAudioReader(crypto, meetings_dir)
|
||||
with pytest.raises(ValueError, match="missing wrapped_dek"):
|
||||
reader.load_meeting_audio(meeting_id)
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_invalid_wrapped_dek_hex_raises(
|
||||
self, crypto: AesGcmCryptoBox, meetings_dir: Path
|
||||
) -> None:
|
||||
"""Invalid hex string in wrapped_dek raises ValueError."""
|
||||
meeting_id = str(uuid4())
|
||||
meeting_dir = meetings_dir / meeting_id
|
||||
meeting_dir.mkdir(parents=True)
|
||||
|
||||
manifest = {
|
||||
"meeting_id": meeting_id,
|
||||
"sample_rate": 16000,
|
||||
"wrapped_dek": "not_valid_hex_!!!",
|
||||
}
|
||||
(meeting_dir / "manifest.json").write_text(json.dumps(manifest))
|
||||
(meeting_dir / "audio.enc").write_bytes(FILE_MAGIC + bytes([FILE_VERSION]))
|
||||
|
||||
reader = MeetingAudioReader(crypto, meetings_dir)
|
||||
with pytest.raises(ValueError):
|
||||
reader.load_meeting_audio(meeting_id)
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_corrupted_wrapped_dek_raises(
|
||||
self, crypto: AesGcmCryptoBox, meetings_dir: Path
|
||||
) -> None:
|
||||
"""Corrupted wrapped_dek (valid hex but invalid content) raises."""
|
||||
meeting_id = str(uuid4())
|
||||
meeting_dir = meetings_dir / meeting_id
|
||||
meeting_dir.mkdir(parents=True)
|
||||
|
||||
dek = crypto.generate_dek()
|
||||
wrapped_dek = crypto.wrap_dek(dek)
|
||||
corrupted = bytearray(wrapped_dek)
|
||||
corrupted[10] ^= 0xFF
|
||||
|
||||
manifest = {
|
||||
"meeting_id": meeting_id,
|
||||
"sample_rate": 16000,
|
||||
"wrapped_dek": bytes(corrupted).hex(),
|
||||
}
|
||||
(meeting_dir / "manifest.json").write_text(json.dumps(manifest))
|
||||
(meeting_dir / "audio.enc").write_bytes(FILE_MAGIC + bytes([FILE_VERSION]))
|
||||
|
||||
reader = MeetingAudioReader(crypto, meetings_dir)
|
||||
with pytest.raises(ValueError, match="unwrap failed"):
|
||||
reader.load_meeting_audio(meeting_id)
|
||||
|
||||
|
||||
class TestWriterReaderRoundTrip:
|
||||
"""Test write-read round trip integrity."""
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_single_chunk_roundtrip(self, crypto: AesGcmCryptoBox, meetings_dir: Path) -> None:
|
||||
"""Single chunk write and read preserves data."""
|
||||
meeting_id = str(uuid4())
|
||||
dek = crypto.generate_dek()
|
||||
wrapped_dek = crypto.wrap_dek(dek)
|
||||
|
||||
original_audio = make_audio(1600)
|
||||
|
||||
writer = MeetingAudioWriter(crypto, meetings_dir)
|
||||
writer.open(meeting_id, dek, wrapped_dek)
|
||||
writer.write_chunk(original_audio)
|
||||
writer.close()
|
||||
|
||||
reader = MeetingAudioReader(crypto, meetings_dir)
|
||||
chunks = reader.load_meeting_audio(meeting_id)
|
||||
|
||||
assert len(chunks) == 1
|
||||
np.testing.assert_array_almost_equal(chunks[0].frames, original_audio, decimal=4)
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_multiple_chunks_roundtrip(self, crypto: AesGcmCryptoBox, meetings_dir: Path) -> None:
|
||||
"""Multiple chunk write and read preserves data."""
|
||||
meeting_id = str(uuid4())
|
||||
dek = crypto.generate_dek()
|
||||
wrapped_dek = crypto.wrap_dek(dek)
|
||||
|
||||
original_chunks = [make_audio(1600) for _ in range(10)]
|
||||
|
||||
writer = MeetingAudioWriter(crypto, meetings_dir)
|
||||
writer.open(meeting_id, dek, wrapped_dek)
|
||||
for chunk in original_chunks:
|
||||
writer.write_chunk(chunk)
|
||||
writer.close()
|
||||
|
||||
reader = MeetingAudioReader(crypto, meetings_dir)
|
||||
loaded_chunks = reader.load_meeting_audio(meeting_id)
|
||||
|
||||
assert len(loaded_chunks) == len(original_chunks)
|
||||
for original, loaded in zip(original_chunks, loaded_chunks, strict=True):
|
||||
np.testing.assert_array_almost_equal(loaded.frames, original, decimal=4)
|
||||
|
||||
@pytest.mark.stress
|
||||
@pytest.mark.slow
|
||||
def test_large_audio_roundtrip(self, crypto: AesGcmCryptoBox, meetings_dir: Path) -> None:
|
||||
"""Large audio file (1000 chunks) write and read succeeds."""
|
||||
meeting_id = str(uuid4())
|
||||
dek = crypto.generate_dek()
|
||||
wrapped_dek = crypto.wrap_dek(dek)
|
||||
|
||||
writer = MeetingAudioWriter(crypto, meetings_dir)
|
||||
writer.open(meeting_id, dek, wrapped_dek)
|
||||
|
||||
np.random.seed(42)
|
||||
chunk_count = 1000
|
||||
for _ in range(chunk_count):
|
||||
writer.write_chunk(make_audio(1600))
|
||||
writer.close()
|
||||
|
||||
reader = MeetingAudioReader(crypto, meetings_dir)
|
||||
chunks = reader.load_meeting_audio(meeting_id)
|
||||
|
||||
assert len(chunks) == chunk_count
|
||||
total_duration = sum(c.duration for c in chunks)
|
||||
expected_duration = chunk_count * (1600 / 16000)
|
||||
assert abs(total_duration - expected_duration) < 0.01
|
||||
|
||||
|
||||
class TestFileVersionHandling:
|
||||
"""Test file version validation."""
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_unsupported_version_raises(self, crypto: AesGcmCryptoBox, meetings_dir: Path) -> None:
|
||||
"""Unsupported file version raises ValueError."""
|
||||
meeting_id = str(uuid4())
|
||||
meeting_dir = meetings_dir / meeting_id
|
||||
meeting_dir.mkdir(parents=True)
|
||||
|
||||
audio_path = meeting_dir / "audio.enc"
|
||||
with audio_path.open("wb") as f:
|
||||
f.write(FILE_MAGIC)
|
||||
f.write(struct.pack("B", 99))
|
||||
|
||||
dek = crypto.generate_dek()
|
||||
reader = ChunkedAssetReader(crypto)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported file version"):
|
||||
reader.open(audio_path, dek)
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_wrong_magic_raises(self, crypto: AesGcmCryptoBox, meetings_dir: Path) -> None:
|
||||
"""Wrong magic bytes raises ValueError."""
|
||||
meeting_id = str(uuid4())
|
||||
meeting_dir = meetings_dir / meeting_id
|
||||
meeting_dir.mkdir(parents=True)
|
||||
|
||||
audio_path = meeting_dir / "audio.enc"
|
||||
audio_path.write_bytes(b"XXXX" + bytes([FILE_VERSION]))
|
||||
|
||||
dek = crypto.generate_dek()
|
||||
reader = ChunkedAssetReader(crypto)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid file format"):
|
||||
reader.open(audio_path, dek)
|
||||
333
tests/stress/test_concurrency_stress.py
Normal file
333
tests/stress/test_concurrency_stress.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""Stress tests for NoteFlowServicer concurrent stream handling.
|
||||
|
||||
Detects race conditions when multiple clients stream simultaneously.
|
||||
Verifies _cleanup_streaming_state prevents memory leaks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.grpc.service import NoteFlowServicer
|
||||
|
||||
|
||||
class TestStreamingStateInitialization:
|
||||
"""Test streaming state initialization correctness."""
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_init_streaming_state_creates_all_state(
|
||||
self, memory_servicer: NoteFlowServicer
|
||||
) -> None:
|
||||
"""Initialize streaming state creates entries in all state dictionaries."""
|
||||
meeting_id = str(uuid4())
|
||||
|
||||
memory_servicer._init_streaming_state(meeting_id, next_segment_id=0)
|
||||
|
||||
assert meeting_id in memory_servicer._partial_buffers
|
||||
assert meeting_id in memory_servicer._vad_instances
|
||||
assert meeting_id in memory_servicer._segmenters
|
||||
assert meeting_id in memory_servicer._was_speaking
|
||||
assert meeting_id in memory_servicer._segment_counters
|
||||
assert meeting_id in memory_servicer._last_partial_time
|
||||
assert meeting_id in memory_servicer._last_partial_text
|
||||
assert meeting_id in memory_servicer._diarization_turns
|
||||
assert meeting_id in memory_servicer._diarization_stream_time
|
||||
|
||||
memory_servicer._cleanup_streaming_state(meeting_id)
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_init_with_different_segment_ids(self, memory_servicer: NoteFlowServicer) -> None:
|
||||
"""Initialize with different segment IDs sets counter correctly."""
|
||||
meeting_id1 = str(uuid4())
|
||||
meeting_id2 = str(uuid4())
|
||||
|
||||
memory_servicer._init_streaming_state(meeting_id1, next_segment_id=0)
|
||||
memory_servicer._init_streaming_state(meeting_id2, next_segment_id=42)
|
||||
|
||||
assert memory_servicer._segment_counters[meeting_id1] == 0
|
||||
assert memory_servicer._segment_counters[meeting_id2] == 42
|
||||
|
||||
memory_servicer._cleanup_streaming_state(meeting_id1)
|
||||
memory_servicer._cleanup_streaming_state(meeting_id2)
|
||||
|
||||
|
||||
class TestCleanupStreamingState:
|
||||
"""Test _cleanup_streaming_state removes all per-meeting resources."""
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_cleanup_removes_all_state_dictionaries(
|
||||
self, memory_servicer: NoteFlowServicer
|
||||
) -> None:
|
||||
"""Verify cleanup removes entries from all state dictionaries."""
|
||||
meeting_id = str(uuid4())
|
||||
|
||||
memory_servicer._init_streaming_state(meeting_id, next_segment_id=0)
|
||||
memory_servicer._active_streams.add(meeting_id)
|
||||
|
||||
assert meeting_id in memory_servicer._partial_buffers
|
||||
assert meeting_id in memory_servicer._vad_instances
|
||||
assert meeting_id in memory_servicer._segmenters
|
||||
|
||||
memory_servicer._cleanup_streaming_state(meeting_id)
|
||||
memory_servicer._active_streams.discard(meeting_id)
|
||||
|
||||
assert meeting_id not in memory_servicer._partial_buffers
|
||||
assert meeting_id not in memory_servicer._vad_instances
|
||||
assert meeting_id not in memory_servicer._segmenters
|
||||
assert meeting_id not in memory_servicer._was_speaking
|
||||
assert meeting_id not in memory_servicer._segment_counters
|
||||
assert meeting_id not in memory_servicer._stream_formats
|
||||
assert meeting_id not in memory_servicer._last_partial_time
|
||||
assert meeting_id not in memory_servicer._last_partial_text
|
||||
assert meeting_id not in memory_servicer._diarization_turns
|
||||
assert meeting_id not in memory_servicer._diarization_stream_time
|
||||
assert meeting_id not in memory_servicer._diarization_streaming_failed
|
||||
assert meeting_id not in memory_servicer._active_streams
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_cleanup_idempotent(self, memory_servicer: NoteFlowServicer) -> None:
|
||||
"""Cleanup is idempotent - multiple calls don't raise."""
|
||||
meeting_id = str(uuid4())
|
||||
|
||||
memory_servicer._init_streaming_state(meeting_id, next_segment_id=0)
|
||||
|
||||
memory_servicer._cleanup_streaming_state(meeting_id)
|
||||
memory_servicer._cleanup_streaming_state(meeting_id)
|
||||
memory_servicer._cleanup_streaming_state(meeting_id)
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_cleanup_nonexistent_meeting_no_error(self, memory_servicer: NoteFlowServicer) -> None:
|
||||
"""Cleanup of non-existent meeting ID doesn't raise."""
|
||||
nonexistent_id = str(uuid4())
|
||||
|
||||
memory_servicer._cleanup_streaming_state(nonexistent_id)
|
||||
|
||||
|
||||
class TestConcurrentStreamInitialization:
|
||||
"""Test concurrent stream initialization."""
|
||||
|
||||
@pytest.mark.stress
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_init_different_meetings(
|
||||
self, memory_servicer: NoteFlowServicer
|
||||
) -> None:
|
||||
"""Multiple concurrent init calls for different meetings succeed."""
|
||||
meeting_ids = [str(uuid4()) for _ in range(20)]
|
||||
|
||||
async def init_meeting(meeting_id: str, segment_id: int) -> None:
|
||||
await asyncio.sleep(0.001)
|
||||
memory_servicer._init_streaming_state(meeting_id, segment_id)
|
||||
|
||||
tasks = [asyncio.create_task(init_meeting(mid, idx)) for idx, mid in enumerate(meeting_ids)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
assert len(memory_servicer._vad_instances) == len(meeting_ids)
|
||||
assert len(memory_servicer._segmenters) == len(meeting_ids)
|
||||
assert len(memory_servicer._partial_buffers) == len(meeting_ids)
|
||||
|
||||
for mid in meeting_ids:
|
||||
memory_servicer._cleanup_streaming_state(mid)
|
||||
|
||||
@pytest.mark.stress
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_cleanup_different_meetings(
|
||||
self, memory_servicer: NoteFlowServicer
|
||||
) -> None:
|
||||
"""Multiple concurrent cleanup calls for different meetings succeed."""
|
||||
meeting_ids = [str(uuid4()) for _ in range(20)]
|
||||
|
||||
for idx, mid in enumerate(meeting_ids):
|
||||
memory_servicer._init_streaming_state(mid, idx)
|
||||
|
||||
async def cleanup_meeting(meeting_id: str) -> None:
|
||||
await asyncio.sleep(0.001)
|
||||
memory_servicer._cleanup_streaming_state(meeting_id)
|
||||
|
||||
tasks = [asyncio.create_task(cleanup_meeting(mid)) for mid in meeting_ids]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
assert len(memory_servicer._vad_instances) == 0
|
||||
assert len(memory_servicer._segmenters) == 0
|
||||
assert len(memory_servicer._partial_buffers) == 0
|
||||
|
||||
|
||||
class TestNoMemoryLeaksUnderLoad:
|
||||
"""Test no memory leaks after many stream cycles."""
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_stream_cycles_cleanup_completely(self, memory_servicer: NoteFlowServicer) -> None:
|
||||
"""Many init/cleanup cycles leave no leaked state."""
|
||||
for _ in range(100):
|
||||
meeting_id = str(uuid4())
|
||||
memory_servicer._init_streaming_state(meeting_id, next_segment_id=0)
|
||||
memory_servicer._active_streams.add(meeting_id)
|
||||
memory_servicer._cleanup_streaming_state(meeting_id)
|
||||
memory_servicer._active_streams.discard(meeting_id)
|
||||
|
||||
assert len(memory_servicer._active_streams) == 0
|
||||
assert len(memory_servicer._partial_buffers) == 0
|
||||
assert len(memory_servicer._vad_instances) == 0
|
||||
assert len(memory_servicer._segmenters) == 0
|
||||
assert len(memory_servicer._was_speaking) == 0
|
||||
assert len(memory_servicer._segment_counters) == 0
|
||||
assert len(memory_servicer._last_partial_time) == 0
|
||||
assert len(memory_servicer._last_partial_text) == 0
|
||||
assert len(memory_servicer._diarization_turns) == 0
|
||||
assert len(memory_servicer._diarization_stream_time) == 0
|
||||
assert len(memory_servicer._diarization_streaming_failed) == 0
|
||||
|
||||
@pytest.mark.stress
|
||||
@pytest.mark.slow
|
||||
def test_many_concurrent_meetings_no_leak(self, memory_servicer: NoteFlowServicer) -> None:
|
||||
"""Many meetings initialized then cleaned up leave no state."""
|
||||
meeting_ids = [str(uuid4()) for _ in range(500)]
|
||||
|
||||
for idx, mid in enumerate(meeting_ids):
|
||||
memory_servicer._init_streaming_state(mid, idx)
|
||||
memory_servicer._active_streams.add(mid)
|
||||
|
||||
assert len(memory_servicer._vad_instances) == 500
|
||||
assert len(memory_servicer._segmenters) == 500
|
||||
|
||||
for mid in meeting_ids:
|
||||
memory_servicer._cleanup_streaming_state(mid)
|
||||
memory_servicer._active_streams.discard(mid)
|
||||
|
||||
assert len(memory_servicer._active_streams) == 0
|
||||
assert len(memory_servicer._vad_instances) == 0
|
||||
assert len(memory_servicer._segmenters) == 0
|
||||
|
||||
@pytest.mark.stress
|
||||
@pytest.mark.asyncio
|
||||
async def test_interleaved_init_cleanup(self, memory_servicer: NoteFlowServicer) -> None:
|
||||
"""Interleaved init and cleanup doesn't leak or corrupt."""
|
||||
for _ in range(50):
|
||||
meeting_ids = [str(uuid4()) for _ in range(10)]
|
||||
|
||||
for idx, mid in enumerate(meeting_ids):
|
||||
memory_servicer._init_streaming_state(mid, idx)
|
||||
|
||||
for mid in meeting_ids[:5]:
|
||||
memory_servicer._cleanup_streaming_state(mid)
|
||||
|
||||
for mid in meeting_ids[5:]:
|
||||
memory_servicer._cleanup_streaming_state(mid)
|
||||
|
||||
assert len(memory_servicer._vad_instances) == 0
|
||||
assert len(memory_servicer._segmenters) == 0
|
||||
|
||||
|
||||
class TestActiveStreamsTracking:
|
||||
"""Test _active_streams set behavior."""
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_active_streams_tracks_active_meetings(self, memory_servicer: NoteFlowServicer) -> None:
|
||||
"""_active_streams correctly tracks active meeting IDs."""
|
||||
meeting_ids = [str(uuid4()) for _ in range(5)]
|
||||
|
||||
for mid in meeting_ids:
|
||||
memory_servicer._active_streams.add(mid)
|
||||
|
||||
assert len(memory_servicer._active_streams) == 5
|
||||
for mid in meeting_ids:
|
||||
assert mid in memory_servicer._active_streams
|
||||
|
||||
for mid in meeting_ids:
|
||||
memory_servicer._active_streams.discard(mid)
|
||||
|
||||
assert len(memory_servicer._active_streams) == 0
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_discard_nonexistent_no_error(self, memory_servicer: NoteFlowServicer) -> None:
|
||||
"""Discarding non-existent meeting ID doesn't raise."""
|
||||
nonexistent = str(uuid4())
|
||||
memory_servicer._active_streams.discard(nonexistent)
|
||||
|
||||
|
||||
class TestDiarizationStateCleanup:
|
||||
"""Test diarization-related state cleanup."""
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_diarization_failed_set_cleaned(self, memory_servicer: NoteFlowServicer) -> None:
|
||||
"""_diarization_streaming_failed set is cleaned on cleanup."""
|
||||
meeting_id = str(uuid4())
|
||||
|
||||
memory_servicer._init_streaming_state(meeting_id, 0)
|
||||
memory_servicer._diarization_streaming_failed.add(meeting_id)
|
||||
|
||||
assert meeting_id in memory_servicer._diarization_streaming_failed
|
||||
|
||||
memory_servicer._cleanup_streaming_state(meeting_id)
|
||||
|
||||
assert meeting_id not in memory_servicer._diarization_streaming_failed
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_diarization_turns_cleaned(self, memory_servicer: NoteFlowServicer) -> None:
|
||||
"""_diarization_turns dict is cleaned on cleanup."""
|
||||
meeting_id = str(uuid4())
|
||||
|
||||
memory_servicer._init_streaming_state(meeting_id, 0)
|
||||
|
||||
assert meeting_id in memory_servicer._diarization_turns
|
||||
assert memory_servicer._diarization_turns[meeting_id] == []
|
||||
|
||||
memory_servicer._cleanup_streaming_state(meeting_id)
|
||||
|
||||
assert meeting_id not in memory_servicer._diarization_turns
|
||||
|
||||
|
||||
class TestServicerInstantiation:
|
||||
"""Test NoteFlowServicer instantiation patterns."""
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_servicer_starts_with_empty_state(self) -> None:
|
||||
"""New servicer has empty state dictionaries."""
|
||||
servicer = NoteFlowServicer()
|
||||
|
||||
assert len(servicer._active_streams) == 0
|
||||
assert len(servicer._partial_buffers) == 0
|
||||
assert len(servicer._vad_instances) == 0
|
||||
assert len(servicer._segmenters) == 0
|
||||
assert len(servicer._was_speaking) == 0
|
||||
assert len(servicer._segment_counters) == 0
|
||||
assert len(servicer._audio_writers) == 0
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_multiple_servicers_independent(self) -> None:
|
||||
"""Multiple servicer instances have independent state."""
|
||||
servicer1 = NoteFlowServicer()
|
||||
servicer2 = NoteFlowServicer()
|
||||
|
||||
meeting_id = str(uuid4())
|
||||
servicer1._init_streaming_state(meeting_id, 0)
|
||||
|
||||
assert meeting_id in servicer1._vad_instances
|
||||
assert meeting_id not in servicer2._vad_instances
|
||||
|
||||
servicer1._cleanup_streaming_state(meeting_id)
|
||||
|
||||
|
||||
class TestMemoryStoreAccess:
|
||||
"""Test memory store access patterns."""
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_get_memory_store_returns_store(self, memory_servicer: NoteFlowServicer) -> None:
|
||||
"""_get_memory_store returns MeetingStore when configured."""
|
||||
store = memory_servicer._get_memory_store()
|
||||
assert store is not None
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_memory_store_create_meeting(self, memory_servicer: NoteFlowServicer) -> None:
|
||||
"""Memory store can create and retrieve meetings."""
|
||||
store = memory_servicer._get_memory_store()
|
||||
|
||||
meeting = store.create(title="Test Meeting")
|
||||
assert meeting is not None
|
||||
assert meeting.title == "Test Meeting"
|
||||
|
||||
retrieved = store.get(str(meeting.id))
|
||||
assert retrieved is not None
|
||||
assert retrieved.title == "Test Meeting"
|
||||
538
tests/stress/test_segmenter_fuzz.py
Normal file
538
tests/stress/test_segmenter_fuzz.py
Normal file
@@ -0,0 +1,538 @@
|
||||
"""Fuzz tests for Segmenter state machine.
|
||||
|
||||
Tests edge cases with rapid VAD transitions and random input sequences.
|
||||
Verifies invariants hold under stress conditions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from noteflow.infrastructure.asr.segmenter import (
|
||||
AudioSegment,
|
||||
Segmenter,
|
||||
SegmenterConfig,
|
||||
SegmenterState,
|
||||
)
|
||||
|
||||
|
||||
def make_audio(duration: float, sample_rate: int = 16000) -> NDArray[np.float32]:
|
||||
"""Create test audio of specified duration with random values."""
|
||||
samples = int(duration * sample_rate)
|
||||
return np.random.uniform(-1.0, 1.0, samples).astype(np.float32)
|
||||
|
||||
|
||||
def make_silence(duration: float, sample_rate: int = 16000) -> NDArray[np.float32]:
|
||||
"""Create silent audio of specified duration."""
|
||||
samples = int(duration * sample_rate)
|
||||
return np.zeros(samples, dtype=np.float32)
|
||||
|
||||
|
||||
class TestSegmenterInvariants:
|
||||
"""Verify segmenter invariants hold under various inputs."""
|
||||
|
||||
@pytest.mark.stress
|
||||
@pytest.mark.parametrize("sample_rate", [16000, 44100, 48000])
|
||||
def test_segment_duration_positive(self, sample_rate: int) -> None:
|
||||
"""All emitted segments have positive duration."""
|
||||
config = SegmenterConfig(
|
||||
sample_rate=sample_rate,
|
||||
min_speech_duration=0.0,
|
||||
trailing_silence=0.1,
|
||||
)
|
||||
segmenter = Segmenter(config=config)
|
||||
|
||||
random.seed(42)
|
||||
segments: list[AudioSegment] = []
|
||||
|
||||
for _ in range(100):
|
||||
audio = make_audio(0.1, sample_rate)
|
||||
is_speech = random.random() > 0.5
|
||||
segments.extend(segmenter.process_audio(audio, is_speech))
|
||||
|
||||
if final := segmenter.flush():
|
||||
segments.append(final)
|
||||
|
||||
for seg in segments:
|
||||
assert seg.duration > 0, f"Segment duration must be positive: {seg.duration}"
|
||||
assert seg.end_time > seg.start_time
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_segment_audio_length_matches_duration(self) -> None:
|
||||
"""Segment audio length matches (end_time - start_time) * sample_rate."""
|
||||
sample_rate = 16000
|
||||
config = SegmenterConfig(
|
||||
sample_rate=sample_rate,
|
||||
min_speech_duration=0.0,
|
||||
trailing_silence=0.1,
|
||||
leading_buffer=0.0,
|
||||
)
|
||||
segmenter = Segmenter(config=config)
|
||||
|
||||
speech = make_audio(0.5, sample_rate)
|
||||
silence = make_silence(0.2, sample_rate)
|
||||
|
||||
list(segmenter.process_audio(speech, is_speech=True))
|
||||
segments = list(segmenter.process_audio(silence, is_speech=False))
|
||||
|
||||
for seg in segments:
|
||||
expected_samples = int(seg.duration * sample_rate)
|
||||
actual_samples = len(seg.audio)
|
||||
assert abs(actual_samples - expected_samples) <= 1, (
|
||||
f"Audio length {actual_samples} != expected {expected_samples}"
|
||||
)
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_segments_strictly_sequential(self) -> None:
|
||||
"""Emitted segments have non-overlapping, sequential time ranges."""
|
||||
config = SegmenterConfig(
|
||||
sample_rate=16000,
|
||||
min_speech_duration=0.0,
|
||||
trailing_silence=0.05,
|
||||
)
|
||||
segmenter = Segmenter(config=config)
|
||||
|
||||
random.seed(123)
|
||||
segments: list[AudioSegment] = []
|
||||
|
||||
for _ in range(50):
|
||||
audio = make_audio(0.05)
|
||||
is_speech = random.random() > 0.3
|
||||
segments.extend(segmenter.process_audio(audio, is_speech))
|
||||
|
||||
if final := segmenter.flush():
|
||||
segments.append(final)
|
||||
|
||||
for i in range(1, len(segments)):
|
||||
prev_end = segments[i - 1].end_time
|
||||
curr_start = segments[i].start_time
|
||||
assert curr_start >= prev_end, (
|
||||
f"Segment overlap: prev_end={prev_end}, curr_start={curr_start}"
|
||||
)
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_all_segments_have_audio(self) -> None:
|
||||
"""All emitted segments contain audio data."""
|
||||
config = SegmenterConfig(
|
||||
sample_rate=16000,
|
||||
min_speech_duration=0.0,
|
||||
trailing_silence=0.1,
|
||||
)
|
||||
segmenter = Segmenter(config=config)
|
||||
|
||||
random.seed(456)
|
||||
segments: list[AudioSegment] = []
|
||||
|
||||
for _ in range(100):
|
||||
audio = make_audio(0.05)
|
||||
is_speech = random.random() > 0.4
|
||||
segments.extend(segmenter.process_audio(audio, is_speech))
|
||||
|
||||
if final := segmenter.flush():
|
||||
segments.append(final)
|
||||
|
||||
for seg in segments:
|
||||
assert len(seg.audio) > 0, "Segment must contain audio data"
|
||||
|
||||
|
||||
class TestRapidVadTransitions:
|
||||
"""Test rapid VAD state transitions (chattering)."""
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_rapid_speech_silence_alternation(self) -> None:
|
||||
"""Rapid alternation between speech and silence doesn't crash."""
|
||||
config = SegmenterConfig(
|
||||
sample_rate=16000,
|
||||
min_speech_duration=0.0,
|
||||
trailing_silence=0.05,
|
||||
)
|
||||
segmenter = Segmenter(config=config)
|
||||
|
||||
for i in range(1000):
|
||||
audio = make_silence(0.01) # 10ms at 16kHz
|
||||
is_speech = i % 2 == 0
|
||||
list(segmenter.process_audio(audio, is_speech))
|
||||
|
||||
assert segmenter.state in (
|
||||
SegmenterState.IDLE,
|
||||
SegmenterState.SPEECH,
|
||||
SegmenterState.TRAILING,
|
||||
)
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_single_sample_chunks(self) -> None:
|
||||
"""Processing single-sample chunks doesn't crash or produce invalid segments."""
|
||||
config = SegmenterConfig(
|
||||
sample_rate=16000,
|
||||
min_speech_duration=0.0,
|
||||
trailing_silence=0.01,
|
||||
)
|
||||
segmenter = Segmenter(config=config)
|
||||
|
||||
random.seed(789)
|
||||
segments: list[AudioSegment] = []
|
||||
|
||||
for i in range(1000):
|
||||
audio = np.array([random.uniform(-1, 1)], dtype=np.float32)
|
||||
is_speech = i % 10 < 5
|
||||
segments.extend(segmenter.process_audio(audio, is_speech))
|
||||
|
||||
for seg in segments:
|
||||
assert seg.duration >= 0
|
||||
assert len(seg.audio) > 0
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_very_short_speech_bursts(self) -> None:
|
||||
"""Very short speech bursts are handled correctly."""
|
||||
config = SegmenterConfig(
|
||||
sample_rate=16000,
|
||||
min_speech_duration=0.0,
|
||||
trailing_silence=0.02,
|
||||
)
|
||||
segmenter = Segmenter(config=config)
|
||||
|
||||
segments: list[AudioSegment] = []
|
||||
|
||||
for _ in range(100):
|
||||
speech = make_audio(0.01)
|
||||
silence = make_silence(0.05)
|
||||
|
||||
segments.extend(segmenter.process_audio(speech, is_speech=True))
|
||||
segments.extend(segmenter.process_audio(silence, is_speech=False))
|
||||
|
||||
if final := segmenter.flush():
|
||||
segments.append(final)
|
||||
|
||||
for seg in segments:
|
||||
assert seg.duration > 0
|
||||
assert seg.end_time > seg.start_time
|
||||
|
||||
|
||||
class TestEdgeCaseConfigurations:
|
||||
"""Test edge case segmenter configurations."""
|
||||
|
||||
@pytest.mark.stress
|
||||
@pytest.mark.parametrize("min_speech", [0.0, 0.001, 0.01, 0.1, 1.0])
|
||||
def test_various_min_speech_durations(self, min_speech: float) -> None:
|
||||
"""Various min_speech_duration values work correctly."""
|
||||
config = SegmenterConfig(
|
||||
sample_rate=16000,
|
||||
min_speech_duration=min_speech,
|
||||
trailing_silence=0.1,
|
||||
)
|
||||
segmenter = Segmenter(config=config)
|
||||
|
||||
speech = make_audio(1.0)
|
||||
silence = make_silence(0.2)
|
||||
|
||||
segments_speech = list(segmenter.process_audio(speech, is_speech=True))
|
||||
segments_silence = list(segmenter.process_audio(silence, is_speech=False))
|
||||
|
||||
all_segments = segments_speech + segments_silence
|
||||
|
||||
for seg in all_segments:
|
||||
assert seg.duration > 0, f"Segment duration must be positive: {seg.duration}"
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_zero_trailing_silence(self) -> None:
|
||||
"""Zero trailing_silence immediately emits on silence."""
|
||||
config = SegmenterConfig(
|
||||
sample_rate=16000,
|
||||
min_speech_duration=0.0,
|
||||
trailing_silence=0.0,
|
||||
)
|
||||
segmenter = Segmenter(config=config)
|
||||
|
||||
speech = make_audio(0.1)
|
||||
silence = make_silence(0.01)
|
||||
|
||||
list(segmenter.process_audio(speech, is_speech=True))
|
||||
segments = list(segmenter.process_audio(silence, is_speech=False))
|
||||
|
||||
assert len(segments) == 1
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_max_duration_forced_split(self) -> None:
|
||||
"""Segments are force-split at max_segment_duration."""
|
||||
config = SegmenterConfig(
|
||||
sample_rate=16000,
|
||||
max_segment_duration=0.5,
|
||||
min_speech_duration=0.0,
|
||||
)
|
||||
segmenter = Segmenter(config=config)
|
||||
|
||||
segments: list[AudioSegment] = []
|
||||
|
||||
for _ in range(20):
|
||||
audio = make_audio(0.1)
|
||||
segments.extend(segmenter.process_audio(audio, is_speech=True))
|
||||
|
||||
assert len(segments) >= 3, f"Expected at least 3 splits, got {len(segments)}"
|
||||
|
||||
for seg in segments:
|
||||
assert seg.duration <= config.max_segment_duration + 0.2
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_zero_leading_buffer(self) -> None:
|
||||
"""Zero leading_buffer doesn't include pre-speech audio."""
|
||||
config = SegmenterConfig(
|
||||
sample_rate=16000,
|
||||
min_speech_duration=0.0,
|
||||
trailing_silence=0.1,
|
||||
leading_buffer=0.0,
|
||||
)
|
||||
segmenter = Segmenter(config=config)
|
||||
|
||||
silence = make_silence(0.5)
|
||||
speech = make_audio(0.3)
|
||||
more_silence = make_silence(0.2)
|
||||
|
||||
list(segmenter.process_audio(silence, is_speech=False))
|
||||
list(segmenter.process_audio(speech, is_speech=True))
|
||||
segments = list(segmenter.process_audio(more_silence, is_speech=False))
|
||||
|
||||
assert len(segments) == 1
|
||||
seg = segments[0]
|
||||
expected_duration = 0.3 + 0.2
|
||||
assert abs(seg.duration - expected_duration) < 0.05
|
||||
|
||||
@pytest.mark.stress
|
||||
@pytest.mark.parametrize("leading_buffer", [0.0, 0.1, 0.2, 0.5, 1.0])
|
||||
def test_various_leading_buffers(self, leading_buffer: float) -> None:
|
||||
"""Various leading_buffer values work correctly."""
|
||||
config = SegmenterConfig(
|
||||
sample_rate=16000,
|
||||
min_speech_duration=0.0,
|
||||
trailing_silence=0.1,
|
||||
leading_buffer=leading_buffer,
|
||||
)
|
||||
segmenter = Segmenter(config=config)
|
||||
|
||||
silence = make_silence(0.5)
|
||||
speech = make_audio(0.3)
|
||||
more_silence = make_silence(0.2)
|
||||
|
||||
list(segmenter.process_audio(silence, is_speech=False))
|
||||
list(segmenter.process_audio(speech, is_speech=True))
|
||||
if segments := list(
|
||||
segmenter.process_audio(more_silence, is_speech=False)
|
||||
):
|
||||
seg = segments[0]
|
||||
assert seg.duration > 0
|
||||
|
||||
|
||||
class TestStateTransitions:
|
||||
"""Test specific state transition scenarios."""
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_idle_to_speech_to_idle(self) -> None:
|
||||
"""IDLE -> SPEECH -> IDLE transition."""
|
||||
config = SegmenterConfig(
|
||||
sample_rate=16000,
|
||||
min_speech_duration=0.0,
|
||||
trailing_silence=0.1,
|
||||
)
|
||||
segmenter = Segmenter(config=config)
|
||||
|
||||
assert segmenter.state == SegmenterState.IDLE
|
||||
|
||||
speech = make_audio(0.2)
|
||||
list(segmenter.process_audio(speech, is_speech=True))
|
||||
assert segmenter.state == SegmenterState.SPEECH
|
||||
|
||||
silence = make_silence(0.2)
|
||||
list(segmenter.process_audio(silence, is_speech=False))
|
||||
assert segmenter.state == SegmenterState.IDLE
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_trailing_back_to_speech(self) -> None:
|
||||
"""TRAILING -> SPEECH transition when speech resumes."""
|
||||
config = SegmenterConfig(
|
||||
sample_rate=16000,
|
||||
min_speech_duration=0.0,
|
||||
trailing_silence=0.5,
|
||||
)
|
||||
segmenter = Segmenter(config=config)
|
||||
|
||||
speech = make_audio(0.2)
|
||||
list(segmenter.process_audio(speech, is_speech=True))
|
||||
|
||||
short_silence = make_silence(0.1)
|
||||
list(segmenter.process_audio(short_silence, is_speech=False))
|
||||
assert segmenter.state == SegmenterState.TRAILING
|
||||
|
||||
more_speech = make_audio(0.2)
|
||||
list(segmenter.process_audio(more_speech, is_speech=True))
|
||||
assert segmenter.state == SegmenterState.SPEECH
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_flush_from_speech_state(self) -> None:
|
||||
"""Flush from SPEECH state emits segment."""
|
||||
config = SegmenterConfig(
|
||||
sample_rate=16000,
|
||||
min_speech_duration=0.0,
|
||||
)
|
||||
segmenter = Segmenter(config=config)
|
||||
|
||||
speech = make_audio(0.3)
|
||||
list(segmenter.process_audio(speech, is_speech=True))
|
||||
assert segmenter.state == SegmenterState.SPEECH
|
||||
|
||||
segment = segmenter.flush()
|
||||
assert segment is not None
|
||||
assert segment.duration > 0
|
||||
assert segmenter.state == SegmenterState.IDLE
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_flush_from_trailing_state(self) -> None:
|
||||
"""Flush from TRAILING state emits segment."""
|
||||
config = SegmenterConfig(
|
||||
sample_rate=16000,
|
||||
min_speech_duration=0.0,
|
||||
trailing_silence=1.0,
|
||||
)
|
||||
segmenter = Segmenter(config=config)
|
||||
|
||||
speech = make_audio(0.3)
|
||||
list(segmenter.process_audio(speech, is_speech=True))
|
||||
|
||||
silence = make_silence(0.1)
|
||||
list(segmenter.process_audio(silence, is_speech=False))
|
||||
assert segmenter.state == SegmenterState.TRAILING
|
||||
|
||||
segment = segmenter.flush()
|
||||
assert segment is not None
|
||||
assert segment.duration > 0
|
||||
assert segmenter.state == SegmenterState.IDLE
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_flush_from_idle_returns_none(self) -> None:
|
||||
"""Flush from IDLE state returns None."""
|
||||
config = SegmenterConfig(sample_rate=16000)
|
||||
segmenter = Segmenter(config=config)
|
||||
|
||||
assert segmenter.state == SegmenterState.IDLE
|
||||
segment = segmenter.flush()
|
||||
assert segment is None
|
||||
|
||||
|
||||
class TestFuzzRandomPatterns:
|
||||
"""Fuzz testing with random VAD patterns."""
|
||||
|
||||
@pytest.mark.stress
|
||||
@pytest.mark.slow
|
||||
def test_random_vad_patterns_1000_iterations(self) -> None:
|
||||
"""Run 1000 random VAD pattern iterations."""
|
||||
for seed in range(1000):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
config = SegmenterConfig(
|
||||
sample_rate=16000,
|
||||
min_speech_duration=random.uniform(0, 0.5),
|
||||
max_segment_duration=random.uniform(1, 10),
|
||||
trailing_silence=random.uniform(0.05, 0.5),
|
||||
leading_buffer=random.uniform(0, 0.3),
|
||||
)
|
||||
segmenter = Segmenter(config=config)
|
||||
|
||||
segments: list[AudioSegment] = []
|
||||
|
||||
for _ in range(random.randint(10, 100)):
|
||||
duration = random.uniform(0.01, 0.5)
|
||||
audio = make_audio(duration)
|
||||
is_speech = random.random() > 0.4
|
||||
segments.extend(segmenter.process_audio(audio, is_speech))
|
||||
|
||||
if final := segmenter.flush():
|
||||
segments.append(final)
|
||||
|
||||
for seg in segments:
|
||||
assert seg.duration > 0, f"Seed {seed}: duration must be positive"
|
||||
assert seg.end_time > seg.start_time, f"Seed {seed}: end > start"
|
||||
assert len(seg.audio) > 0, f"Seed {seed}: audio must exist"
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_deterministic_with_same_seed(self) -> None:
|
||||
"""Same random seed produces same segments."""
|
||||
|
||||
def run_with_seed(seed: int) -> list[tuple[float, float]]:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
config = SegmenterConfig(
|
||||
sample_rate=16000,
|
||||
min_speech_duration=0.0,
|
||||
trailing_silence=0.1,
|
||||
)
|
||||
segmenter = Segmenter(config=config)
|
||||
segments: list[AudioSegment] = []
|
||||
|
||||
for _ in range(50):
|
||||
duration = random.uniform(0.05, 0.2)
|
||||
audio = make_audio(duration)
|
||||
is_speech = random.random() > 0.5
|
||||
segments.extend(segmenter.process_audio(audio, is_speech))
|
||||
|
||||
if final := segmenter.flush():
|
||||
segments.append(final)
|
||||
|
||||
return [(s.start_time, s.end_time) for s in segments]
|
||||
|
||||
result1 = run_with_seed(999)
|
||||
result2 = run_with_seed(999)
|
||||
|
||||
assert result1 == result2
|
||||
|
||||
|
||||
class TestResetBehavior:
|
||||
"""Test reset functionality."""
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_reset_clears_all_state(self) -> None:
|
||||
"""Reset clears all internal state."""
|
||||
config = SegmenterConfig(
|
||||
sample_rate=16000,
|
||||
min_speech_duration=0.0,
|
||||
trailing_silence=0.5,
|
||||
)
|
||||
segmenter = Segmenter(config=config)
|
||||
|
||||
speech = make_audio(0.5)
|
||||
list(segmenter.process_audio(speech, is_speech=True))
|
||||
|
||||
silence = make_silence(0.1)
|
||||
list(segmenter.process_audio(silence, is_speech=False))
|
||||
|
||||
segmenter.reset()
|
||||
|
||||
assert segmenter.state == SegmenterState.IDLE
|
||||
|
||||
@pytest.mark.stress
|
||||
def test_reset_allows_fresh_processing(self) -> None:
|
||||
"""After reset, segmenter works from fresh state."""
|
||||
config = SegmenterConfig(
|
||||
sample_rate=16000,
|
||||
min_speech_duration=0.0,
|
||||
trailing_silence=0.1,
|
||||
)
|
||||
segmenter = Segmenter(config=config)
|
||||
|
||||
speech1 = make_audio(0.3)
|
||||
list(segmenter.process_audio(speech1, is_speech=True))
|
||||
silence1 = make_silence(0.2)
|
||||
segments1 = list(segmenter.process_audio(silence1, is_speech=False))
|
||||
|
||||
segmenter.reset()
|
||||
|
||||
speech2 = make_audio(0.3)
|
||||
list(segmenter.process_audio(speech2, is_speech=True))
|
||||
silence2 = make_silence(0.2)
|
||||
segments2 = list(segmenter.process_audio(silence2, is_speech=False))
|
||||
|
||||
assert len(segments1) == len(segments2) == 1
|
||||
assert segments2[0].start_time == 0.0
|
||||
376
tests/stress/test_transaction_boundaries.py
Normal file
376
tests/stress/test_transaction_boundaries.py
Normal file
@@ -0,0 +1,376 @@
|
||||
"""Tests for SqlAlchemyUnitOfWork transaction boundaries and rollback behavior.
|
||||
|
||||
Verifies rollback works correctly when operations fail mid-transaction.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.domain.entities.meeting import Meeting
|
||||
from noteflow.domain.entities.segment import Segment
|
||||
from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWork
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
pytestmark = [pytest.mark.integration, pytest.mark.stress]
|
||||
|
||||
|
||||
class TestExceptionRollback:
|
||||
"""Test automatic rollback on exception."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_during_context_rolls_back(
|
||||
self, postgres_session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Exception in context manager triggers automatic rollback."""
|
||||
meeting = Meeting.create(title="Rollback Test")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
await uow.meetings.create(meeting)
|
||||
raise RuntimeError("Simulated failure")
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
result = await uow.meetings.get(meeting.id)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rollback_after_multiple_operations(
|
||||
self, postgres_session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Rollback after multiple operations reverts all changes."""
|
||||
meeting = Meeting.create(title="Multi-op Rollback")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
await uow.meetings.create(meeting)
|
||||
|
||||
segment = Segment(
|
||||
segment_id=0,
|
||||
text="Test segment",
|
||||
start_time=0.0,
|
||||
end_time=1.0,
|
||||
meeting_id=meeting.id,
|
||||
)
|
||||
await uow.segments.add(meeting.id, segment)
|
||||
|
||||
raise ValueError("Simulated batch failure")
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
result = await uow.meetings.get(meeting.id)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_type_does_not_matter(
|
||||
self, postgres_session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Any exception type triggers rollback."""
|
||||
meeting = Meeting.create(title="Exception Type Test")
|
||||
|
||||
class CustomError(Exception):
|
||||
pass
|
||||
|
||||
with pytest.raises(CustomError):
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
await uow.meetings.create(meeting)
|
||||
raise CustomError("Custom error")
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
result = await uow.meetings.get(meeting.id)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestExplicitRollback:
|
||||
"""Test explicit rollback behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explicit_rollback_reverts_changes(
|
||||
self, postgres_session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Explicit rollback() call reverts uncommitted changes."""
|
||||
meeting = Meeting.create(title="Explicit Rollback")
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
await uow.meetings.create(meeting)
|
||||
await uow.rollback()
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
result = await uow.meetings.get(meeting.id)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commit_after_rollback_is_no_op(
|
||||
self, postgres_session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Commit after rollback doesn't resurrect rolled-back data."""
|
||||
meeting = Meeting.create(title="Commit After Rollback")
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
await uow.meetings.create(meeting)
|
||||
await uow.rollback()
|
||||
await uow.commit()
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
result = await uow.meetings.get(meeting.id)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestCommitPersistence:
|
||||
"""Test that committed data persists."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_committed_data_visible_in_new_uow(
|
||||
self, postgres_session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Data committed in one UoW is visible in subsequent UoW."""
|
||||
meeting = Meeting.create(title="Visibility Test")
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
await uow.meetings.create(meeting)
|
||||
await uow.commit()
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
result = await uow.meetings.get(meeting.id)
|
||||
assert result is not None
|
||||
assert result.title == "Visibility Test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_committed_meeting_and_segment(
|
||||
self, postgres_session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Committed meeting and segment both persist."""
|
||||
meeting = Meeting.create(title="Meeting With Segment")
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
await uow.meetings.create(meeting)
|
||||
await uow.commit()
|
||||
|
||||
segment = Segment(
|
||||
segment_id=0,
|
||||
text="Test segment text",
|
||||
start_time=0.0,
|
||||
end_time=1.5,
|
||||
meeting_id=meeting.id,
|
||||
)
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
await uow.segments.add(meeting.id, segment)
|
||||
await uow.commit()
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
segments = await uow.segments.get_by_meeting(meeting.id)
|
||||
assert len(segments) == 1
|
||||
assert segments[0].text == "Test segment text"
|
||||
|
||||
|
||||
class TestBatchOperationRollback:
|
||||
"""Test rollback behavior with batch operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_segment_add_rollback(
|
||||
self, postgres_session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Batch segment operations are fully rolled back on failure."""
|
||||
meeting = Meeting.create(title="Batch Rollback Test")
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
await uow.meetings.create(meeting)
|
||||
await uow.commit()
|
||||
|
||||
segments = [
|
||||
Segment(
|
||||
segment_id=i,
|
||||
text=f"Segment {i}",
|
||||
start_time=float(i),
|
||||
end_time=float(i + 1),
|
||||
meeting_id=meeting.id,
|
||||
)
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
await uow.segments.add_batch(meeting.id, segments)
|
||||
raise RuntimeError("Batch failure")
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
result = await uow.segments.get_by_meeting(meeting.id)
|
||||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_batch_no_partial_persist(
|
||||
self, postgres_session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Failure mid-batch doesn't leave partial data."""
|
||||
meeting = Meeting.create(title="Partial Batch Test")
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
await uow.meetings.create(meeting)
|
||||
await uow.commit()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
for i in range(5):
|
||||
segment = Segment(
|
||||
segment_id=i,
|
||||
text=f"Segment {i}",
|
||||
start_time=float(i),
|
||||
end_time=float(i + 1),
|
||||
meeting_id=meeting.id,
|
||||
)
|
||||
await uow.segments.add(meeting.id, segment)
|
||||
|
||||
raise ValueError("Mid-batch failure")
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
result = await uow.segments.get_by_meeting(meeting.id)
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
class TestIsolation:
|
||||
"""Test transaction isolation between UoW instances."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uncommitted_data_not_visible_externally(
|
||||
self, postgres_session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Uncommitted data in one UoW not visible in another."""
|
||||
meeting = Meeting.create(title="Isolation Test")
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow1:
|
||||
await uow1.meetings.create(meeting)
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow2:
|
||||
result = await uow2.meetings.get(meeting.id)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_independent_uow_transactions(
|
||||
self, postgres_session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Two UoW instances have independent transactions."""
|
||||
meeting1 = Meeting.create(title="Meeting 1")
|
||||
meeting2 = Meeting.create(title="Meeting 2")
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow1:
|
||||
await uow1.meetings.create(meeting1)
|
||||
await uow1.commit()
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow2:
|
||||
await uow2.meetings.create(meeting2)
|
||||
await uow2.rollback()
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
result1 = await uow.meetings.get(meeting1.id)
|
||||
result2 = await uow.meetings.get(meeting2.id)
|
||||
|
||||
assert result1 is not None
|
||||
assert result2 is None
|
||||
|
||||
|
||||
class TestMeetingStateRollback:
|
||||
"""Test rollback on meeting state changes."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meeting_state_change_rollback(
|
||||
self, postgres_session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Meeting state changes are rolled back on failure."""
|
||||
meeting = Meeting.create(title="State Rollback")
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
await uow.meetings.create(meeting)
|
||||
await uow.commit()
|
||||
|
||||
original_state = meeting.state
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
m = await uow.meetings.get(meeting.id)
|
||||
assert m is not None
|
||||
m.start_recording()
|
||||
await uow.meetings.update(m)
|
||||
raise ValueError("Business logic failure")
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
result = await uow.meetings.get(meeting.id)
|
||||
assert result is not None
|
||||
assert result.state == original_state
|
||||
|
||||
|
||||
class TestRepositoryContextRequirement:
|
||||
"""Test that repositories require UoW context."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_repo_access_outside_context_raises(
|
||||
self, postgres_session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Accessing repository outside context raises RuntimeError."""
|
||||
uow = SqlAlchemyUnitOfWork(postgres_session_factory)
|
||||
|
||||
with pytest.raises(RuntimeError, match="UnitOfWork not in context"):
|
||||
_ = uow.meetings
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commit_outside_context_raises(
|
||||
self, postgres_session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Calling commit outside context raises RuntimeError."""
|
||||
uow = SqlAlchemyUnitOfWork(postgres_session_factory)
|
||||
|
||||
with pytest.raises(RuntimeError, match="UnitOfWork not in context"):
|
||||
await uow.commit()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rollback_outside_context_raises(
|
||||
self, postgres_session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Calling rollback outside context raises RuntimeError."""
|
||||
uow = SqlAlchemyUnitOfWork(postgres_session_factory)
|
||||
|
||||
with pytest.raises(RuntimeError, match="UnitOfWork not in context"):
|
||||
await uow.rollback()
|
||||
|
||||
|
||||
class TestMultipleMeetingOperations:
|
||||
"""Test transactions spanning multiple meetings."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_meetings_atomic(
|
||||
self, postgres_session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Multiple meeting creates are atomic."""
|
||||
meetings = [Meeting.create(title=f"Meeting {i}") for i in range(5)]
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
for meeting in meetings:
|
||||
await uow.meetings.create(meeting)
|
||||
raise RuntimeError("All or nothing")
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
for meeting in meetings:
|
||||
result = await uow.meetings.get(meeting.id)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_meetings_commit_all(
|
||||
self, postgres_session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Multiple meetings commit together."""
|
||||
meetings = [Meeting.create(title=f"Meeting {i}") for i in range(5)]
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
for meeting in meetings:
|
||||
await uow.meetings.create(meeting)
|
||||
await uow.commit()
|
||||
|
||||
async with SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
|
||||
for meeting in meetings:
|
||||
result = await uow.meetings.get(meeting.id)
|
||||
assert result is not None
|
||||
assert meeting.title in result.title
|
||||
80
uv.lock
generated
80
uv.lock
generated
@@ -1890,7 +1890,6 @@ dependencies = [
|
||||
{ name = "alembic" },
|
||||
{ name = "asyncpg" },
|
||||
{ name = "cryptography" },
|
||||
{ name = "diart" },
|
||||
{ name = "faster-whisper" },
|
||||
{ name = "flet" },
|
||||
{ name = "grpcio" },
|
||||
@@ -1936,6 +1935,7 @@ triggers = [
|
||||
[package.dev-dependencies]
|
||||
dev = [
|
||||
{ name = "ruff" },
|
||||
{ name = "watchfiles" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
@@ -1945,8 +1945,7 @@ requires-dist = [
|
||||
{ name = "asyncpg", specifier = ">=0.29" },
|
||||
{ name = "basedpyright", marker = "extra == 'dev'", specifier = ">=1.18" },
|
||||
{ name = "cryptography", specifier = ">=42.0" },
|
||||
{ name = "diart", specifier = ">=0.9.2" },
|
||||
{ name = "diart", marker = "extra == 'diarization'", specifier = ">=0.9" },
|
||||
{ name = "diart", marker = "extra == 'diarization'", specifier = ">=0.9.2" },
|
||||
{ name = "faster-whisper", specifier = ">=1.0" },
|
||||
{ name = "flet", specifier = ">=0.21" },
|
||||
{ name = "grpcio", specifier = ">=1.60" },
|
||||
@@ -1978,7 +1977,10 @@ requires-dist = [
|
||||
provides-extras = ["dev", "triggers", "summarization", "diarization"]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [{ name = "ruff", specifier = ">=0.14.9" }]
|
||||
dev = [
|
||||
{ name = "ruff", specifier = ">=0.14.9" },
|
||||
{ name = "watchfiles", specifier = ">=1.1.1" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "numpy"
|
||||
@@ -6488,6 +6490,76 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/6d/b9/4095b668ea3678bf6a0af005527f39de12fb026516fb3df17495a733b7f8/urllib3-2.6.2-py3-none-any.whl", hash = "sha256:ec21cddfe7724fc7cb4ba4bea7aa8e2ef36f607a4bab81aa6ce42a13dc3f03dd", size = 131182, upload-time = "2025-12-11T15:56:38.584Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "watchfiles"
|
||||
version = "1.1.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c2/c9/8869df9b2a2d6c59d79220a4db37679e74f807c559ffe5265e08b227a210/watchfiles-1.1.1.tar.gz", hash = "sha256:a173cb5c16c4f40ab19cecf48a534c409f7ea983ab8fed0741304a1c0a31b3f2", size = 94440, upload-time = "2025-10-14T15:06:21.08Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/74/d5/f039e7e3c639d9b1d09b07ea412a6806d38123f0508e5f9b48a87b0a76cc/watchfiles-1.1.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:8c89f9f2f740a6b7dcc753140dd5e1ab9215966f7a3530d0c0705c83b401bd7d", size = 404745, upload-time = "2025-10-14T15:04:46.731Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a5/96/a881a13aa1349827490dab2d363c8039527060cfcc2c92cc6d13d1b1049e/watchfiles-1.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bd404be08018c37350f0d6e34676bd1e2889990117a2b90070b3007f172d0610", size = 391769, upload-time = "2025-10-14T15:04:48.003Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4b/5b/d3b460364aeb8da471c1989238ea0e56bec24b6042a68046adf3d9ddb01c/watchfiles-1.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8526e8f916bb5b9a0a777c8317c23ce65de259422bba5b31325a6fa6029d33af", size = 449374, upload-time = "2025-10-14T15:04:49.179Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b9/44/5769cb62d4ed055cb17417c0a109a92f007114a4e07f30812a73a4efdb11/watchfiles-1.1.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2edc3553362b1c38d9f06242416a5d8e9fe235c204a4072e988ce2e5bb1f69f6", size = 459485, upload-time = "2025-10-14T15:04:50.155Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/19/0c/286b6301ded2eccd4ffd0041a1b726afda999926cf720aab63adb68a1e36/watchfiles-1.1.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:30f7da3fb3f2844259cba4720c3fc7138eb0f7b659c38f3bfa65084c7fc7abce", size = 488813, upload-time = "2025-10-14T15:04:51.059Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c7/2b/8530ed41112dd4a22f4dcfdb5ccf6a1baad1ff6eed8dc5a5f09e7e8c41c7/watchfiles-1.1.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f8979280bdafff686ba5e4d8f97840f929a87ed9cdf133cbbd42f7766774d2aa", size = 594816, upload-time = "2025-10-14T15:04:52.031Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ce/d2/f5f9fb49489f184f18470d4f99f4e862a4b3e9ac2865688eb2099e3d837a/watchfiles-1.1.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dcc5c24523771db3a294c77d94771abcfcb82a0e0ee8efd910c37c59ec1b31bb", size = 475186, upload-time = "2025-10-14T15:04:53.064Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cf/68/5707da262a119fb06fbe214d82dd1fe4a6f4af32d2d14de368d0349eb52a/watchfiles-1.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1db5d7ae38ff20153d542460752ff397fcf5c96090c1230803713cf3147a6803", size = 456812, upload-time = "2025-10-14T15:04:55.174Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/66/ab/3cbb8756323e8f9b6f9acb9ef4ec26d42b2109bce830cc1f3468df20511d/watchfiles-1.1.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:28475ddbde92df1874b6c5c8aaeb24ad5be47a11f87cde5a28ef3835932e3e94", size = 630196, upload-time = "2025-10-14T15:04:56.22Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/78/46/7152ec29b8335f80167928944a94955015a345440f524d2dfe63fc2f437b/watchfiles-1.1.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:36193ed342f5b9842edd3532729a2ad55c4160ffcfa3700e0d54be496b70dd43", size = 622657, upload-time = "2025-10-14T15:04:57.521Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0a/bf/95895e78dd75efe9a7f31733607f384b42eb5feb54bd2eb6ed57cc2e94f4/watchfiles-1.1.1-cp312-cp312-win32.whl", hash = "sha256:859e43a1951717cc8de7f4c77674a6d389b106361585951d9e69572823f311d9", size = 272042, upload-time = "2025-10-14T15:04:59.046Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/87/0a/90eb755f568de2688cb220171c4191df932232c20946966c27a59c400850/watchfiles-1.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:91d4c9a823a8c987cce8fa2690923b069966dabb196dd8d137ea2cede885fde9", size = 288410, upload-time = "2025-10-14T15:05:00.081Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/36/76/f322701530586922fbd6723c4f91ace21364924822a8772c549483abed13/watchfiles-1.1.1-cp312-cp312-win_arm64.whl", hash = "sha256:a625815d4a2bdca61953dbba5a39d60164451ef34c88d751f6c368c3ea73d404", size = 278209, upload-time = "2025-10-14T15:05:01.168Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/f4/f750b29225fe77139f7ae5de89d4949f5a99f934c65a1f1c0b248f26f747/watchfiles-1.1.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:130e4876309e8686a5e37dba7d5e9bc77e6ed908266996ca26572437a5271e18", size = 404321, upload-time = "2025-10-14T15:05:02.063Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2b/f9/f07a295cde762644aa4c4bb0f88921d2d141af45e735b965fb2e87858328/watchfiles-1.1.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5f3bde70f157f84ece3765b42b4a52c6ac1a50334903c6eaf765362f6ccca88a", size = 391783, upload-time = "2025-10-14T15:05:03.052Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bc/11/fc2502457e0bea39a5c958d86d2cb69e407a4d00b85735ca724bfa6e0d1a/watchfiles-1.1.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:14e0b1fe858430fc0251737ef3824c54027bedb8c37c38114488b8e131cf8219", size = 449279, upload-time = "2025-10-14T15:05:04.004Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e3/1f/d66bc15ea0b728df3ed96a539c777acfcad0eb78555ad9efcaa1274688f0/watchfiles-1.1.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f27db948078f3823a6bb3b465180db8ebecf26dd5dae6f6180bd87383b6b4428", size = 459405, upload-time = "2025-10-14T15:05:04.942Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/be/90/9f4a65c0aec3ccf032703e6db02d89a157462fbb2cf20dd415128251cac0/watchfiles-1.1.1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:059098c3a429f62fc98e8ec62b982230ef2c8df68c79e826e37b895bc359a9c0", size = 488976, upload-time = "2025-10-14T15:05:05.905Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/37/57/ee347af605d867f712be7029bb94c8c071732a4b44792e3176fa3c612d39/watchfiles-1.1.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bfb5862016acc9b869bb57284e6cb35fdf8e22fe59f7548858e2f971d045f150", size = 595506, upload-time = "2025-10-14T15:05:06.906Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a8/78/cc5ab0b86c122047f75e8fc471c67a04dee395daf847d3e59381996c8707/watchfiles-1.1.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:319b27255aacd9923b8a276bb14d21a5f7ff82564c744235fc5eae58d95422ae", size = 474936, upload-time = "2025-10-14T15:05:07.906Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/62/da/def65b170a3815af7bd40a3e7010bf6ab53089ef1b75d05dd5385b87cf08/watchfiles-1.1.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c755367e51db90e75b19454b680903631d41f9e3607fbd941d296a020c2d752d", size = 456147, upload-time = "2025-10-14T15:05:09.138Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/57/99/da6573ba71166e82d288d4df0839128004c67d2778d3b566c138695f5c0b/watchfiles-1.1.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:c22c776292a23bfc7237a98f791b9ad3144b02116ff10d820829ce62dff46d0b", size = 630007, upload-time = "2025-10-14T15:05:10.117Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a8/51/7439c4dd39511368849eb1e53279cd3454b4a4dbace80bab88feeb83c6b5/watchfiles-1.1.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:3a476189be23c3686bc2f4321dd501cb329c0a0469e77b7b534ee10129ae6374", size = 622280, upload-time = "2025-10-14T15:05:11.146Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/95/9c/8ed97d4bba5db6fdcdb2b298d3898f2dd5c20f6b73aee04eabe56c59677e/watchfiles-1.1.1-cp313-cp313-win32.whl", hash = "sha256:bf0a91bfb5574a2f7fc223cf95eeea79abfefa404bf1ea5e339c0c1560ae99a0", size = 272056, upload-time = "2025-10-14T15:05:12.156Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1f/f3/c14e28429f744a260d8ceae18bf58c1d5fa56b50d006a7a9f80e1882cb0d/watchfiles-1.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:52e06553899e11e8074503c8e716d574adeeb7e68913115c4b3653c53f9bae42", size = 288162, upload-time = "2025-10-14T15:05:13.208Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/dc/61/fe0e56c40d5cd29523e398d31153218718c5786b5e636d9ae8ae79453d27/watchfiles-1.1.1-cp313-cp313-win_arm64.whl", hash = "sha256:ac3cc5759570cd02662b15fbcd9d917f7ecd47efe0d6b40474eafd246f91ea18", size = 277909, upload-time = "2025-10-14T15:05:14.49Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/42/e0a7d749626f1e28c7108a99fb9bf524b501bbbeb9b261ceecde644d5a07/watchfiles-1.1.1-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:563b116874a9a7ce6f96f87cd0b94f7faf92d08d0021e837796f0a14318ef8da", size = 403389, upload-time = "2025-10-14T15:05:15.777Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/15/49/08732f90ce0fbbc13913f9f215c689cfc9ced345fb1bcd8829a50007cc8d/watchfiles-1.1.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:3ad9fe1dae4ab4212d8c91e80b832425e24f421703b5a42ef2e4a1e215aff051", size = 389964, upload-time = "2025-10-14T15:05:16.85Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/27/0d/7c315d4bd5f2538910491a0393c56bf70d333d51bc5b34bee8e68e8cea19/watchfiles-1.1.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce70f96a46b894b36eba678f153f052967a0d06d5b5a19b336ab0dbbd029f73e", size = 448114, upload-time = "2025-10-14T15:05:17.876Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c3/24/9e096de47a4d11bc4df41e9d1e61776393eac4cb6eb11b3e23315b78b2cc/watchfiles-1.1.1-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cb467c999c2eff23a6417e58d75e5828716f42ed8289fe6b77a7e5a91036ca70", size = 460264, upload-time = "2025-10-14T15:05:18.962Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cc/0f/e8dea6375f1d3ba5fcb0b3583e2b493e77379834c74fd5a22d66d85d6540/watchfiles-1.1.1-cp313-cp313t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:836398932192dae4146c8f6f737d74baeac8b70ce14831a239bdb1ca882fc261", size = 487877, upload-time = "2025-10-14T15:05:20.094Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ac/5b/df24cfc6424a12deb41503b64d42fbea6b8cb357ec62ca84a5a3476f654a/watchfiles-1.1.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:743185e7372b7bc7c389e1badcc606931a827112fbbd37f14c537320fca08620", size = 595176, upload-time = "2025-10-14T15:05:21.134Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8f/b5/853b6757f7347de4e9b37e8cc3289283fb983cba1ab4d2d7144694871d9c/watchfiles-1.1.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:afaeff7696e0ad9f02cbb8f56365ff4686ab205fcf9c4c5b6fdfaaa16549dd04", size = 473577, upload-time = "2025-10-14T15:05:22.306Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e1/f7/0a4467be0a56e80447c8529c9fce5b38eab4f513cb3d9bf82e7392a5696b/watchfiles-1.1.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f7eb7da0eb23aa2ba036d4f616d46906013a68caf61b7fdbe42fc8b25132e77", size = 455425, upload-time = "2025-10-14T15:05:23.348Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/e0/82583485ea00137ddf69bc84a2db88bd92ab4a6e3c405e5fb878ead8d0e7/watchfiles-1.1.1-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:831a62658609f0e5c64178211c942ace999517f5770fe9436be4c2faeba0c0ef", size = 628826, upload-time = "2025-10-14T15:05:24.398Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/28/9a/a785356fccf9fae84c0cc90570f11702ae9571036fb25932f1242c82191c/watchfiles-1.1.1-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:f9a2ae5c91cecc9edd47e041a930490c31c3afb1f5e6d71de3dc671bfaca02bf", size = 622208, upload-time = "2025-10-14T15:05:25.45Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c3/f4/0872229324ef69b2c3edec35e84bd57a1289e7d3fe74588048ed8947a323/watchfiles-1.1.1-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:d1715143123baeeaeadec0528bb7441103979a1d5f6fd0e1f915383fea7ea6d5", size = 404315, upload-time = "2025-10-14T15:05:26.501Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7b/22/16d5331eaed1cb107b873f6ae1b69e9ced582fcf0c59a50cd84f403b1c32/watchfiles-1.1.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:39574d6370c4579d7f5d0ad940ce5b20db0e4117444e39b6d8f99db5676c52fd", size = 390869, upload-time = "2025-10-14T15:05:27.649Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b2/7e/5643bfff5acb6539b18483128fdc0ef2cccc94a5b8fbda130c823e8ed636/watchfiles-1.1.1-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7365b92c2e69ee952902e8f70f3ba6360d0d596d9299d55d7d386df84b6941fb", size = 449919, upload-time = "2025-10-14T15:05:28.701Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/51/2e/c410993ba5025a9f9357c376f48976ef0e1b1aefb73b97a5ae01a5972755/watchfiles-1.1.1-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bfff9740c69c0e4ed32416f013f3c45e2ae42ccedd1167ef2d805c000b6c71a5", size = 460845, upload-time = "2025-10-14T15:05:30.064Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/a4/2df3b404469122e8680f0fcd06079317e48db58a2da2950fb45020947734/watchfiles-1.1.1-cp314-cp314-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b27cf2eb1dda37b2089e3907d8ea92922b673c0c427886d4edc6b94d8dfe5db3", size = 489027, upload-time = "2025-10-14T15:05:31.064Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ea/84/4587ba5b1f267167ee715b7f66e6382cca6938e0a4b870adad93e44747e6/watchfiles-1.1.1-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:526e86aced14a65a5b0ec50827c745597c782ff46b571dbfe46192ab9e0b3c33", size = 595615, upload-time = "2025-10-14T15:05:32.074Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/0f/c6988c91d06e93cd0bb3d4a808bcf32375ca1904609835c3031799e3ecae/watchfiles-1.1.1-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04e78dd0b6352db95507fd8cb46f39d185cf8c74e4cf1e4fbad1d3df96faf510", size = 474836, upload-time = "2025-10-14T15:05:33.209Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/36/ded8aebea91919485b7bbabbd14f5f359326cb5ec218cd67074d1e426d74/watchfiles-1.1.1-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c85794a4cfa094714fb9c08d4a218375b2b95b8ed1666e8677c349906246c05", size = 455099, upload-time = "2025-10-14T15:05:34.189Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/98/e0/8c9bdba88af756a2fce230dd365fab2baf927ba42cd47521ee7498fd5211/watchfiles-1.1.1-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:74d5012b7630714b66be7b7b7a78855ef7ad58e8650c73afc4c076a1f480a8d6", size = 630626, upload-time = "2025-10-14T15:05:35.216Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/84/a95db05354bf2d19e438520d92a8ca475e578c647f78f53197f5a2f17aaf/watchfiles-1.1.1-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:8fbe85cb3201c7d380d3d0b90e63d520f15d6afe217165d7f98c9c649654db81", size = 622519, upload-time = "2025-10-14T15:05:36.259Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1d/ce/d8acdc8de545de995c339be67711e474c77d643555a9bb74a9334252bd55/watchfiles-1.1.1-cp314-cp314-win32.whl", hash = "sha256:3fa0b59c92278b5a7800d3ee7733da9d096d4aabcfabb9a928918bd276ef9b9b", size = 272078, upload-time = "2025-10-14T15:05:37.63Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c4/c9/a74487f72d0451524be827e8edec251da0cc1fcf111646a511ae752e1a3d/watchfiles-1.1.1-cp314-cp314-win_amd64.whl", hash = "sha256:c2047d0b6cea13b3316bdbafbfa0c4228ae593d995030fda39089d36e64fc03a", size = 287664, upload-time = "2025-10-14T15:05:38.95Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/df/b8/8ac000702cdd496cdce998c6f4ee0ca1f15977bba51bdf07d872ebdfc34c/watchfiles-1.1.1-cp314-cp314-win_arm64.whl", hash = "sha256:842178b126593addc05acf6fce960d28bc5fae7afbaa2c6c1b3a7b9460e5be02", size = 277154, upload-time = "2025-10-14T15:05:39.954Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/47/a8/e3af2184707c29f0f14b1963c0aace6529f9d1b8582d5b99f31bbf42f59e/watchfiles-1.1.1-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:88863fbbc1a7312972f1c511f202eb30866370ebb8493aef2812b9ff28156a21", size = 403820, upload-time = "2025-10-14T15:05:40.932Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c0/ec/e47e307c2f4bd75f9f9e8afbe3876679b18e1bcec449beca132a1c5ffb2d/watchfiles-1.1.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:55c7475190662e202c08c6c0f4d9e345a29367438cf8e8037f3155e10a88d5a5", size = 390510, upload-time = "2025-10-14T15:05:41.945Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d5/a0/ad235642118090f66e7b2f18fd5c42082418404a79205cdfca50b6309c13/watchfiles-1.1.1-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f53fa183d53a1d7a8852277c92b967ae99c2d4dcee2bfacff8868e6e30b15f7", size = 448408, upload-time = "2025-10-14T15:05:43.385Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/df/85/97fa10fd5ff3332ae17e7e40e20784e419e28521549780869f1413742e9d/watchfiles-1.1.1-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6aae418a8b323732fa89721d86f39ec8f092fc2af67f4217a2b07fd3e93c6101", size = 458968, upload-time = "2025-10-14T15:05:44.404Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/47/c2/9059c2e8966ea5ce678166617a7f75ecba6164375f3b288e50a40dc6d489/watchfiles-1.1.1-cp314-cp314t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f096076119da54a6080e8920cbdaac3dbee667eb91dcc5e5b78840b87415bd44", size = 488096, upload-time = "2025-10-14T15:05:45.398Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/94/44/d90a9ec8ac309bc26db808a13e7bfc0e4e78b6fc051078a554e132e80160/watchfiles-1.1.1-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:00485f441d183717038ed2e887a7c868154f216877653121068107b227a2f64c", size = 596040, upload-time = "2025-10-14T15:05:46.502Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/95/68/4e3479b20ca305cfc561db3ed207a8a1c745ee32bf24f2026a129d0ddb6e/watchfiles-1.1.1-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a55f3e9e493158d7bfdb60a1165035f1cf7d320914e7b7ea83fe22c6023b58fc", size = 473847, upload-time = "2025-10-14T15:05:47.484Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4f/55/2af26693fd15165c4ff7857e38330e1b61ab8c37d15dc79118cdba115b7a/watchfiles-1.1.1-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c91ed27800188c2ae96d16e3149f199d62f86c7af5f5f4d2c61a3ed8cd3666c", size = 455072, upload-time = "2025-10-14T15:05:48.928Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/66/1d/d0d200b10c9311ec25d2273f8aad8c3ef7cc7ea11808022501811208a750/watchfiles-1.1.1-cp314-cp314t-musllinux_1_1_aarch64.whl", hash = "sha256:311ff15a0bae3714ffb603e6ba6dbfba4065ab60865d15a6ec544133bdb21099", size = 629104, upload-time = "2025-10-14T15:05:49.908Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e3/bd/fa9bb053192491b3867ba07d2343d9f2252e00811567d30ae8d0f78136fe/watchfiles-1.1.1-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:a916a2932da8f8ab582f242c065f5c81bed3462849ca79ee357dd9551b0e9b01", size = 622112, upload-time = "2025-10-14T15:05:50.941Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "websocket-client"
|
||||
version = "1.9.0"
|
||||
|
||||
Reference in New Issue
Block a user