Remove deprecated spikes for UI, audio capture, ASR latency, and encryption
- Deleted unused spike directories and their associated files, including UI tray hotkeys, audio capture, ASR latency, and encryption implementations. - Cleaned up related demo scripts and findings documents to streamline the project structure and focus on active components. - Ensured that all removed files were no longer referenced in the codebase, maintaining a clean repository.
This commit is contained in:
@@ -1 +0,0 @@
|
||||
"""NoteFlow M0 de-risking spikes."""
|
||||
@@ -1,109 +0,0 @@
|
||||
# Spike 1: UI + Tray + Hotkeys - FINDINGS
|
||||
|
||||
## Status: Implementation Complete, Requires Display Server
|
||||
|
||||
## System Requirements
|
||||
|
||||
**X11 or Wayland display server is required** for pystray and pynput:
|
||||
|
||||
```bash
|
||||
# pystray on Linux requires X11 or GTK AppIndicator
|
||||
# pynput requires X11 ($DISPLAY must be set)
|
||||
|
||||
# Running from terminal with display:
|
||||
export DISPLAY=:0 # If not already set
|
||||
python -m spikes.spike_01_ui_tray_hotkeys.demo
|
||||
```
|
||||
|
||||
## Implementation Summary
|
||||
|
||||
### Files Created
|
||||
- `protocols.py` - Defines TrayController, HotkeyManager, Notifier protocols
|
||||
- `tray_impl.py` - PystrayController implementation with icon states
|
||||
- `hotkey_impl.py` - PynputHotkeyManager for global hotkeys
|
||||
- `demo.py` - Interactive Flet + pystray demo
|
||||
|
||||
### Key Design Decisions
|
||||
|
||||
1. **Flet for UI**: Modern Python UI framework with hot reload
|
||||
2. **pystray for Tray**: Cross-platform system tray (separate thread)
|
||||
3. **pynput for Hotkeys**: Cross-platform global hotkey capture
|
||||
4. **Queue Communication**: Thread-safe event passing between tray and UI
|
||||
|
||||
### Architecture: Flet + pystray Integration
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────┐
|
||||
│ Main Thread │
|
||||
│ ┌─────────────────────────────────┐ │
|
||||
│ │ Flet Event Loop │ │
|
||||
│ │ - UI rendering │ │
|
||||
│ │ - Event polling (100ms) │ │
|
||||
│ │ - State updates │ │
|
||||
│ └─────────────────────────────────┘ │
|
||||
│ ▲ │
|
||||
│ │ Queue │
|
||||
│ │ │
|
||||
└───────────────────┼─────────────────────┘
|
||||
│
|
||||
┌───────────────────┼─────────────────────┐
|
||||
│ ┌────────────────▼────────────────┐ │
|
||||
│ │ Event Queue │ │
|
||||
│ │ - "toggle" -> toggle state │ │
|
||||
│ │ - "quit" -> cleanup + exit │ │
|
||||
│ └────────────────┬────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌────────────────┴────────────────┐ │
|
||||
│ │ pystray Thread (daemon) │ │
|
||||
│ │ pynput Thread (daemon) │ │
|
||||
│ │ - Tray icon & menu │ │
|
||||
│ │ - Global hotkey listener │ │
|
||||
│ └─────────────────────────────────┘ │
|
||||
│ Background Threads │
|
||||
└─────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Exit Criteria Status
|
||||
|
||||
- [x] Protocol definitions complete
|
||||
- [x] Implementation complete
|
||||
- [ ] Flet window opens and displays controls (requires display)
|
||||
- [ ] System tray icon appears on Linux (requires X11)
|
||||
- [ ] Tray menu has working items (requires X11)
|
||||
- [ ] Global hotkey works when window not focused (requires X11)
|
||||
- [ ] Notifications display (requires X11)
|
||||
|
||||
### Cross-Platform Notes
|
||||
|
||||
- **Linux**: Requires X11 or AppIndicator; Wayland support limited
|
||||
- **macOS**: Requires Accessibility permissions for global hotkeys
|
||||
- System Preferences > Privacy & Security > Accessibility
|
||||
- Add Terminal or the app to allowed list
|
||||
- **Windows**: Should work out of box
|
||||
|
||||
### Running the Demo
|
||||
|
||||
With a display server running:
|
||||
|
||||
```bash
|
||||
python -m spikes.spike_01_ui_tray_hotkeys.demo
|
||||
```
|
||||
|
||||
Features:
|
||||
- Flet window with Start/Stop recording buttons
|
||||
- System tray icon (gray = idle, red = recording)
|
||||
- Global hotkey: Ctrl+Shift+R to toggle
|
||||
- Notifications on state changes
|
||||
|
||||
### Known Limitations
|
||||
|
||||
1. **pystray Threading**: Must run in separate thread, communicate via queue
|
||||
2. **pynput on macOS**: Marked "experimental" - may require Accessibility permissions
|
||||
3. **Wayland**: pynput only receives events from X11 apps via Xwayland
|
||||
|
||||
### Next Steps
|
||||
|
||||
1. Test with X11 display server
|
||||
2. Verify cross-platform behavior
|
||||
3. Add window hide-to-tray functionality
|
||||
4. Implement notification action buttons
|
||||
@@ -1 +0,0 @@
|
||||
"""Spike 1: UI + Tray + Hotkeys validation."""
|
||||
@@ -1,253 +0,0 @@
|
||||
"""Interactive UI + Tray + Hotkeys demo for Spike 1.
|
||||
|
||||
Run with: python -m spikes.spike_01_ui_tray_hotkeys.demo
|
||||
|
||||
Features:
|
||||
- Flet window with Start/Stop buttons
|
||||
- System tray icon with context menu
|
||||
- Global hotkey support (Ctrl+Shift+R)
|
||||
- Notifications on state changes
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import queue
|
||||
import sys
|
||||
import threading
|
||||
from enum import Enum, auto
|
||||
|
||||
import flet as ft
|
||||
|
||||
from .hotkey_impl import PynputHotkeyManager
|
||||
from .protocols import TrayIcon, TrayMenuItem
|
||||
from .tray_impl import PystrayController
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppState(Enum):
|
||||
"""Application state."""
|
||||
|
||||
IDLE = auto()
|
||||
RECORDING = auto()
|
||||
|
||||
|
||||
class NoteFlowDemo:
|
||||
"""Demo application combining Flet UI, system tray, and hotkeys."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the demo application."""
|
||||
self.state = AppState.IDLE
|
||||
self.tray = PystrayController(app_name="NoteFlow Demo")
|
||||
self.hotkey_manager = PynputHotkeyManager()
|
||||
|
||||
# Queue for cross-thread communication
|
||||
self._event_queue: queue.Queue[str] = queue.Queue()
|
||||
|
||||
# Flet page reference (set when app starts)
|
||||
self._page: ft.Page | None = None
|
||||
self._status_text: ft.Text | None = None
|
||||
self._toggle_button: ft.ElevatedButton | None = None
|
||||
|
||||
def _update_ui(self) -> None:
|
||||
"""Update UI elements based on current state."""
|
||||
if self._page is None:
|
||||
return
|
||||
|
||||
if self.state == AppState.RECORDING:
|
||||
if self._status_text:
|
||||
self._status_text.value = "Recording..."
|
||||
self._status_text.color = ft.Colors.RED
|
||||
if self._toggle_button:
|
||||
self._toggle_button.text = "Stop Recording"
|
||||
self._toggle_button.bgcolor = ft.Colors.RED
|
||||
self.tray.set_icon(TrayIcon.RECORDING)
|
||||
self.tray.set_tooltip("NoteFlow - Recording")
|
||||
else:
|
||||
if self._status_text:
|
||||
self._status_text.value = "Idle"
|
||||
self._status_text.color = ft.Colors.GREY
|
||||
if self._toggle_button:
|
||||
self._toggle_button.text = "Start Recording"
|
||||
self._toggle_button.bgcolor = ft.Colors.BLUE
|
||||
self.tray.set_icon(TrayIcon.IDLE)
|
||||
self.tray.set_tooltip("NoteFlow - Idle")
|
||||
|
||||
self._page.update()
|
||||
|
||||
def _toggle_recording(self) -> None:
|
||||
"""Toggle recording state."""
|
||||
if self.state == AppState.IDLE:
|
||||
self.state = AppState.RECORDING
|
||||
logger.info("Started recording")
|
||||
self.tray.notify("NoteFlow", "Recording started")
|
||||
else:
|
||||
self.state = AppState.IDLE
|
||||
logger.info("Stopped recording")
|
||||
self.tray.notify("NoteFlow", "Recording stopped")
|
||||
|
||||
self._update_ui()
|
||||
|
||||
def _on_toggle_click(self, e: ft.ControlEvent) -> None:
|
||||
"""Handle toggle button click."""
|
||||
self._toggle_recording()
|
||||
|
||||
def _on_hotkey(self) -> None:
|
||||
"""Handle global hotkey press."""
|
||||
logger.info("Hotkey pressed!")
|
||||
# Queue event for main thread
|
||||
self._event_queue.put("toggle")
|
||||
|
||||
def _process_events(self) -> None:
|
||||
"""Process queued events (called periodically from UI thread)."""
|
||||
try:
|
||||
while True:
|
||||
event = self._event_queue.get_nowait()
|
||||
if event == "toggle":
|
||||
self._toggle_recording()
|
||||
elif event == "quit":
|
||||
self._cleanup()
|
||||
if self._page:
|
||||
self._page.window.close()
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
def _setup_tray_menu(self) -> None:
|
||||
"""Set up the system tray context menu."""
|
||||
menu_items = [
|
||||
TrayMenuItem(
|
||||
label="Start Recording" if self.state == AppState.IDLE else "Stop Recording",
|
||||
callback=self._toggle_recording,
|
||||
),
|
||||
TrayMenuItem(label="", callback=lambda: None, separator=True),
|
||||
TrayMenuItem(
|
||||
label="Show Window",
|
||||
callback=lambda: self._event_queue.put("show"),
|
||||
),
|
||||
TrayMenuItem(label="", callback=lambda: None, separator=True),
|
||||
TrayMenuItem(
|
||||
label="Quit",
|
||||
callback=lambda: self._event_queue.put("quit"),
|
||||
),
|
||||
]
|
||||
self.tray.set_menu(menu_items)
|
||||
|
||||
def _cleanup(self) -> None:
|
||||
"""Clean up resources."""
|
||||
self.hotkey_manager.unregister_all()
|
||||
self.tray.stop()
|
||||
|
||||
def _build_ui(self, page: ft.Page) -> None:
|
||||
"""Build the Flet UI."""
|
||||
self._page = page
|
||||
page.title = "NoteFlow Demo - Spike 1"
|
||||
page.window.width = 400
|
||||
page.window.height = 300
|
||||
page.theme_mode = ft.ThemeMode.DARK
|
||||
|
||||
# Status text
|
||||
self._status_text = ft.Text(
|
||||
value="Idle",
|
||||
size=24,
|
||||
weight=ft.FontWeight.BOLD,
|
||||
color=ft.Colors.GREY,
|
||||
)
|
||||
|
||||
# Toggle button
|
||||
self._toggle_button = ft.ElevatedButton(
|
||||
text="Start Recording",
|
||||
icon=ft.Icons.MIC,
|
||||
on_click=self._on_toggle_click,
|
||||
bgcolor=ft.Colors.BLUE,
|
||||
color=ft.Colors.WHITE,
|
||||
width=200,
|
||||
height=50,
|
||||
)
|
||||
|
||||
# Hotkey info
|
||||
hotkey_text = ft.Text(
|
||||
value="Hotkey: Ctrl+Shift+R",
|
||||
size=14,
|
||||
color=ft.Colors.GREY_400,
|
||||
)
|
||||
|
||||
# Layout
|
||||
page.add(
|
||||
ft.Column(
|
||||
controls=[
|
||||
ft.Container(height=30),
|
||||
self._status_text,
|
||||
ft.Container(height=20),
|
||||
self._toggle_button,
|
||||
ft.Container(height=30),
|
||||
hotkey_text,
|
||||
ft.Text(
|
||||
value="System tray icon is active",
|
||||
size=12,
|
||||
color=ft.Colors.GREY_600,
|
||||
),
|
||||
],
|
||||
horizontal_alignment=ft.CrossAxisAlignment.CENTER,
|
||||
alignment=ft.MainAxisAlignment.CENTER,
|
||||
)
|
||||
)
|
||||
|
||||
# Set up event polling
|
||||
def poll_events() -> None:
|
||||
self._process_events()
|
||||
|
||||
# Poll events every 100ms
|
||||
page.run_task(self._poll_loop)
|
||||
|
||||
async def _poll_loop(self) -> None:
|
||||
"""Async loop to poll events."""
|
||||
import asyncio
|
||||
|
||||
while True:
|
||||
self._process_events()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
def run(self) -> None:
|
||||
"""Run the demo application."""
|
||||
logger.info("Starting NoteFlow Demo")
|
||||
|
||||
# Start system tray
|
||||
self.tray.start()
|
||||
self._setup_tray_menu()
|
||||
|
||||
# Register global hotkey
|
||||
try:
|
||||
self.hotkey_manager.register("ctrl+shift+r", self._on_hotkey)
|
||||
logger.info("Registered hotkey: Ctrl+Shift+R")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to register hotkey: %s", e)
|
||||
|
||||
try:
|
||||
# Run Flet app
|
||||
ft.app(target=self._build_ui)
|
||||
finally:
|
||||
self._cleanup()
|
||||
logger.info("Demo ended")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Run the UI + Tray + Hotkeys demo."""
|
||||
print("=== NoteFlow Demo - Spike 1 ===")
|
||||
print("Features:")
|
||||
print(" - Flet window with Start/Stop buttons")
|
||||
print(" - System tray icon with context menu")
|
||||
print(" - Global hotkey: Ctrl+Shift+R")
|
||||
print()
|
||||
|
||||
demo = NoteFlowDemo()
|
||||
demo.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,149 +0,0 @@
|
||||
"""Global hotkey implementation using pynput.
|
||||
|
||||
Provides cross-platform global hotkey registration and callback handling.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .protocols import HotkeyCallback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pynput import keyboard
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PynputHotkeyManager:
|
||||
"""pynput-based global hotkey manager.
|
||||
|
||||
Uses pynput.keyboard.GlobalHotKeys for cross-platform hotkey support.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the hotkey manager."""
|
||||
self._hotkeys: dict[str, tuple[str, HotkeyCallback]] = {} # id -> (hotkey_str, callback)
|
||||
self._listener: keyboard.GlobalHotKeys | None = None
|
||||
self._started = False
|
||||
|
||||
def _normalize_hotkey(self, hotkey: str) -> str:
|
||||
"""Normalize hotkey string to pynput format.
|
||||
|
||||
Args:
|
||||
hotkey: Hotkey string like "ctrl+shift+r".
|
||||
|
||||
Returns:
|
||||
Normalized hotkey string for pynput.
|
||||
"""
|
||||
# Convert common formats to pynput format
|
||||
# pynput uses "<ctrl>+<shift>+r" format
|
||||
parts = hotkey.lower().replace(" ", "").split("+")
|
||||
normalized_parts: list[str] = []
|
||||
|
||||
for part in parts:
|
||||
if part in ("ctrl", "control"):
|
||||
normalized_parts.append("<ctrl>")
|
||||
elif part in ("shift",):
|
||||
normalized_parts.append("<shift>")
|
||||
elif part in ("alt", "option"):
|
||||
normalized_parts.append("<alt>")
|
||||
elif part in ("cmd", "command", "meta", "win", "super"):
|
||||
normalized_parts.append("<cmd>")
|
||||
else:
|
||||
normalized_parts.append(part)
|
||||
|
||||
return "+".join(normalized_parts)
|
||||
|
||||
def _rebuild_listener(self) -> None:
|
||||
"""Rebuild the hotkey listener with current registrations."""
|
||||
from pynput import keyboard
|
||||
|
||||
# Stop existing listener
|
||||
if self._listener is not None:
|
||||
self._listener.stop()
|
||||
self._listener = None
|
||||
|
||||
if not self._hotkeys:
|
||||
return
|
||||
|
||||
# Build hotkey dict for pynput
|
||||
hotkey_dict: dict[str, HotkeyCallback] = {}
|
||||
for reg_id, (hotkey_str, callback) in self._hotkeys.items():
|
||||
normalized = self._normalize_hotkey(hotkey_str)
|
||||
hotkey_dict[normalized] = callback
|
||||
logger.debug("Registered hotkey: %s -> %s", hotkey_str, normalized)
|
||||
|
||||
# Create and start new listener
|
||||
self._listener = keyboard.GlobalHotKeys(hotkey_dict)
|
||||
self._listener.start()
|
||||
self._started = True
|
||||
|
||||
def register(self, hotkey: str, callback: HotkeyCallback) -> str:
|
||||
"""Register a global hotkey.
|
||||
|
||||
Args:
|
||||
hotkey: Hotkey string (e.g., "ctrl+shift+r").
|
||||
callback: Function to call when hotkey is pressed.
|
||||
|
||||
Returns:
|
||||
Registration ID for later unregistration.
|
||||
|
||||
Raises:
|
||||
ValueError: If hotkey string is invalid.
|
||||
"""
|
||||
if not hotkey or not hotkey.strip():
|
||||
raise ValueError("Hotkey string cannot be empty")
|
||||
|
||||
# Generate unique registration ID
|
||||
reg_id = str(uuid.uuid4())
|
||||
|
||||
self._hotkeys[reg_id] = (hotkey, callback)
|
||||
self._rebuild_listener()
|
||||
|
||||
logger.info("Registered hotkey '%s' with id %s", hotkey, reg_id)
|
||||
return reg_id
|
||||
|
||||
def unregister(self, registration_id: str) -> None:
|
||||
"""Unregister a previously registered hotkey.
|
||||
|
||||
Args:
|
||||
registration_id: ID returned from register().
|
||||
|
||||
Safe to call with invalid ID (no-op).
|
||||
"""
|
||||
if registration_id not in self._hotkeys:
|
||||
return
|
||||
|
||||
hotkey_str, _ = self._hotkeys.pop(registration_id)
|
||||
self._rebuild_listener()
|
||||
logger.info("Unregistered hotkey '%s'", hotkey_str)
|
||||
|
||||
def unregister_all(self) -> None:
|
||||
"""Unregister all registered hotkeys."""
|
||||
self._hotkeys.clear()
|
||||
if self._listener is not None:
|
||||
self._listener.stop()
|
||||
self._listener = None
|
||||
self._started = False
|
||||
logger.info("Unregistered all hotkeys")
|
||||
|
||||
def is_supported(self) -> bool:
|
||||
"""Check if global hotkeys are supported on this platform.
|
||||
|
||||
Returns:
|
||||
True if hotkeys can be registered.
|
||||
"""
|
||||
try:
|
||||
from pynput import keyboard # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def registered_count(self) -> int:
|
||||
"""Get the number of registered hotkeys."""
|
||||
return len(self._hotkeys)
|
||||
@@ -1,173 +0,0 @@
|
||||
"""UI, System Tray, and Hotkey protocols for Spike 1.
|
||||
|
||||
These protocols define the contracts for platform abstraction components
|
||||
that will be promoted to src/noteflow/platform/ after validation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class TrayIcon(Enum):
|
||||
"""System tray icon states."""
|
||||
|
||||
IDLE = auto()
|
||||
RECORDING = auto()
|
||||
PAUSED = auto()
|
||||
ERROR = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrayMenuItem:
|
||||
"""A menu item for the system tray context menu."""
|
||||
|
||||
label: str
|
||||
callback: Callable[[], None]
|
||||
enabled: bool = True
|
||||
checked: bool = False
|
||||
separator: bool = False
|
||||
|
||||
|
||||
class TrayController(Protocol):
|
||||
"""Protocol for system tray/menubar icon controller.
|
||||
|
||||
Implementations should handle cross-platform tray icon display
|
||||
and menu management.
|
||||
"""
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the tray icon.
|
||||
|
||||
May run in a separate thread depending on implementation.
|
||||
"""
|
||||
...
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop and remove the tray icon."""
|
||||
...
|
||||
|
||||
def set_icon(self, icon: TrayIcon) -> None:
|
||||
"""Update the tray icon state.
|
||||
|
||||
Args:
|
||||
icon: New icon state to display.
|
||||
"""
|
||||
...
|
||||
|
||||
def set_menu(self, items: list[TrayMenuItem]) -> None:
|
||||
"""Update the tray context menu items.
|
||||
|
||||
Args:
|
||||
items: List of menu items to display.
|
||||
"""
|
||||
...
|
||||
|
||||
def set_tooltip(self, text: str) -> None:
|
||||
"""Update the tray icon tooltip.
|
||||
|
||||
Args:
|
||||
text: Tooltip text to display on hover.
|
||||
"""
|
||||
...
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the tray icon is running.
|
||||
|
||||
Returns:
|
||||
True if tray is active.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# Type alias for hotkey callback
|
||||
HotkeyCallback = Callable[[], None]
|
||||
|
||||
|
||||
class HotkeyManager(Protocol):
|
||||
"""Protocol for global hotkey registration.
|
||||
|
||||
Implementations should handle cross-platform global hotkey capture.
|
||||
"""
|
||||
|
||||
def register(self, hotkey: str, callback: HotkeyCallback) -> str:
|
||||
"""Register a global hotkey.
|
||||
|
||||
Args:
|
||||
hotkey: Hotkey string (e.g., "ctrl+shift+r").
|
||||
callback: Function to call when hotkey is pressed.
|
||||
|
||||
Returns:
|
||||
Registration ID for later unregistration.
|
||||
|
||||
Raises:
|
||||
ValueError: If hotkey string is invalid.
|
||||
RuntimeError: If hotkey is already registered by another app.
|
||||
"""
|
||||
...
|
||||
|
||||
def unregister(self, registration_id: str) -> None:
|
||||
"""Unregister a previously registered hotkey.
|
||||
|
||||
Args:
|
||||
registration_id: ID returned from register().
|
||||
|
||||
Safe to call with invalid ID (no-op).
|
||||
"""
|
||||
...
|
||||
|
||||
def unregister_all(self) -> None:
|
||||
"""Unregister all registered hotkeys."""
|
||||
...
|
||||
|
||||
def is_supported(self) -> bool:
|
||||
"""Check if global hotkeys are supported on this platform.
|
||||
|
||||
Returns:
|
||||
True if hotkeys can be registered.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class Notifier(Protocol):
|
||||
"""Protocol for OS notifications.
|
||||
|
||||
Implementations should handle cross-platform notification display.
|
||||
"""
|
||||
|
||||
def notify(
|
||||
self,
|
||||
title: str,
|
||||
body: str,
|
||||
on_click: Callable[[], None] | None = None,
|
||||
timeout_ms: int = 5000,
|
||||
) -> None:
|
||||
"""Show a notification.
|
||||
|
||||
Args:
|
||||
title: Notification title.
|
||||
body: Notification body text.
|
||||
on_click: Optional callback when notification is clicked.
|
||||
timeout_ms: How long to show notification (platform-dependent).
|
||||
"""
|
||||
...
|
||||
|
||||
def prompt(
|
||||
self,
|
||||
title: str,
|
||||
body: str,
|
||||
actions: list[tuple[str, Callable[[], None]]],
|
||||
) -> None:
|
||||
"""Show an actionable notification prompt.
|
||||
|
||||
Args:
|
||||
title: Notification title.
|
||||
body: Notification body text.
|
||||
actions: List of (button_label, callback) tuples.
|
||||
|
||||
Note: Platform support for action buttons varies.
|
||||
"""
|
||||
...
|
||||
@@ -1,261 +0,0 @@
|
||||
"""System tray implementation using pystray.
|
||||
|
||||
Provides cross-platform system tray icon with context menu.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Protocol
|
||||
|
||||
import pystray
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from .protocols import TrayIcon, TrayMenuItem
|
||||
|
||||
|
||||
class PystrayIcon(Protocol):
|
||||
"""Protocol for pystray Icon type."""
|
||||
|
||||
def run(self) -> None:
|
||||
"""Run the icon event loop."""
|
||||
...
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the icon."""
|
||||
...
|
||||
|
||||
@property
|
||||
def icon(self) -> Image.Image:
|
||||
"""Icon image."""
|
||||
...
|
||||
|
||||
@icon.setter
|
||||
def icon(self, value: Image.Image) -> None:
|
||||
"""Set icon image."""
|
||||
...
|
||||
|
||||
@property
|
||||
def menu(self) -> PystrayMenu:
|
||||
"""Context menu."""
|
||||
...
|
||||
|
||||
@menu.setter
|
||||
def menu(self, value: PystrayMenu) -> None:
|
||||
"""Set context menu."""
|
||||
...
|
||||
|
||||
@property
|
||||
def title(self) -> str:
|
||||
"""Tooltip title."""
|
||||
...
|
||||
|
||||
@title.setter
|
||||
def title(self, value: str) -> None:
|
||||
"""Set tooltip title."""
|
||||
...
|
||||
|
||||
def notify(self, message: str, title: str) -> None:
|
||||
"""Show notification."""
|
||||
...
|
||||
|
||||
|
||||
class PystrayMenu(Protocol):
|
||||
"""Protocol for pystray Menu type.
|
||||
|
||||
Note: SEPARATOR is a class attribute but Protocols don't support
|
||||
class attributes well, so it's omitted here.
|
||||
"""
|
||||
|
||||
def __init__(self, *items: PystrayMenuItem) -> None:
|
||||
"""Create menu with items."""
|
||||
...
|
||||
|
||||
|
||||
class PystrayMenuItem(Protocol):
|
||||
"""Protocol for pystray MenuItem type.
|
||||
|
||||
This is a minimal protocol - pystray.MenuItem will satisfy it structurally.
|
||||
"""
|
||||
|
||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||
"""Create menu item."""
|
||||
...
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_icon_image(icon_state: TrayIcon, size: int = 64) -> Image.Image:
|
||||
"""Create a simple icon image for the given state.
|
||||
|
||||
Args:
|
||||
icon_state: The icon state to visualize.
|
||||
size: Icon size in pixels.
|
||||
|
||||
Returns:
|
||||
PIL Image object.
|
||||
"""
|
||||
# Create a simple colored circle icon
|
||||
image = Image.new("RGBA", (size, size), (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
# Color based on state
|
||||
colors = {
|
||||
TrayIcon.IDLE: (100, 100, 100, 255), # Gray
|
||||
TrayIcon.RECORDING: (220, 50, 50, 255), # Red
|
||||
TrayIcon.PAUSED: (255, 165, 0, 255), # Orange
|
||||
TrayIcon.ERROR: (255, 0, 0, 255), # Bright red
|
||||
}
|
||||
color = colors.get(icon_state, (100, 100, 100, 255))
|
||||
|
||||
# Draw filled circle
|
||||
margin = size // 8
|
||||
draw.ellipse(
|
||||
[margin, margin, size - margin, size - margin],
|
||||
fill=color,
|
||||
outline=(255, 255, 255, 255),
|
||||
width=2,
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class PystrayController:
|
||||
"""pystray-based system tray controller.
|
||||
|
||||
Runs pystray in a separate thread to avoid blocking the main event loop.
|
||||
"""
|
||||
|
||||
def __init__(self, app_name: str = "NoteFlow") -> None:
|
||||
"""Initialize the tray controller.
|
||||
|
||||
Args:
|
||||
app_name: Application name for the tray icon.
|
||||
"""
|
||||
self._app_name = app_name
|
||||
self._icon: PystrayIcon | None = None
|
||||
self._thread: threading.Thread | None = None
|
||||
self._running = False
|
||||
self._current_state = TrayIcon.IDLE
|
||||
self._menu_items: list[TrayMenuItem] = []
|
||||
self._tooltip = app_name
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the tray icon in a background thread."""
|
||||
if self._running:
|
||||
logger.warning("Tray already running")
|
||||
return
|
||||
|
||||
# Create initial icon
|
||||
image = create_icon_image(self._current_state)
|
||||
|
||||
# Create menu
|
||||
menu = self._build_menu()
|
||||
|
||||
self._icon = pystray.Icon(
|
||||
name=self._app_name,
|
||||
icon=image,
|
||||
title=self._tooltip,
|
||||
menu=menu,
|
||||
)
|
||||
|
||||
# Run in background thread
|
||||
self._running = True
|
||||
self._thread = threading.Thread(target=self._run_icon, daemon=True)
|
||||
self._thread.start()
|
||||
logger.info("Tray icon started")
|
||||
|
||||
def _run_icon(self) -> None:
|
||||
"""Run the icon event loop (called in background thread)."""
|
||||
if self._icon:
|
||||
self._icon.run()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop and remove the tray icon."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
if self._icon:
|
||||
self._icon.stop()
|
||||
self._icon = None
|
||||
self._thread = None
|
||||
logger.info("Tray icon stopped")
|
||||
|
||||
def set_icon(self, icon: TrayIcon) -> None:
|
||||
"""Update the tray icon state.
|
||||
|
||||
Args:
|
||||
icon: New icon state to display.
|
||||
"""
|
||||
self._current_state = icon
|
||||
if self._icon:
|
||||
self._icon.icon = create_icon_image(icon)
|
||||
|
||||
def set_menu(self, items: list[TrayMenuItem]) -> None:
|
||||
"""Update the tray context menu items.
|
||||
|
||||
Args:
|
||||
items: List of menu items to display.
|
||||
"""
|
||||
self._menu_items = items
|
||||
if self._icon:
|
||||
self._icon.menu = self._build_menu()
|
||||
|
||||
def _build_menu(self) -> PystrayMenu:
|
||||
"""Build pystray menu from TrayMenuItem list."""
|
||||
menu_items: list[PystrayMenuItem] = []
|
||||
|
||||
for item in self._menu_items:
|
||||
if item.separator:
|
||||
menu_items.append(pystray.Menu.SEPARATOR)
|
||||
else:
|
||||
menu_items.append(
|
||||
pystray.MenuItem(
|
||||
text=item.label,
|
||||
action=item.callback,
|
||||
enabled=item.enabled,
|
||||
checked=lambda checked=item.checked: checked,
|
||||
)
|
||||
)
|
||||
|
||||
# Always add a Quit option if not present
|
||||
has_quit = any(m.label.lower() == "quit" for m in self._menu_items)
|
||||
if not has_quit:
|
||||
if menu_items:
|
||||
menu_items.append(pystray.Menu.SEPARATOR)
|
||||
menu_items.append(
|
||||
pystray.MenuItem("Quit", lambda: self.stop())
|
||||
)
|
||||
|
||||
return pystray.Menu(*menu_items)
|
||||
|
||||
def set_tooltip(self, text: str) -> None:
|
||||
"""Update the tray icon tooltip.
|
||||
|
||||
Args:
|
||||
text: Tooltip text to display on hover.
|
||||
"""
|
||||
self._tooltip = text
|
||||
if self._icon:
|
||||
self._icon.title = text
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the tray icon is running.
|
||||
|
||||
Returns:
|
||||
True if tray is active.
|
||||
"""
|
||||
return self._running
|
||||
|
||||
def notify(self, title: str, message: str) -> None:
|
||||
"""Show a notification via the tray icon.
|
||||
|
||||
Args:
|
||||
title: Notification title.
|
||||
message: Notification message.
|
||||
"""
|
||||
if self._icon:
|
||||
self._icon.notify(message, title)
|
||||
@@ -1,93 +0,0 @@
|
||||
# Spike 2: Audio Capture - FINDINGS
|
||||
|
||||
## Status: CORE COMPONENTS VALIDATED
|
||||
|
||||
PortAudio installed. Core components (RmsLevelProvider, TimestampedRingBuffer, SoundDeviceCapture) tested and working. Full validation requires audio hardware/display environment.
|
||||
|
||||
## System Requirements
|
||||
|
||||
**PortAudio library is required** for sounddevice to work:
|
||||
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt-get install -y libportaudio2 portaudio19-dev
|
||||
|
||||
# macOS (Homebrew)
|
||||
brew install portaudio
|
||||
|
||||
# Windows
|
||||
# PortAudio is bundled with the sounddevice wheel
|
||||
```
|
||||
|
||||
## Implementation Summary
|
||||
|
||||
### Files Created
|
||||
- `protocols.py` - Defines AudioCapture, AudioLevelProvider, RingBuffer protocols
|
||||
- `capture_impl.py` - SoundDeviceCapture implementation
|
||||
- `levels_impl.py` - RmsLevelProvider for VU meter
|
||||
- `ring_buffer_impl.py` - TimestampedRingBuffer for audio storage
|
||||
- `demo.py` - Interactive demo with VU meter and WAV export
|
||||
|
||||
### Key Design Decisions
|
||||
|
||||
1. **Sample Rate**: Default 16kHz for ASR compatibility
|
||||
2. **Format**: float32 normalized (-1.0 to 1.0) for processing
|
||||
3. **Chunk Size**: 100ms chunks for responsive VU meter
|
||||
4. **Ring Buffer**: 5-minute default capacity for meeting recordings
|
||||
|
||||
### Component Test Results
|
||||
|
||||
```
|
||||
=== RMS Level Provider ===
|
||||
Silent RMS: 0.0000
|
||||
Silent dB: -60.0
|
||||
Loud RMS: 0.5000
|
||||
Loud dB: -6.0
|
||||
|
||||
=== Ring Buffer ===
|
||||
Chunks: 5
|
||||
Duration: 0.50s
|
||||
Window (0.3s): 3 chunks
|
||||
|
||||
=== Audio Capture ===
|
||||
Devices found: 0 (headless - no audio hardware)
|
||||
```
|
||||
|
||||
### Exit Criteria Status
|
||||
|
||||
- [x] Protocol definitions complete
|
||||
- [x] Implementation complete
|
||||
- [x] RmsLevelProvider working (0dB to -60dB range)
|
||||
- [x] TimestampedRingBuffer working (FIFO eviction)
|
||||
- [x] SoundDeviceCapture initializes (PortAudio found)
|
||||
- [ ] Can list audio devices (requires audio hardware)
|
||||
- [ ] VU meter updates in real-time (requires audio hardware)
|
||||
- [ ] Device unplug detected (requires audio hardware)
|
||||
- [ ] Captured audio file is playable (requires audio hardware)
|
||||
|
||||
### Cross-Platform Notes
|
||||
|
||||
- **Linux**: Requires `libportaudio2` and `portaudio19-dev`
|
||||
- **macOS**: Requires Homebrew `portaudio` or similar
|
||||
- **Windows**: PortAudio bundled in sounddevice wheel - should work out of box
|
||||
|
||||
### Running the Demo
|
||||
|
||||
After installing PortAudio:
|
||||
|
||||
```bash
|
||||
python -m spikes.spike_02_audio_capture.demo
|
||||
```
|
||||
|
||||
Commands:
|
||||
- `r` - Start recording
|
||||
- `s` - Stop recording and save to output.wav
|
||||
- `l` - List devices
|
||||
- `q` - Quit
|
||||
|
||||
### Next Steps
|
||||
|
||||
1. Install PortAudio system library
|
||||
2. Run demo to validate exit criteria
|
||||
3. Test device unplug handling
|
||||
4. Measure latency characteristics
|
||||
@@ -1 +0,0 @@
|
||||
"""Spike 2: Audio capture validation."""
|
||||
@@ -1,185 +0,0 @@
|
||||
"""Audio capture implementation using sounddevice.
|
||||
|
||||
Provides cross-platform audio input capture with device handling.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import sounddevice as sd
|
||||
|
||||
from .protocols import AudioDeviceInfo, AudioFrameCallback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SoundDeviceCapture:
|
||||
"""sounddevice-based implementation of AudioCapture.
|
||||
|
||||
Handles device enumeration, stream management, and device change detection.
|
||||
Uses PortAudio under the hood for cross-platform audio capture.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the capture instance."""
|
||||
self._stream: sd.InputStream | None = None
|
||||
self._callback: AudioFrameCallback | None = None
|
||||
self._device_id: int | None = None
|
||||
self._sample_rate: int = 16000
|
||||
self._channels: int = 1
|
||||
|
||||
def list_devices(self) -> list[AudioDeviceInfo]:
|
||||
"""List available audio input devices.
|
||||
|
||||
Returns:
|
||||
List of AudioDeviceInfo for all available input devices.
|
||||
"""
|
||||
devices: list[AudioDeviceInfo] = []
|
||||
device_list = sd.query_devices()
|
||||
|
||||
# Get default input device index
|
||||
try:
|
||||
default_input = sd.default.device[0] # Input device index
|
||||
except (TypeError, IndexError):
|
||||
default_input = -1
|
||||
|
||||
devices.extend(
|
||||
AudioDeviceInfo(
|
||||
device_id=idx,
|
||||
name=dev["name"],
|
||||
channels=int(dev["max_input_channels"]),
|
||||
sample_rate=int(dev["default_samplerate"]),
|
||||
is_default=(idx == default_input),
|
||||
)
|
||||
for idx, dev in enumerate(device_list)
|
||||
if int(dev.get("max_input_channels", 0)) > 0
|
||||
)
|
||||
return devices
|
||||
|
||||
def get_default_device(self) -> AudioDeviceInfo | None:
|
||||
"""Get the default input device.
|
||||
|
||||
Returns:
|
||||
Default input device info, or None if no input devices available.
|
||||
"""
|
||||
devices = self.list_devices()
|
||||
for dev in devices:
|
||||
if dev.is_default:
|
||||
return dev
|
||||
return devices[0] if devices else None
|
||||
|
||||
def start(
|
||||
self,
|
||||
device_id: int | None,
|
||||
on_frames: AudioFrameCallback,
|
||||
sample_rate: int = 16000,
|
||||
channels: int = 1,
|
||||
chunk_duration_ms: int = 100,
|
||||
) -> None:
|
||||
"""Start capturing audio from the specified device.
|
||||
|
||||
Args:
|
||||
device_id: Device ID to capture from, or None for default device.
|
||||
on_frames: Callback receiving (frames, timestamp) for each chunk.
|
||||
sample_rate: Sample rate in Hz (default 16kHz for ASR).
|
||||
channels: Number of channels (default 1 for mono).
|
||||
chunk_duration_ms: Duration of each audio chunk in milliseconds.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If already capturing.
|
||||
ValueError: If device_id is invalid.
|
||||
"""
|
||||
if self._stream is not None:
|
||||
raise RuntimeError("Already capturing audio")
|
||||
|
||||
self._callback = on_frames
|
||||
self._device_id = device_id
|
||||
self._sample_rate = sample_rate
|
||||
self._channels = channels
|
||||
|
||||
# Calculate block size from chunk duration
|
||||
blocksize = int(sample_rate * chunk_duration_ms / 1000)
|
||||
|
||||
def _stream_callback(
|
||||
indata: NDArray[np.float32],
|
||||
frames: int,
|
||||
time_info: object, # cffi CData from sounddevice, unused
|
||||
status: sd.CallbackFlags,
|
||||
) -> None:
|
||||
"""Internal sounddevice callback."""
|
||||
if status:
|
||||
logger.warning("Audio stream status: %s", status)
|
||||
|
||||
if self._callback is not None:
|
||||
# Copy the data and flatten to 1D array
|
||||
audio_data = indata.copy().flatten().astype(np.float32)
|
||||
timestamp = time.monotonic()
|
||||
self._callback(audio_data, timestamp)
|
||||
|
||||
try:
|
||||
self._stream = sd.InputStream(
|
||||
device=device_id,
|
||||
channels=channels,
|
||||
samplerate=sample_rate,
|
||||
blocksize=blocksize,
|
||||
dtype=np.float32,
|
||||
callback=_stream_callback,
|
||||
)
|
||||
self._stream.start()
|
||||
logger.info(
|
||||
"Started audio capture: device=%s, rate=%d, channels=%d, blocksize=%d",
|
||||
device_id,
|
||||
sample_rate,
|
||||
channels,
|
||||
blocksize,
|
||||
)
|
||||
except sd.PortAudioError as e:
|
||||
self._stream = None
|
||||
self._callback = None
|
||||
raise RuntimeError(f"Failed to start audio capture: {e}") from e
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop audio capture.
|
||||
|
||||
Safe to call even if not capturing.
|
||||
"""
|
||||
if self._stream is not None:
|
||||
try:
|
||||
self._stream.stop()
|
||||
self._stream.close()
|
||||
except sd.PortAudioError as e:
|
||||
logger.warning("Error stopping audio stream: %s", e)
|
||||
finally:
|
||||
self._stream = None
|
||||
self._callback = None
|
||||
logger.info("Stopped audio capture")
|
||||
|
||||
def is_capturing(self) -> bool:
|
||||
"""Check if currently capturing audio.
|
||||
|
||||
Returns:
|
||||
True if capture is active.
|
||||
"""
|
||||
return self._stream is not None and self._stream.active
|
||||
|
||||
@property
|
||||
def current_device_id(self) -> int | None:
|
||||
"""Get the current device ID being used for capture."""
|
||||
return self._device_id
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
"""Get the current sample rate."""
|
||||
return self._sample_rate
|
||||
|
||||
@property
|
||||
def channels(self) -> int:
|
||||
"""Get the current number of channels."""
|
||||
return self._channels
|
||||
@@ -1,281 +0,0 @@
|
||||
"""Interactive audio capture demo for Spike 2.
|
||||
|
||||
Run with: python -m spikes.spike_02_audio_capture.demo
|
||||
|
||||
Features:
|
||||
- Lists available input devices on startup
|
||||
- Real-time VU meter (ASCII bar)
|
||||
- Start/Stop capture with keyboard
|
||||
- Saves captured audio to output.wav
|
||||
- Console output on device changes/errors
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from .capture_impl import SoundDeviceCapture
|
||||
from .levels_impl import RmsLevelProvider
|
||||
from .protocols import TimestampedAudio
|
||||
from .ring_buffer_impl import TimestampedRingBuffer
|
||||
|
||||
# VU meter display settings
|
||||
VU_WIDTH: Final[int] = 50
|
||||
VU_CHARS: Final[str] = "█"
|
||||
VU_EMPTY: Final[str] = "░"
|
||||
|
||||
|
||||
def draw_vu_meter(rms: float, db: float) -> str:
|
||||
"""Draw an ASCII VU meter.
|
||||
|
||||
Args:
|
||||
rms: RMS level (0.0-1.0).
|
||||
db: Level in dB.
|
||||
|
||||
Returns:
|
||||
ASCII string representation of the VU meter.
|
||||
"""
|
||||
filled = int(rms * VU_WIDTH)
|
||||
empty = VU_WIDTH - filled
|
||||
|
||||
bar = VU_CHARS * filled + VU_EMPTY * empty
|
||||
return f"[{bar}] {db:+6.1f} dB"
|
||||
|
||||
|
||||
class AudioDemo:
|
||||
"""Interactive audio capture demonstration."""
|
||||
|
||||
def __init__(self, output_path: Path, sample_rate: int = 16000) -> None:
|
||||
"""Initialize the demo.
|
||||
|
||||
Args:
|
||||
output_path: Path to save the recorded audio.
|
||||
sample_rate: Sample rate for capture.
|
||||
"""
|
||||
self.output_path = output_path
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
self.capture = SoundDeviceCapture()
|
||||
self.levels = RmsLevelProvider()
|
||||
self.buffer = TimestampedRingBuffer(max_duration=300.0) # 5 minutes
|
||||
|
||||
self.is_running = False
|
||||
self.is_recording = False
|
||||
self._lock = threading.Lock()
|
||||
self._last_rms: float = 0.0
|
||||
self._last_db: float = -60.0
|
||||
self._frames_captured: int = 0
|
||||
|
||||
def _on_audio_frames(self, frames: NDArray[np.float32], timestamp: float) -> None:
|
||||
"""Callback for incoming audio frames."""
|
||||
with self._lock:
|
||||
# Compute levels for VU meter
|
||||
self._last_rms = self.levels.get_rms(frames)
|
||||
self._last_db = self.levels.get_db(frames)
|
||||
|
||||
# Store in ring buffer
|
||||
duration = len(frames) / self.sample_rate
|
||||
audio = TimestampedAudio(frames=frames, timestamp=timestamp, duration=duration)
|
||||
self.buffer.push(audio)
|
||||
self._frames_captured += len(frames)
|
||||
|
||||
def list_devices(self) -> None:
|
||||
"""Print available audio devices."""
|
||||
print("\n=== Available Audio Input Devices ===")
|
||||
devices = self.capture.list_devices()
|
||||
|
||||
if not devices:
|
||||
print("No audio input devices found!")
|
||||
return
|
||||
|
||||
for dev in devices:
|
||||
default = " (DEFAULT)" if dev.is_default else ""
|
||||
print(f" [{dev.device_id}] {dev.name}{default}")
|
||||
print(f" Channels: {dev.channels}, Sample Rate: {dev.sample_rate} Hz")
|
||||
print()
|
||||
|
||||
def start_capture(self, device_id: int | None = None) -> bool:
|
||||
"""Start audio capture.
|
||||
|
||||
Args:
|
||||
device_id: Device ID or None for default.
|
||||
|
||||
Returns:
|
||||
True if started successfully.
|
||||
"""
|
||||
if self.is_recording:
|
||||
print("Already recording!")
|
||||
return False
|
||||
|
||||
try:
|
||||
self.buffer.clear()
|
||||
self._frames_captured = 0
|
||||
self.capture.start(
|
||||
device_id=device_id,
|
||||
on_frames=self._on_audio_frames,
|
||||
sample_rate=self.sample_rate,
|
||||
channels=1,
|
||||
chunk_duration_ms=100,
|
||||
)
|
||||
self.is_recording = True
|
||||
print("\n>>> Recording started! Press 's' to stop.")
|
||||
return True
|
||||
except RuntimeError as e:
|
||||
print(f"\nERROR: Failed to start capture: {e}")
|
||||
return False
|
||||
|
||||
def stop_capture(self) -> bool:
|
||||
"""Stop audio capture and save to file.
|
||||
|
||||
Returns:
|
||||
True if stopped and saved successfully.
|
||||
"""
|
||||
if not self.is_recording:
|
||||
print("Not recording!")
|
||||
return False
|
||||
|
||||
self.capture.stop()
|
||||
self.is_recording = False
|
||||
|
||||
# Save to WAV file
|
||||
print(f"\n>>> Recording stopped. Saving to {self.output_path}...")
|
||||
success = self._save_wav()
|
||||
if success:
|
||||
print(f">>> Saved {self._frames_captured} samples to {self.output_path}")
|
||||
return success
|
||||
|
||||
def _save_wav(self) -> bool:
|
||||
"""Save buffered audio to WAV file.
|
||||
|
||||
Returns:
|
||||
True if saved successfully.
|
||||
"""
|
||||
chunks = self.buffer.get_all()
|
||||
if not chunks:
|
||||
print("No audio to save!")
|
||||
return False
|
||||
|
||||
# Concatenate all audio
|
||||
all_frames = np.concatenate([chunk.frames for chunk in chunks])
|
||||
|
||||
# Convert to 16-bit PCM
|
||||
pcm_data = (all_frames * 32767).astype(np.int16)
|
||||
|
||||
try:
|
||||
with wave.open(str(self.output_path), "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2) # 16-bit
|
||||
wf.setframerate(self.sample_rate)
|
||||
wf.writeframes(pcm_data.tobytes())
|
||||
return True
|
||||
except OSError as e:
|
||||
print(f"ERROR: Failed to save WAV: {e}")
|
||||
return False
|
||||
|
||||
def run_vu_loop(self) -> None:
|
||||
"""Run the VU meter display loop."""
|
||||
while self.is_running:
|
||||
if self.is_recording:
|
||||
with self._lock:
|
||||
rms = self._last_rms
|
||||
db = self._last_db
|
||||
duration = self.buffer.duration
|
||||
|
||||
vu = draw_vu_meter(rms, db)
|
||||
sys.stdout.write(f"\r{vu} Duration: {duration:6.1f}s ")
|
||||
sys.stdout.flush()
|
||||
time.sleep(0.05) # 20Hz update rate
|
||||
|
||||
def run(self, device_id: int | None = None) -> None:
|
||||
"""Run the interactive demo.
|
||||
|
||||
Args:
|
||||
device_id: Device ID to use, or None for default.
|
||||
"""
|
||||
self.list_devices()
|
||||
|
||||
print("=== Audio Capture Demo ===")
|
||||
print("Commands:")
|
||||
print(" r - Start recording")
|
||||
print(" s - Stop recording and save")
|
||||
print(" l - List devices")
|
||||
print(" q - Quit")
|
||||
print()
|
||||
|
||||
self.is_running = True
|
||||
|
||||
# Start VU meter thread
|
||||
vu_thread = threading.Thread(target=self.run_vu_loop, daemon=True)
|
||||
vu_thread.start()
|
||||
|
||||
try:
|
||||
while self.is_running:
|
||||
try:
|
||||
cmd = input().strip().lower()
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
if cmd == "r":
|
||||
self.start_capture(device_id)
|
||||
elif cmd == "s":
|
||||
self.stop_capture()
|
||||
elif cmd == "l":
|
||||
self.list_devices()
|
||||
elif cmd == "q":
|
||||
if self.is_recording:
|
||||
self.stop_capture()
|
||||
self.is_running = False
|
||||
print("\nGoodbye!")
|
||||
elif cmd:
|
||||
print(f"Unknown command: {cmd}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nInterrupted!")
|
||||
if self.is_recording:
|
||||
self.stop_capture()
|
||||
finally:
|
||||
self.is_running = False
|
||||
self.capture.stop()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Run the audio capture demo."""
|
||||
parser = argparse.ArgumentParser(description="Audio Capture Demo - Spike 2")
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output",
|
||||
type=Path,
|
||||
default=Path("output.wav"),
|
||||
help="Output WAV file path (default: output.wav)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--device",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Device ID to use (default: system default)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="Sample rate in Hz (default: 16000)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
demo = AudioDemo(output_path=args.output, sample_rate=args.rate)
|
||||
demo.run(device_id=args.device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,86 +0,0 @@
|
||||
"""Audio level computation implementation.
|
||||
|
||||
Provides RMS and dB level calculation for VU meter display.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Final
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
||||
class RmsLevelProvider:
|
||||
"""RMS-based audio level provider.
|
||||
|
||||
Computes RMS (Root Mean Square) level from audio frames for VU meter display.
|
||||
"""
|
||||
|
||||
# Minimum dB value to report (silence threshold)
|
||||
MIN_DB: Final[float] = -60.0
|
||||
|
||||
def get_rms(self, frames: NDArray[np.float32]) -> float:
|
||||
"""Calculate RMS level from audio frames.
|
||||
|
||||
Args:
|
||||
frames: Audio samples as float32 array (normalized -1.0 to 1.0).
|
||||
|
||||
Returns:
|
||||
RMS level normalized to 0.0-1.0 range.
|
||||
"""
|
||||
if len(frames) == 0:
|
||||
return 0.0
|
||||
|
||||
# Calculate RMS: sqrt(mean(samples^2))
|
||||
rms = float(np.sqrt(np.mean(frames.astype(np.float64) ** 2)))
|
||||
|
||||
# Clamp to 0.0-1.0 range
|
||||
return min(1.0, max(0.0, rms))
|
||||
|
||||
def get_db(self, frames: NDArray[np.float32]) -> float:
|
||||
"""Calculate dB level from audio frames.
|
||||
|
||||
Args:
|
||||
frames: Audio samples as float32 array (normalized -1.0 to 1.0).
|
||||
|
||||
Returns:
|
||||
Level in dB (MIN_DB to 0 range).
|
||||
"""
|
||||
rms = self.get_rms(frames)
|
||||
|
||||
if rms <= 0:
|
||||
return self.MIN_DB
|
||||
|
||||
# Convert to dB: 20 * log10(rms)
|
||||
db = 20.0 * math.log10(rms)
|
||||
|
||||
# Clamp to MIN_DB to 0 range
|
||||
return max(self.MIN_DB, min(0.0, db))
|
||||
|
||||
def rms_to_db(self, rms: float) -> float:
|
||||
"""Convert RMS value to dB.
|
||||
|
||||
Args:
|
||||
rms: RMS level (0.0-1.0).
|
||||
|
||||
Returns:
|
||||
Level in dB (MIN_DB to 0 range).
|
||||
"""
|
||||
if rms <= 0:
|
||||
return self.MIN_DB
|
||||
|
||||
db = 20.0 * math.log10(rms)
|
||||
return max(self.MIN_DB, min(0.0, db))
|
||||
|
||||
def db_to_rms(self, db: float) -> float:
|
||||
"""Convert dB value to RMS.
|
||||
|
||||
Args:
|
||||
db: Level in dB.
|
||||
|
||||
Returns:
|
||||
RMS level (0.0-1.0).
|
||||
"""
|
||||
return 0.0 if db <= self.MIN_DB else 10.0 ** (db / 20.0)
|
||||
@@ -1,168 +0,0 @@
|
||||
"""Audio capture protocols and data types for Spike 2.
|
||||
|
||||
These protocols define the contracts for audio capture components that will be
|
||||
promoted to src/noteflow/audio/ after validation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AudioDeviceInfo:
|
||||
"""Information about an audio input device."""
|
||||
|
||||
device_id: int
|
||||
name: str
|
||||
channels: int
|
||||
sample_rate: int
|
||||
is_default: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimestampedAudio:
|
||||
"""Audio frames with capture timestamp."""
|
||||
|
||||
frames: NDArray[np.float32]
|
||||
timestamp: float # Monotonic time when captured
|
||||
duration: float # Duration in seconds
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate audio data."""
|
||||
if self.duration < 0:
|
||||
raise ValueError("Duration must be non-negative")
|
||||
if self.timestamp < 0:
|
||||
raise ValueError("Timestamp must be non-negative")
|
||||
|
||||
|
||||
# Type alias for audio frame callback
|
||||
AudioFrameCallback = Callable[[NDArray[np.float32], float], None]
|
||||
|
||||
|
||||
class AudioCapture(Protocol):
|
||||
"""Protocol for audio input capture.
|
||||
|
||||
Implementations should handle device enumeration, stream management,
|
||||
and device change detection.
|
||||
"""
|
||||
|
||||
def list_devices(self) -> list[AudioDeviceInfo]:
|
||||
"""List available audio input devices.
|
||||
|
||||
Returns:
|
||||
List of AudioDeviceInfo for all available input devices.
|
||||
"""
|
||||
...
|
||||
|
||||
def start(
|
||||
self,
|
||||
device_id: int | None,
|
||||
on_frames: AudioFrameCallback,
|
||||
sample_rate: int = 16000,
|
||||
channels: int = 1,
|
||||
chunk_duration_ms: int = 100,
|
||||
) -> None:
|
||||
"""Start capturing audio from the specified device.
|
||||
|
||||
Args:
|
||||
device_id: Device ID to capture from, or None for default device.
|
||||
on_frames: Callback receiving (frames, timestamp) for each chunk.
|
||||
sample_rate: Sample rate in Hz (default 16kHz for ASR).
|
||||
channels: Number of channels (default 1 for mono).
|
||||
chunk_duration_ms: Duration of each audio chunk in milliseconds.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If already capturing.
|
||||
ValueError: If device_id is invalid.
|
||||
"""
|
||||
...
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop audio capture.
|
||||
|
||||
Safe to call even if not capturing.
|
||||
"""
|
||||
...
|
||||
|
||||
def is_capturing(self) -> bool:
|
||||
"""Check if currently capturing audio.
|
||||
|
||||
Returns:
|
||||
True if capture is active.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class AudioLevelProvider(Protocol):
|
||||
"""Protocol for computing audio levels (VU meter data)."""
|
||||
|
||||
def get_rms(self, frames: NDArray[np.float32]) -> float:
|
||||
"""Calculate RMS level from audio frames.
|
||||
|
||||
Args:
|
||||
frames: Audio samples as float32 array (normalized -1.0 to 1.0).
|
||||
|
||||
Returns:
|
||||
RMS level normalized to 0.0-1.0 range.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_db(self, frames: NDArray[np.float32]) -> float:
|
||||
"""Calculate dB level from audio frames.
|
||||
|
||||
Args:
|
||||
frames: Audio samples as float32 array (normalized -1.0 to 1.0).
|
||||
|
||||
Returns:
|
||||
Level in dB (typically -60 to 0 range).
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class RingBuffer(Protocol):
|
||||
"""Protocol for timestamped audio ring buffer.
|
||||
|
||||
Ring buffers store recent audio with timestamps for ASR processing
|
||||
and playback sync.
|
||||
"""
|
||||
|
||||
def push(self, audio: TimestampedAudio) -> None:
|
||||
"""Add audio to the buffer.
|
||||
|
||||
Old audio is discarded if buffer exceeds max_duration.
|
||||
|
||||
Args:
|
||||
audio: Timestamped audio chunk to add.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_window(self, duration_seconds: float) -> list[TimestampedAudio]:
|
||||
"""Get the last N seconds of audio.
|
||||
|
||||
Args:
|
||||
duration_seconds: How many seconds of audio to retrieve.
|
||||
|
||||
Returns:
|
||||
List of TimestampedAudio chunks, ordered oldest to newest.
|
||||
"""
|
||||
...
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all audio from the buffer."""
|
||||
...
|
||||
|
||||
@property
|
||||
def duration(self) -> float:
|
||||
"""Total duration of buffered audio in seconds."""
|
||||
...
|
||||
|
||||
@property
|
||||
def max_duration(self) -> float:
|
||||
"""Maximum buffer duration in seconds."""
|
||||
...
|
||||
@@ -1,108 +0,0 @@
|
||||
"""Timestamped audio ring buffer implementation.
|
||||
|
||||
Stores recent audio with timestamps for ASR processing and playback sync.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
|
||||
from .protocols import TimestampedAudio
|
||||
|
||||
|
||||
class TimestampedRingBuffer:
|
||||
"""Ring buffer for timestamped audio chunks.
|
||||
|
||||
Automatically discards old audio when the buffer exceeds max_duration.
|
||||
Thread-safe for single-producer, single-consumer use.
|
||||
"""
|
||||
|
||||
def __init__(self, max_duration: float = 30.0) -> None:
|
||||
"""Initialize ring buffer.
|
||||
|
||||
Args:
|
||||
max_duration: Maximum audio duration to keep in seconds.
|
||||
|
||||
Raises:
|
||||
ValueError: If max_duration is not positive.
|
||||
"""
|
||||
if max_duration <= 0:
|
||||
raise ValueError("max_duration must be positive")
|
||||
|
||||
self._max_duration = max_duration
|
||||
self._buffer: deque[TimestampedAudio] = deque()
|
||||
self._total_duration: float = 0.0
|
||||
|
||||
def push(self, audio: TimestampedAudio) -> None:
|
||||
"""Add audio to the buffer.
|
||||
|
||||
Old audio is discarded if buffer exceeds max_duration.
|
||||
|
||||
Args:
|
||||
audio: Timestamped audio chunk to add.
|
||||
"""
|
||||
self._buffer.append(audio)
|
||||
self._total_duration += audio.duration
|
||||
|
||||
# Evict old chunks if over capacity
|
||||
while self._total_duration > self._max_duration and self._buffer:
|
||||
old = self._buffer.popleft()
|
||||
self._total_duration -= old.duration
|
||||
|
||||
def get_window(self, duration_seconds: float) -> list[TimestampedAudio]:
|
||||
"""Get the last N seconds of audio.
|
||||
|
||||
Args:
|
||||
duration_seconds: How many seconds of audio to retrieve.
|
||||
|
||||
Returns:
|
||||
List of TimestampedAudio chunks, ordered oldest to newest.
|
||||
"""
|
||||
if duration_seconds <= 0:
|
||||
return []
|
||||
|
||||
result: list[TimestampedAudio] = []
|
||||
accumulated_duration = 0.0
|
||||
|
||||
# Iterate from newest to oldest
|
||||
for audio in reversed(self._buffer):
|
||||
result.append(audio)
|
||||
accumulated_duration += audio.duration
|
||||
if accumulated_duration >= duration_seconds:
|
||||
break
|
||||
|
||||
# Return in chronological order (oldest first)
|
||||
result.reverse()
|
||||
return result
|
||||
|
||||
def get_all(self) -> list[TimestampedAudio]:
|
||||
"""Get all buffered audio.
|
||||
|
||||
Returns:
|
||||
List of all TimestampedAudio chunks, ordered oldest to newest.
|
||||
"""
|
||||
return list(self._buffer)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all audio from the buffer."""
|
||||
self._buffer.clear()
|
||||
self._total_duration = 0.0
|
||||
|
||||
@property
|
||||
def duration(self) -> float:
|
||||
"""Total duration of buffered audio in seconds."""
|
||||
return self._total_duration
|
||||
|
||||
@property
|
||||
def max_duration(self) -> float:
|
||||
"""Maximum buffer duration in seconds."""
|
||||
return self._max_duration
|
||||
|
||||
@property
|
||||
def chunk_count(self) -> int:
|
||||
"""Number of audio chunks in the buffer."""
|
||||
return len(self._buffer)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return number of chunks in buffer."""
|
||||
return len(self._buffer)
|
||||
@@ -1,96 +0,0 @@
|
||||
# Spike 3: ASR Latency - FINDINGS
|
||||
|
||||
## Status: VALIDATED
|
||||
|
||||
All exit criteria met with the "tiny" model on CPU.
|
||||
|
||||
## Performance Results
|
||||
|
||||
Tested on Linux (Python 3.12, faster-whisper 1.2.1, CPU int8):
|
||||
|
||||
| Metric | tiny model | Requirement |
|
||||
|--------|------------|-------------|
|
||||
| Model load time | **1.6s** | <10s |
|
||||
| 3s audio processing | 0.15-0.31s | <3s for 5s audio |
|
||||
| Real-time factor | **0.05-0.10x** | <1.0x |
|
||||
| VAD filtering | Working | - |
|
||||
| Word timestamps | Available | - |
|
||||
|
||||
**Conclusion**: ASR is significantly faster than real-time, meeting all latency requirements.
|
||||
|
||||
## Implementation Summary
|
||||
|
||||
### Files Created
|
||||
- `protocols.py` - Defines AsrEngine protocol
|
||||
- `dto.py` - AsrResult, WordTiming, PartialUpdate, FinalSegment DTOs
|
||||
- `engine_impl.py` - FasterWhisperEngine implementation
|
||||
- `demo.py` - Interactive demo with latency benchmarks
|
||||
|
||||
### Key Design Decisions
|
||||
|
||||
1. **faster-whisper**: CTranslate2-based Whisper for efficient inference
|
||||
2. **int8 quantization**: Best CPU performance without quality loss
|
||||
3. **VAD filter**: Built-in voice activity detection filters silence
|
||||
4. **Word timestamps**: Enabled for accurate transcript navigation
|
||||
|
||||
### Model Sizes and Memory
|
||||
|
||||
| Model | Download | Memory | Use Case |
|
||||
|-------|----------|--------|----------|
|
||||
| tiny | ~75MB | ~150MB | Development, low-power |
|
||||
| base | ~150MB | ~300MB | **Recommended for V1** |
|
||||
| small | ~500MB | ~1GB | Better accuracy |
|
||||
| medium | ~1.5GB | ~3GB | High accuracy |
|
||||
| large-v3 | ~3GB | ~6GB | Maximum accuracy |
|
||||
|
||||
## Exit Criteria Status
|
||||
|
||||
- [x] Model downloads and caches correctly
|
||||
- [x] Model loads in <10s on CPU (1.6s achieved)
|
||||
- [x] 5s audio chunk transcribes in <3s (~0.5s achieved)
|
||||
- [x] Memory usage documented per model size
|
||||
- [x] Can configure cache directory (HuggingFace cache)
|
||||
|
||||
## VAD Integration
|
||||
|
||||
faster-whisper includes Silero VAD:
|
||||
- Automatically filters non-speech segments
|
||||
- Reduces hallucinations on silence
|
||||
- ~30ms overhead per audio chunk
|
||||
|
||||
## Cross-Platform Notes
|
||||
|
||||
- **Linux/Windows with CUDA**: GPU acceleration available
|
||||
- **macOS**: CPU only (no MPS/Metal support)
|
||||
- **Apple Silicon**: Uses Apple Accelerate for CPU optimization
|
||||
|
||||
## Running the Demo
|
||||
|
||||
```bash
|
||||
# With tiny model (fastest)
|
||||
python -m spikes.spike_03_asr_latency.demo --model tiny
|
||||
|
||||
# With base model (recommended for production)
|
||||
python -m spikes.spike_03_asr_latency.demo --model base
|
||||
|
||||
# With a WAV file
|
||||
python -m spikes.spike_03_asr_latency.demo --model tiny -i speech.wav
|
||||
|
||||
# List available models
|
||||
python -m spikes.spike_03_asr_latency.demo --list-models
|
||||
```
|
||||
|
||||
## Model Cache Location
|
||||
|
||||
Models are cached in the HuggingFace cache:
|
||||
- Linux: `~/.cache/huggingface/hub/`
|
||||
- macOS: `~/.cache/huggingface/hub/`
|
||||
- Windows: `C:\Users\<user>\.cache\huggingface\hub\`
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. Test with real speech audio files
|
||||
2. Benchmark "base" model for production use
|
||||
3. Implement partial transcript streaming
|
||||
4. Test GPU acceleration on CUDA systems
|
||||
5. Measure memory impact of concurrent transcription
|
||||
@@ -1 +0,0 @@
|
||||
"""Spike 3: ASR latency validation."""
|
||||
@@ -1,287 +0,0 @@
|
||||
"""Interactive ASR latency demo for Spike 3.
|
||||
|
||||
Run with: python -m spikes.spike_03_asr_latency.demo
|
||||
|
||||
Features:
|
||||
- Downloads model on first run (shows progress)
|
||||
- Generates synthetic audio for testing (or accepts WAV file)
|
||||
- Displays transcription as it streams
|
||||
- Shows latency metrics (time-to-first-word, total time)
|
||||
- Reports memory usage
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from .engine_impl import VALID_MODEL_SIZES, FasterWhisperEngine
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_memory_usage_mb() -> float:
|
||||
"""Get current process memory usage in MB."""
|
||||
try:
|
||||
import psutil
|
||||
|
||||
process = psutil.Process(os.getpid())
|
||||
return process.memory_info().rss / 1024 / 1024
|
||||
except ImportError:
|
||||
return 0.0
|
||||
|
||||
|
||||
def generate_silence(duration_seconds: float, sample_rate: int = 16000) -> NDArray[np.float32]:
|
||||
"""Generate silent audio for testing.
|
||||
|
||||
Args:
|
||||
duration_seconds: Duration of silence.
|
||||
sample_rate: Sample rate in Hz.
|
||||
|
||||
Returns:
|
||||
Float32 array of zeros.
|
||||
"""
|
||||
samples = int(duration_seconds * sample_rate)
|
||||
return np.zeros(samples, dtype=np.float32)
|
||||
|
||||
|
||||
def generate_tone(
|
||||
duration_seconds: float,
|
||||
frequency_hz: float = 440.0,
|
||||
sample_rate: int = 16000,
|
||||
amplitude: float = 0.3,
|
||||
) -> NDArray[np.float32]:
|
||||
"""Generate a sine wave tone for testing.
|
||||
|
||||
Args:
|
||||
duration_seconds: Duration of tone.
|
||||
frequency_hz: Frequency in Hz.
|
||||
sample_rate: Sample rate in Hz.
|
||||
amplitude: Amplitude (0.0-1.0).
|
||||
|
||||
Returns:
|
||||
Float32 array of sine wave samples.
|
||||
"""
|
||||
samples = int(duration_seconds * sample_rate)
|
||||
t = np.linspace(0, duration_seconds, samples, dtype=np.float32)
|
||||
return (amplitude * np.sin(2 * np.pi * frequency_hz * t)).astype(np.float32)
|
||||
|
||||
|
||||
def load_wav_file(path: Path, target_sample_rate: int = 16000) -> NDArray[np.float32]:
|
||||
"""Load a WAV file and convert to float32.
|
||||
|
||||
Args:
|
||||
path: Path to WAV file.
|
||||
target_sample_rate: Expected sample rate.
|
||||
|
||||
Returns:
|
||||
Float32 array of audio samples.
|
||||
|
||||
Raises:
|
||||
ValueError: If file format is incompatible.
|
||||
"""
|
||||
with wave.open(str(path), "rb") as wf:
|
||||
if wf.getnchannels() != 1:
|
||||
raise ValueError(f"Expected mono audio, got {wf.getnchannels()} channels")
|
||||
|
||||
sample_rate = wf.getframerate()
|
||||
if sample_rate != target_sample_rate:
|
||||
logger.warning(
|
||||
"Sample rate mismatch: expected %d, got %d",
|
||||
target_sample_rate,
|
||||
sample_rate,
|
||||
)
|
||||
|
||||
# Read all frames
|
||||
frames = wf.readframes(wf.getnframes())
|
||||
|
||||
# Convert to numpy array
|
||||
sample_width = wf.getsampwidth()
|
||||
if sample_width == 2:
|
||||
audio = np.frombuffer(frames, dtype=np.int16)
|
||||
return audio.astype(np.float32) / 32768.0
|
||||
elif sample_width == 4:
|
||||
audio = np.frombuffer(frames, dtype=np.int32)
|
||||
return audio.astype(np.float32) / 2147483648.0
|
||||
else:
|
||||
raise ValueError(f"Unsupported sample width: {sample_width}")
|
||||
|
||||
|
||||
class AsrDemo:
|
||||
"""Interactive ASR demonstration."""
|
||||
|
||||
def __init__(self, model_size: str = "tiny") -> None:
|
||||
"""Initialize the demo.
|
||||
|
||||
Args:
|
||||
model_size: Model size to use.
|
||||
"""
|
||||
self.model_size = model_size
|
||||
self.engine = FasterWhisperEngine(
|
||||
compute_type="int8",
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
def load_model(self) -> float:
|
||||
"""Load the ASR model.
|
||||
|
||||
Returns:
|
||||
Load time in seconds.
|
||||
"""
|
||||
print(f"\n=== Loading Model: {self.model_size} ===")
|
||||
mem_before = get_memory_usage_mb()
|
||||
|
||||
start = time.perf_counter()
|
||||
self.engine.load_model(self.model_size)
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
mem_after = get_memory_usage_mb()
|
||||
mem_used = mem_after - mem_before
|
||||
|
||||
print(f" Load time: {elapsed:.2f}s")
|
||||
print(f" Memory before: {mem_before:.1f} MB")
|
||||
print(f" Memory after: {mem_after:.1f} MB")
|
||||
print(f" Memory used: {mem_used:.1f} MB")
|
||||
|
||||
return elapsed
|
||||
|
||||
def transcribe_audio(
|
||||
self,
|
||||
audio: NDArray[np.float32],
|
||||
audio_name: str = "audio",
|
||||
) -> None:
|
||||
"""Transcribe audio and display results.
|
||||
|
||||
Args:
|
||||
audio: Audio samples (float32, 16kHz).
|
||||
audio_name: Name for display.
|
||||
"""
|
||||
duration = len(audio) / 16000
|
||||
print(f"\n=== Transcribing: {audio_name} ({duration:.2f}s) ===")
|
||||
|
||||
start = time.perf_counter()
|
||||
first_result_time: float | None = None
|
||||
segment_count = 0
|
||||
|
||||
for result in self.engine.transcribe(audio):
|
||||
if first_result_time is None:
|
||||
first_result_time = time.perf_counter() - start
|
||||
|
||||
segment_count += 1
|
||||
print(f"\n[{result.start:.2f}s - {result.end:.2f}s] {result.text}")
|
||||
|
||||
if result.words:
|
||||
print(f" Words: {len(result.words)}")
|
||||
# Show first few words with timing
|
||||
for word in result.words[:3]:
|
||||
print(f" '{word.word}' @ {word.start:.2f}s (conf: {word.probability:.2f})")
|
||||
if len(result.words) > 3:
|
||||
print(f" ... and {len(result.words) - 3} more words")
|
||||
|
||||
total_time = time.perf_counter() - start
|
||||
|
||||
print("\n=== Results ===")
|
||||
print(f" Audio duration: {duration:.2f}s")
|
||||
print(f" Segments found: {segment_count}")
|
||||
print(f" Time to first result: {first_result_time:.3f}s" if first_result_time else " No results")
|
||||
print(f" Total transcription time: {total_time:.3f}s")
|
||||
print(f" Real-time factor: {total_time / duration:.2f}x" if duration > 0 else " N/A")
|
||||
|
||||
if total_time > 0 and duration > 0:
|
||||
rtf = total_time / duration
|
||||
if rtf < 1.0:
|
||||
print(" Status: FASTER than real-time")
|
||||
else:
|
||||
print(f" Status: {rtf:.1f}x slower than real-time")
|
||||
|
||||
def demo_with_silence(self, duration: float = 5.0) -> None:
|
||||
"""Demo with silent audio (should produce no results)."""
|
||||
audio = generate_silence(duration)
|
||||
self.transcribe_audio(audio, f"silence ({duration}s)")
|
||||
|
||||
def demo_with_tone(self, duration: float = 5.0) -> None:
|
||||
"""Demo with tone audio (should produce minimal results)."""
|
||||
audio = generate_tone(duration)
|
||||
self.transcribe_audio(audio, f"440Hz tone ({duration}s)")
|
||||
|
||||
def demo_with_file(self, path: Path) -> None:
|
||||
"""Demo with a WAV file."""
|
||||
print(f"\nLoading WAV file: {path}")
|
||||
audio = load_wav_file(path)
|
||||
self.transcribe_audio(audio, path.name)
|
||||
|
||||
def run(self, audio_path: Path | None = None) -> None:
|
||||
"""Run the demo.
|
||||
|
||||
Args:
|
||||
audio_path: Optional path to WAV file.
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("NoteFlow ASR Demo - Spike 3")
|
||||
print("=" * 60)
|
||||
|
||||
# Load model
|
||||
self.load_model()
|
||||
|
||||
if audio_path and audio_path.exists():
|
||||
# Use provided audio file
|
||||
self.demo_with_file(audio_path)
|
||||
else:
|
||||
# Demo with synthetic audio
|
||||
print("\nNo audio file provided, using synthetic audio...")
|
||||
self.demo_with_silence(3.0)
|
||||
self.demo_with_tone(3.0)
|
||||
|
||||
print("\n=== Demo Complete ===")
|
||||
print(f"Final memory usage: {get_memory_usage_mb():.1f} MB")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Run the ASR demo."""
|
||||
parser = argparse.ArgumentParser(description="ASR Latency Demo - Spike 3")
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model",
|
||||
type=str,
|
||||
default="tiny",
|
||||
choices=list(VALID_MODEL_SIZES),
|
||||
help="Model size to use (default: tiny)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--input",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Input WAV file to transcribe",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list-models",
|
||||
action="store_true",
|
||||
help="List available model sizes and exit",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.list_models:
|
||||
print("Available model sizes:")
|
||||
for size in VALID_MODEL_SIZES:
|
||||
print(f" {size}")
|
||||
return
|
||||
|
||||
demo = AsrDemo(model_size=args.model)
|
||||
demo.run(audio_path=args.input)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,88 +0,0 @@
|
||||
"""Data Transfer Objects for ASR.
|
||||
|
||||
These DTOs define the data structures used by ASR components.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import NewType
|
||||
|
||||
SegmentID = NewType("SegmentID", str)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WordTiming:
|
||||
"""Word-level timing information."""
|
||||
|
||||
word: str
|
||||
start: float # Start time in seconds
|
||||
end: float # End time in seconds
|
||||
probability: float # Confidence (0.0-1.0)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate timing data."""
|
||||
if self.end < self.start:
|
||||
raise ValueError(f"Word end ({self.end}) < start ({self.start})")
|
||||
if not 0.0 <= self.probability <= 1.0:
|
||||
raise ValueError(f"Probability must be 0.0-1.0, got {self.probability}")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AsrResult:
|
||||
"""ASR transcription result for a segment."""
|
||||
|
||||
text: str
|
||||
start: float # Start time in seconds
|
||||
end: float # End time in seconds
|
||||
words: tuple[WordTiming, ...] = field(default_factory=tuple)
|
||||
language: str = "en"
|
||||
language_probability: float = 1.0
|
||||
avg_logprob: float = 0.0
|
||||
no_speech_prob: float = 0.0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate result data."""
|
||||
if self.end < self.start:
|
||||
raise ValueError(f"Segment end ({self.end}) < start ({self.start})")
|
||||
|
||||
@property
|
||||
def duration(self) -> float:
|
||||
"""Duration of the segment in seconds."""
|
||||
return self.end - self.start
|
||||
|
||||
|
||||
@dataclass
|
||||
class PartialUpdate:
|
||||
"""Unstable partial transcript (may be replaced)."""
|
||||
|
||||
text: str
|
||||
start: float
|
||||
end: float
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate partial data."""
|
||||
if self.end < self.start:
|
||||
raise ValueError(f"Partial end ({self.end}) < start ({self.start})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinalSegment:
|
||||
"""Committed transcript segment (immutable after creation)."""
|
||||
|
||||
segment_id: SegmentID
|
||||
text: str
|
||||
start: float
|
||||
end: float
|
||||
words: tuple[WordTiming, ...] = field(default_factory=tuple)
|
||||
speaker_label: str = "Unknown"
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate segment data."""
|
||||
if self.end < self.start:
|
||||
raise ValueError(f"Segment end ({self.end}) < start ({self.start})")
|
||||
|
||||
@property
|
||||
def duration(self) -> float:
|
||||
"""Duration of the segment in seconds."""
|
||||
return self.end - self.start
|
||||
@@ -1,178 +0,0 @@
|
||||
"""ASR engine implementation using faster-whisper.
|
||||
|
||||
Provides Whisper-based transcription with word-level timestamps.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from typing import TYPE_CHECKING, Final
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from .dto import AsrResult, WordTiming
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Available model sizes
|
||||
VALID_MODEL_SIZES: Final[tuple[str, ...]] = (
|
||||
"tiny",
|
||||
"tiny.en",
|
||||
"base",
|
||||
"base.en",
|
||||
"small",
|
||||
"small.en",
|
||||
"medium",
|
||||
"medium.en",
|
||||
"large-v1",
|
||||
"large-v2",
|
||||
"large-v3",
|
||||
)
|
||||
|
||||
|
||||
class FasterWhisperEngine:
|
||||
"""faster-whisper based ASR engine.
|
||||
|
||||
Uses CTranslate2 for efficient Whisper inference on CPU or GPU.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
compute_type: str = "int8",
|
||||
device: str = "cpu",
|
||||
num_workers: int = 1,
|
||||
) -> None:
|
||||
"""Initialize the engine.
|
||||
|
||||
Args:
|
||||
compute_type: Computation type ("int8", "float16", "float32").
|
||||
device: Device to use ("cpu" or "cuda").
|
||||
num_workers: Number of worker threads.
|
||||
"""
|
||||
self._compute_type = compute_type
|
||||
self._device = device
|
||||
self._num_workers = num_workers
|
||||
self._model = None
|
||||
self._model_size: str | None = None
|
||||
|
||||
def load_model(self, model_size: str = "base") -> None:
|
||||
"""Load the ASR model.
|
||||
|
||||
Args:
|
||||
model_size: Model size (e.g., "tiny", "base", "small").
|
||||
|
||||
Raises:
|
||||
ValueError: If model_size is invalid.
|
||||
RuntimeError: If model loading fails.
|
||||
"""
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
if model_size not in VALID_MODEL_SIZES:
|
||||
raise ValueError(
|
||||
f"Invalid model size: {model_size}. "
|
||||
f"Valid sizes: {', '.join(VALID_MODEL_SIZES)}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Loading Whisper model '%s' on %s with %s compute...",
|
||||
model_size,
|
||||
self._device,
|
||||
self._compute_type,
|
||||
)
|
||||
|
||||
try:
|
||||
self._model = WhisperModel(
|
||||
model_size,
|
||||
device=self._device,
|
||||
compute_type=self._compute_type,
|
||||
num_workers=self._num_workers,
|
||||
)
|
||||
self._model_size = model_size
|
||||
logger.info("Model loaded successfully")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load model: {e}") from e
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
audio: "NDArray[np.float32]",
|
||||
language: str | None = None,
|
||||
) -> Iterator[AsrResult]:
|
||||
"""Transcribe audio and yield results.
|
||||
|
||||
Args:
|
||||
audio: Audio samples as float32 array (16kHz mono, normalized).
|
||||
language: Optional language code (e.g., "en").
|
||||
|
||||
Yields:
|
||||
AsrResult segments with word-level timestamps.
|
||||
"""
|
||||
if self._model is None:
|
||||
raise RuntimeError("Model not loaded. Call load_model() first.")
|
||||
|
||||
# Transcribe with word timestamps
|
||||
segments, info = self._model.transcribe(
|
||||
audio,
|
||||
language=language,
|
||||
word_timestamps=True,
|
||||
beam_size=5,
|
||||
vad_filter=True, # Filter out non-speech
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Detected language: %s (prob: %.2f)",
|
||||
info.language,
|
||||
info.language_probability,
|
||||
)
|
||||
|
||||
for segment in segments:
|
||||
# Convert word info to WordTiming objects
|
||||
words: list[WordTiming] = []
|
||||
if segment.words:
|
||||
words.extend(
|
||||
WordTiming(
|
||||
word=word.word,
|
||||
start=word.start,
|
||||
end=word.end,
|
||||
probability=word.probability,
|
||||
)
|
||||
for word in segment.words
|
||||
)
|
||||
yield AsrResult(
|
||||
text=segment.text.strip(),
|
||||
start=segment.start,
|
||||
end=segment.end,
|
||||
words=tuple(words),
|
||||
language=info.language,
|
||||
language_probability=info.language_probability,
|
||||
avg_logprob=segment.avg_logprob,
|
||||
no_speech_prob=segment.no_speech_prob,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
"""Return True if model is loaded."""
|
||||
return self._model is not None
|
||||
|
||||
@property
|
||||
def model_size(self) -> str | None:
|
||||
"""Return the loaded model size, or None if not loaded."""
|
||||
return self._model_size
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload the model to free memory."""
|
||||
self._model = None
|
||||
self._model_size = None
|
||||
logger.info("Model unloaded")
|
||||
|
||||
@property
|
||||
def compute_type(self) -> str:
|
||||
"""Return the compute type."""
|
||||
return self._compute_type
|
||||
|
||||
@property
|
||||
def device(self) -> str:
|
||||
"""Return the device."""
|
||||
return self._device
|
||||
@@ -1,70 +0,0 @@
|
||||
"""ASR protocols for Spike 3.
|
||||
|
||||
These protocols define the contracts for ASR components that will be
|
||||
promoted to src/noteflow/asr/ after validation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from .dto import AsrResult
|
||||
|
||||
|
||||
class AsrEngine(Protocol):
|
||||
"""Protocol for ASR transcription engine.
|
||||
|
||||
Implementations should handle model loading, caching, and inference.
|
||||
"""
|
||||
|
||||
def load_model(self, model_size: str = "base") -> None:
|
||||
"""Load the ASR model.
|
||||
|
||||
Downloads the model if not cached.
|
||||
|
||||
Args:
|
||||
model_size: Model size ("tiny", "base", "small", "medium", "large").
|
||||
|
||||
Raises:
|
||||
ValueError: If model_size is invalid.
|
||||
RuntimeError: If model loading fails.
|
||||
"""
|
||||
...
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
audio: "NDArray[np.float32]",
|
||||
language: str | None = None,
|
||||
) -> Iterator[AsrResult]:
|
||||
"""Transcribe audio and yield results.
|
||||
|
||||
Args:
|
||||
audio: Audio samples as float32 array (16kHz mono, normalized).
|
||||
language: Optional language code (e.g., "en"). Auto-detected if None.
|
||||
|
||||
Yields:
|
||||
AsrResult segments.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model not loaded.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
"""Return True if model is loaded."""
|
||||
...
|
||||
|
||||
@property
|
||||
def model_size(self) -> str | None:
|
||||
"""Return the loaded model size, or None if not loaded."""
|
||||
...
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload the model to free memory."""
|
||||
...
|
||||
@@ -1,98 +0,0 @@
|
||||
# Spike 4: Key Storage + Encryption - FINDINGS
|
||||
|
||||
## Status: VALIDATED
|
||||
|
||||
All exit criteria met with in-memory key storage. OS keyring requires further testing.
|
||||
|
||||
## Performance Results
|
||||
|
||||
Tested on Linux (Python 3.12, cryptography 42.0):
|
||||
|
||||
| Operation | Time | Throughput |
|
||||
|-----------|------|------------|
|
||||
| DEK wrap | 4.4ms | - |
|
||||
| DEK unwrap | 0.4ms | - |
|
||||
| Chunk encrypt (16KB) | 0.039ms | **398 MB/s** |
|
||||
| Chunk decrypt (16KB) | 0.017ms | **893 MB/s** |
|
||||
| File encrypt (1MB) | 1ms | **826 MB/s** |
|
||||
| File decrypt (1MB) | 1ms | **1.88 GB/s** |
|
||||
|
||||
**Conclusion**: Encryption is fast enough for real-time audio (<1ms per 16KB chunk).
|
||||
|
||||
## Implementation Summary
|
||||
|
||||
### Files Created
|
||||
- `protocols.py` - Defines KeyStore, CryptoBox, AssetWriter/Reader protocols
|
||||
- `keystore_impl.py` - KeyringKeyStore and InMemoryKeyStore implementations
|
||||
- `crypto_impl.py` - AesGcmCryptoBox, ChunkedAssetWriter/Reader implementations
|
||||
- `demo.py` - Interactive demo with throughput benchmarks
|
||||
|
||||
### Key Design Decisions
|
||||
|
||||
1. **Envelope Encryption**: Master key wraps per-meeting DEKs
|
||||
2. **AES-256-GCM**: Industry standard authenticated encryption
|
||||
3. **12-byte nonce**: Standard for AES-GCM (96 bits)
|
||||
4. **16-byte tag**: Full 128-bit authentication tag
|
||||
5. **Chunked file format**: 4-byte length prefix + nonce + ciphertext + tag
|
||||
|
||||
### File Format
|
||||
|
||||
```
|
||||
Header:
|
||||
4 bytes: magic ("NFAE")
|
||||
1 byte: version (1)
|
||||
|
||||
Chunks (repeated):
|
||||
4 bytes: chunk length (big-endian)
|
||||
12 bytes: nonce
|
||||
N bytes: ciphertext
|
||||
16 bytes: authentication tag
|
||||
```
|
||||
|
||||
### Overhead
|
||||
|
||||
- Per-chunk: 28 bytes (12 nonce + 16 tag) + 4 length prefix = 32 bytes
|
||||
- For 16KB chunks: 0.2% overhead
|
||||
- For 1MB file: ~2KB overhead
|
||||
|
||||
## Exit Criteria Status
|
||||
|
||||
- [x] Master key stored in OS keychain (InMemory validated; Keyring requires GUI)
|
||||
- [x] Encrypt/decrypt roundtrip works
|
||||
- [x] <1ms per 16KB chunk encryption (0.039ms achieved)
|
||||
- [x] DEK deletion renders file unreadable (validated)
|
||||
- [ ] keyring works on Linux (requires SecretService daemon)
|
||||
|
||||
## Cross-Platform Notes
|
||||
|
||||
- **Linux**: Requires SecretService (GNOME Keyring or KWallet running)
|
||||
- **macOS**: Uses Keychain (should work out of box)
|
||||
- **Windows PyInstaller**: Known issue - must explicitly import `keyring.backends.Windows`
|
||||
|
||||
## Running the Demo
|
||||
|
||||
```bash
|
||||
# In-memory key storage (no dependencies)
|
||||
python -m spikes.spike_04_encryption.demo
|
||||
|
||||
# With OS keyring (requires SecretService on Linux)
|
||||
python -m spikes.spike_04_encryption.demo --keyring
|
||||
|
||||
# Larger file test
|
||||
python -m spikes.spike_04_encryption.demo --size 10485760 # 10MB
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
1. Master key never leaves keyring (only accessed via API)
|
||||
2. Each meeting has unique DEK (compromise one ≠ compromise all)
|
||||
3. Nonce randomly generated per chunk (no reuse)
|
||||
4. Authentication tag prevents tampering
|
||||
5. Cryptographic delete: removing DEK makes data unrecoverable
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. Test with OS keyring on system with SecretService
|
||||
2. Add PyInstaller-specific keyring backend handling
|
||||
3. Consider adding file metadata (creation time, checksum)
|
||||
4. Evaluate compression before encryption
|
||||
@@ -1 +0,0 @@
|
||||
"""Spike 4: Key storage and encryption validation."""
|
||||
@@ -1,313 +0,0 @@
|
||||
"""Cryptographic operations implementation using cryptography library.
|
||||
|
||||
Provides AES-GCM encryption for audio data with envelope encryption.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
import struct
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, BinaryIO, Final
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
from .protocols import EncryptedChunk
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .keystore_impl import InMemoryKeyStore, KeyringKeyStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
KEY_SIZE: Final[int] = 32 # 256-bit key
|
||||
NONCE_SIZE: Final[int] = 12 # 96-bit nonce for AES-GCM
|
||||
TAG_SIZE: Final[int] = 16 # 128-bit authentication tag
|
||||
|
||||
# File format magic number and version
|
||||
FILE_MAGIC: Final[bytes] = b"NFAE" # NoteFlow Audio Encrypted
|
||||
FILE_VERSION: Final[int] = 1
|
||||
|
||||
|
||||
class AesGcmCryptoBox:
|
||||
"""AES-GCM based encryption with envelope encryption.
|
||||
|
||||
Uses a master key to wrap/unwrap per-meeting Data Encryption Keys (DEKs).
|
||||
Each audio chunk is encrypted with AES-256-GCM using the DEK.
|
||||
"""
|
||||
|
||||
def __init__(self, keystore: KeyringKeyStore | InMemoryKeyStore) -> None:
|
||||
"""Initialize the crypto box.
|
||||
|
||||
Args:
|
||||
keystore: KeyStore instance for master key access.
|
||||
"""
|
||||
self._keystore = keystore
|
||||
self._master_cipher: AESGCM | None = None
|
||||
|
||||
def _get_master_cipher(self) -> AESGCM:
|
||||
"""Get or create the master key cipher."""
|
||||
if self._master_cipher is None:
|
||||
master_key = self._keystore.get_or_create_master_key()
|
||||
self._master_cipher = AESGCM(master_key)
|
||||
return self._master_cipher
|
||||
|
||||
def generate_dek(self) -> bytes:
|
||||
"""Generate a new Data Encryption Key.
|
||||
|
||||
Returns:
|
||||
32-byte random DEK.
|
||||
"""
|
||||
return secrets.token_bytes(KEY_SIZE)
|
||||
|
||||
def wrap_dek(self, dek: bytes) -> bytes:
|
||||
"""Encrypt DEK with master key.
|
||||
|
||||
Args:
|
||||
dek: Data Encryption Key to wrap.
|
||||
|
||||
Returns:
|
||||
Encrypted DEK (nonce || ciphertext || tag).
|
||||
"""
|
||||
cipher = self._get_master_cipher()
|
||||
nonce = secrets.token_bytes(NONCE_SIZE)
|
||||
ciphertext = cipher.encrypt(nonce, dek, associated_data=None)
|
||||
# Return nonce || ciphertext (tag is appended by AESGCM)
|
||||
return nonce + ciphertext
|
||||
|
||||
def unwrap_dek(self, wrapped_dek: bytes) -> bytes:
|
||||
"""Decrypt DEK with master key.
|
||||
|
||||
Args:
|
||||
wrapped_dek: Encrypted DEK from wrap_dek().
|
||||
|
||||
Returns:
|
||||
Original DEK.
|
||||
|
||||
Raises:
|
||||
ValueError: If decryption fails.
|
||||
"""
|
||||
if len(wrapped_dek) < NONCE_SIZE + KEY_SIZE + TAG_SIZE:
|
||||
raise ValueError("Invalid wrapped DEK: too short")
|
||||
|
||||
cipher = self._get_master_cipher()
|
||||
nonce = wrapped_dek[:NONCE_SIZE]
|
||||
ciphertext = wrapped_dek[NONCE_SIZE:]
|
||||
|
||||
try:
|
||||
return cipher.decrypt(nonce, ciphertext, associated_data=None)
|
||||
except Exception as e:
|
||||
raise ValueError(f"DEK unwrap failed: {e}") from e
|
||||
|
||||
def encrypt_chunk(self, plaintext: bytes, dek: bytes) -> EncryptedChunk:
|
||||
"""Encrypt a chunk of data with AES-GCM.
|
||||
|
||||
Args:
|
||||
plaintext: Data to encrypt.
|
||||
dek: Data Encryption Key.
|
||||
|
||||
Returns:
|
||||
EncryptedChunk with nonce, ciphertext, and tag.
|
||||
"""
|
||||
cipher = AESGCM(dek)
|
||||
nonce = secrets.token_bytes(NONCE_SIZE)
|
||||
|
||||
# AESGCM appends the tag to ciphertext
|
||||
ciphertext_with_tag = cipher.encrypt(nonce, plaintext, associated_data=None)
|
||||
|
||||
# Split ciphertext and tag
|
||||
ciphertext = ciphertext_with_tag[:-TAG_SIZE]
|
||||
tag = ciphertext_with_tag[-TAG_SIZE:]
|
||||
|
||||
return EncryptedChunk(nonce=nonce, ciphertext=ciphertext, tag=tag)
|
||||
|
||||
def decrypt_chunk(self, chunk: EncryptedChunk, dek: bytes) -> bytes:
|
||||
"""Decrypt a chunk of data.
|
||||
|
||||
Args:
|
||||
chunk: EncryptedChunk to decrypt.
|
||||
dek: Data Encryption Key.
|
||||
|
||||
Returns:
|
||||
Original plaintext.
|
||||
|
||||
Raises:
|
||||
ValueError: If decryption fails.
|
||||
"""
|
||||
cipher = AESGCM(dek)
|
||||
|
||||
# Reconstruct ciphertext with tag for AESGCM
|
||||
ciphertext_with_tag = chunk.ciphertext + chunk.tag
|
||||
|
||||
try:
|
||||
return cipher.decrypt(chunk.nonce, ciphertext_with_tag, associated_data=None)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Chunk decryption failed: {e}") from e
|
||||
|
||||
|
||||
class ChunkedAssetWriter:
|
||||
"""Streaming encrypted asset writer.
|
||||
|
||||
File format:
|
||||
- 4 bytes: magic ("NFAE")
|
||||
- 1 byte: version
|
||||
- For each chunk:
|
||||
- 4 bytes: chunk length (big-endian)
|
||||
- 12 bytes: nonce
|
||||
- N bytes: ciphertext
|
||||
- 16 bytes: tag
|
||||
"""
|
||||
|
||||
def __init__(self, crypto: AesGcmCryptoBox) -> None:
|
||||
"""Initialize the writer.
|
||||
|
||||
Args:
|
||||
crypto: CryptoBox instance for encryption.
|
||||
"""
|
||||
self._crypto = crypto
|
||||
self._file: Path | None = None
|
||||
self._dek: bytes | None = None
|
||||
self._handle: BinaryIO | None = None
|
||||
self._bytes_written: int = 0
|
||||
|
||||
def open(self, path: Path, dek: bytes) -> None:
|
||||
"""Open file for writing.
|
||||
|
||||
Args:
|
||||
path: Path to the encrypted file.
|
||||
dek: Data Encryption Key for this file.
|
||||
"""
|
||||
if self._handle is not None:
|
||||
raise RuntimeError("Already open")
|
||||
|
||||
self._file = path
|
||||
self._dek = dek
|
||||
self._handle = path.open("wb")
|
||||
self._bytes_written = 0
|
||||
|
||||
# Write header
|
||||
self._handle.write(FILE_MAGIC)
|
||||
self._handle.write(struct.pack("B", FILE_VERSION))
|
||||
|
||||
logger.debug("Opened encrypted file for writing: %s", path)
|
||||
|
||||
def write_chunk(self, audio_bytes: bytes) -> None:
|
||||
"""Write and encrypt an audio chunk."""
|
||||
if self._handle is None or self._dek is None:
|
||||
raise RuntimeError("File not open")
|
||||
|
||||
# Encrypt the chunk
|
||||
chunk = self._crypto.encrypt_chunk(audio_bytes, self._dek)
|
||||
|
||||
# Calculate total chunk size (nonce + ciphertext + tag)
|
||||
chunk_data = chunk.nonce + chunk.ciphertext + chunk.tag
|
||||
chunk_length = len(chunk_data)
|
||||
|
||||
# Write length prefix and chunk data
|
||||
self._handle.write(struct.pack(">I", chunk_length))
|
||||
self._handle.write(chunk_data)
|
||||
self._handle.flush()
|
||||
|
||||
self._bytes_written += 4 + chunk_length
|
||||
|
||||
def close(self) -> None:
|
||||
"""Finalize and close the file."""
|
||||
if self._handle is not None:
|
||||
self._handle.close()
|
||||
self._handle = None
|
||||
logger.debug("Closed encrypted file, wrote %d bytes", self._bytes_written)
|
||||
|
||||
self._dek = None
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
"""Check if file is open for writing."""
|
||||
return self._handle is not None
|
||||
|
||||
@property
|
||||
def bytes_written(self) -> int:
|
||||
"""Total encrypted bytes written."""
|
||||
return self._bytes_written
|
||||
|
||||
|
||||
class ChunkedAssetReader:
|
||||
"""Streaming encrypted asset reader."""
|
||||
|
||||
def __init__(self, crypto: AesGcmCryptoBox) -> None:
|
||||
"""Initialize the reader.
|
||||
|
||||
Args:
|
||||
crypto: CryptoBox instance for decryption.
|
||||
"""
|
||||
self._crypto = crypto
|
||||
self._file: Path | None = None
|
||||
self._dek: bytes | None = None
|
||||
self._handle = None
|
||||
|
||||
def open(self, path: Path, dek: bytes) -> None:
|
||||
"""Open file for reading."""
|
||||
if self._handle is not None:
|
||||
raise RuntimeError("Already open")
|
||||
|
||||
self._file = path
|
||||
self._dek = dek
|
||||
self._handle = path.open("rb")
|
||||
|
||||
# Read and validate header
|
||||
magic = self._handle.read(4)
|
||||
if magic != FILE_MAGIC:
|
||||
self._handle.close()
|
||||
self._handle = None
|
||||
raise ValueError(f"Invalid file format: expected {FILE_MAGIC!r}, got {magic!r}")
|
||||
|
||||
version = struct.unpack("B", self._handle.read(1))[0]
|
||||
if version != FILE_VERSION:
|
||||
self._handle.close()
|
||||
self._handle = None
|
||||
raise ValueError(f"Unsupported file version: {version}")
|
||||
|
||||
logger.debug("Opened encrypted file for reading: %s", path)
|
||||
|
||||
def read_chunks(self) -> Iterator[bytes]:
|
||||
"""Yield decrypted audio chunks."""
|
||||
if self._handle is None or self._dek is None:
|
||||
raise RuntimeError("File not open")
|
||||
|
||||
while True:
|
||||
# Read chunk length
|
||||
length_bytes = self._handle.read(4)
|
||||
if len(length_bytes) < 4:
|
||||
break # End of file
|
||||
|
||||
chunk_length = struct.unpack(">I", length_bytes)[0]
|
||||
|
||||
# Read chunk data
|
||||
chunk_data = self._handle.read(chunk_length)
|
||||
if len(chunk_data) < chunk_length:
|
||||
raise ValueError("Truncated chunk")
|
||||
|
||||
# Parse chunk (nonce + ciphertext + tag)
|
||||
nonce = chunk_data[:NONCE_SIZE]
|
||||
ciphertext = chunk_data[NONCE_SIZE:-TAG_SIZE]
|
||||
tag = chunk_data[-TAG_SIZE:]
|
||||
|
||||
chunk = EncryptedChunk(nonce=nonce, ciphertext=ciphertext, tag=tag)
|
||||
|
||||
# Decrypt and yield
|
||||
yield self._crypto.decrypt_chunk(chunk, self._dek)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the file."""
|
||||
if self._handle is not None:
|
||||
self._handle.close()
|
||||
self._handle = None
|
||||
logger.debug("Closed encrypted file")
|
||||
|
||||
self._dek = None
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
"""Check if file is open for reading."""
|
||||
return self._handle is not None
|
||||
@@ -1,305 +0,0 @@
|
||||
"""Interactive encryption demo for Spike 4.
|
||||
|
||||
Run with: python -m spikes.spike_04_encryption.demo
|
||||
|
||||
Features:
|
||||
- Creates/retrieves master key from OS keychain
|
||||
- Generates and wraps/unwraps DEKs
|
||||
- Encrypts a sample file in chunks
|
||||
- Decrypts and verifies integrity
|
||||
- Demonstrates DEK deletion renders file unreadable
|
||||
- Reports encryption/decryption throughput
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import secrets
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from .crypto_impl import AesGcmCryptoBox, ChunkedAssetReader, ChunkedAssetWriter
|
||||
from .keystore_impl import InMemoryKeyStore, KeyringKeyStore
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def format_size(size_bytes: float) -> str:
|
||||
"""Format byte size as human-readable string."""
|
||||
current_size: float = size_bytes
|
||||
for unit in ["B", "KB", "MB", "GB"]:
|
||||
if current_size < 1024:
|
||||
return f"{current_size:.2f} {unit}"
|
||||
current_size /= 1024
|
||||
return f"{current_size:.2f} TB"
|
||||
|
||||
|
||||
def format_speed(bytes_per_sec: float) -> str:
|
||||
"""Format speed as human-readable string."""
|
||||
return f"{format_size(int(bytes_per_sec))}/s"
|
||||
|
||||
|
||||
class EncryptionDemo:
|
||||
"""Interactive encryption demonstration."""
|
||||
|
||||
def __init__(self, use_keyring: bool = False) -> None:
|
||||
"""Initialize the demo.
|
||||
|
||||
Args:
|
||||
use_keyring: If True, use OS keyring; otherwise use in-memory storage.
|
||||
"""
|
||||
if use_keyring:
|
||||
self.keystore = KeyringKeyStore(service_name="noteflow-demo")
|
||||
print("Using OS keyring for key storage")
|
||||
else:
|
||||
self.keystore = InMemoryKeyStore()
|
||||
print("Using in-memory key storage (keys lost on exit)")
|
||||
|
||||
self.crypto = AesGcmCryptoBox(self.keystore)
|
||||
|
||||
def demo_key_storage(self) -> None:
|
||||
"""Demonstrate key storage operations."""
|
||||
print("\n=== Key Storage Demo ===")
|
||||
|
||||
# Check if key exists
|
||||
has_key = self.keystore.has_master_key()
|
||||
print(f"Master key exists: {has_key}")
|
||||
|
||||
# Get or create key
|
||||
print("Getting/creating master key...")
|
||||
start = time.perf_counter()
|
||||
key = self.keystore.get_or_create_master_key()
|
||||
elapsed = time.perf_counter() - start
|
||||
print(f" Key retrieved in {elapsed * 1000:.2f}ms")
|
||||
print(f" Key size: {len(key)} bytes ({len(key) * 8} bits)")
|
||||
|
||||
# Verify same key is returned
|
||||
key2 = self.keystore.get_or_create_master_key()
|
||||
print(f" Same key returned: {key == key2}")
|
||||
|
||||
def demo_dek_operations(self) -> None:
|
||||
"""Demonstrate DEK generation and wrapping."""
|
||||
print("\n=== DEK Operations Demo ===")
|
||||
|
||||
# Generate DEK
|
||||
print("Generating DEK...")
|
||||
dek = self.crypto.generate_dek()
|
||||
print(f" DEK size: {len(dek)} bytes")
|
||||
|
||||
# Wrap DEK
|
||||
print("Wrapping DEK with master key...")
|
||||
start = time.perf_counter()
|
||||
wrapped = self.crypto.wrap_dek(dek)
|
||||
wrap_time = time.perf_counter() - start
|
||||
print(f" Wrapped DEK size: {len(wrapped)} bytes")
|
||||
print(f" Wrap time: {wrap_time * 1000:.3f}ms")
|
||||
|
||||
# Unwrap DEK
|
||||
print("Unwrapping DEK...")
|
||||
start = time.perf_counter()
|
||||
unwrapped = self.crypto.unwrap_dek(wrapped)
|
||||
unwrap_time = time.perf_counter() - start
|
||||
print(f" Unwrap time: {unwrap_time * 1000:.3f}ms")
|
||||
print(f" DEK matches original: {dek == unwrapped}")
|
||||
|
||||
def demo_chunk_encryption(self, chunk_size: int = 16384) -> None:
|
||||
"""Demonstrate chunk encryption/decryption."""
|
||||
print("\n=== Chunk Encryption Demo ===")
|
||||
|
||||
dek = self.crypto.generate_dek()
|
||||
plaintext = secrets.token_bytes(chunk_size)
|
||||
|
||||
print(f"Encrypting {format_size(chunk_size)} chunk...")
|
||||
start = time.perf_counter()
|
||||
chunk = self.crypto.encrypt_chunk(plaintext, dek)
|
||||
encrypt_time = time.perf_counter() - start
|
||||
|
||||
overhead = len(chunk.nonce) + len(chunk.tag)
|
||||
print(f" Nonce size: {len(chunk.nonce)} bytes")
|
||||
print(f" Ciphertext size: {len(chunk.ciphertext)} bytes")
|
||||
print(f" Tag size: {len(chunk.tag)} bytes")
|
||||
print(f" Overhead: {overhead} bytes ({overhead / float(chunk_size) * 100:.1f}%)")
|
||||
print(f" Encrypt time: {encrypt_time * 1000:.3f}ms")
|
||||
print(f" Throughput: {format_speed(chunk_size / encrypt_time)}")
|
||||
|
||||
print("Decrypting chunk...")
|
||||
start = time.perf_counter()
|
||||
decrypted = self.crypto.decrypt_chunk(chunk, dek)
|
||||
decrypt_time = time.perf_counter() - start
|
||||
print(f" Decrypt time: {decrypt_time * 1000:.3f}ms")
|
||||
print(f" Throughput: {format_speed(chunk_size / decrypt_time)}")
|
||||
print(f" Data matches: {plaintext == decrypted}")
|
||||
|
||||
def demo_file_encryption(
|
||||
self,
|
||||
output_path: Path,
|
||||
total_size: int = 1024 * 1024, # 1MB
|
||||
chunk_size: int = 16384, # 16KB
|
||||
) -> tuple[bytes, list[bytes]]:
|
||||
"""Demonstrate file encryption and return the DEK and chunks.
|
||||
|
||||
Args:
|
||||
output_path: Path to write encrypted file.
|
||||
total_size: Total data size to encrypt.
|
||||
chunk_size: Size of each chunk.
|
||||
|
||||
Returns:
|
||||
Tuple of (DEK used for encryption, list of original chunks).
|
||||
"""
|
||||
print(f"\n=== File Encryption Demo ({format_size(total_size)}) ===")
|
||||
|
||||
dek = self.crypto.generate_dek()
|
||||
writer = ChunkedAssetWriter(self.crypto)
|
||||
|
||||
# Generate test data
|
||||
print("Generating test data...")
|
||||
chunks = []
|
||||
remaining = total_size
|
||||
while remaining > 0:
|
||||
size = min(chunk_size, remaining)
|
||||
chunks.append(secrets.token_bytes(size))
|
||||
remaining -= size
|
||||
|
||||
print(f"Writing {len(chunks)} chunks to {output_path}...")
|
||||
start = time.perf_counter()
|
||||
|
||||
writer.open(output_path, dek)
|
||||
for chunk in chunks:
|
||||
writer.write_chunk(chunk)
|
||||
writer.close()
|
||||
|
||||
elapsed = time.perf_counter() - start
|
||||
file_size = output_path.stat().st_size
|
||||
|
||||
print(f" File size: {format_size(file_size)}")
|
||||
print(f" Overhead: {format_size(file_size - total_size)} ({(file_size / total_size - 1) * 100:.1f}%)")
|
||||
print(f" Time: {elapsed:.3f}s")
|
||||
print(f" Throughput: {format_speed(total_size / float(elapsed))}")
|
||||
|
||||
return dek, chunks
|
||||
|
||||
def demo_file_decryption(
|
||||
self,
|
||||
input_path: Path,
|
||||
dek: bytes,
|
||||
original_chunks: list[bytes],
|
||||
) -> None:
|
||||
"""Demonstrate file decryption.
|
||||
|
||||
Args:
|
||||
input_path: Path to encrypted file.
|
||||
dek: DEK used for encryption.
|
||||
original_chunks: Original plaintext chunks for verification.
|
||||
"""
|
||||
print("\n=== File Decryption Demo ===")
|
||||
|
||||
reader = ChunkedAssetReader(self.crypto)
|
||||
|
||||
print(f"Reading from {input_path}...")
|
||||
start = time.perf_counter()
|
||||
|
||||
reader.open(input_path, dek)
|
||||
decrypted_chunks = list(reader.read_chunks())
|
||||
reader.close()
|
||||
|
||||
elapsed = time.perf_counter() - start
|
||||
total_size = sum(len(c) for c in decrypted_chunks)
|
||||
|
||||
print(f" Chunks read: {len(decrypted_chunks)}")
|
||||
print(f" Total data: {format_size(total_size)}")
|
||||
print(f" Time: {elapsed:.3f}s")
|
||||
print(f" Throughput: {format_speed(total_size / elapsed)}")
|
||||
|
||||
# Verify integrity
|
||||
if len(decrypted_chunks) != len(original_chunks):
|
||||
print(" INTEGRITY FAIL: chunk count mismatch")
|
||||
else:
|
||||
all_match = all(d == o for d, o in zip(decrypted_chunks, original_chunks, strict=True))
|
||||
print(f" Integrity verified: {all_match}")
|
||||
|
||||
def demo_dek_deletion(self, input_path: Path, dek: bytes) -> None:
|
||||
"""Demonstrate that deleting DEK renders file unreadable."""
|
||||
print("\n=== DEK Deletion Demo ===")
|
||||
|
||||
print("Attempting to read file with correct DEK...")
|
||||
reader = ChunkedAssetReader(self.crypto)
|
||||
reader.open(input_path, dek)
|
||||
first_chunk = next(reader.read_chunks())
|
||||
reader.close()
|
||||
print(f" Success: read {format_size(len(first_chunk))}")
|
||||
|
||||
print("\nSimulating DEK deletion (using wrong key)...")
|
||||
wrong_dek = secrets.token_bytes(32)
|
||||
|
||||
reader = ChunkedAssetReader(self.crypto)
|
||||
reader.open(input_path, wrong_dek)
|
||||
|
||||
try:
|
||||
list(reader.read_chunks())
|
||||
print(" FAIL: Should have raised error!")
|
||||
except ValueError as e:
|
||||
print(" Success: Decryption failed as expected")
|
||||
print(f" Error: {e}")
|
||||
finally:
|
||||
reader.close()
|
||||
|
||||
def run(self, output_path: Path) -> None:
|
||||
"""Run all demos."""
|
||||
print("=" * 60)
|
||||
print("NoteFlow Encryption Demo - Spike 4")
|
||||
print("=" * 60)
|
||||
|
||||
self.demo_key_storage()
|
||||
self.demo_dek_operations()
|
||||
self.demo_chunk_encryption()
|
||||
|
||||
dek, chunks = self.demo_file_encryption(output_path)
|
||||
self.demo_file_decryption(output_path, dek, chunks)
|
||||
self.demo_dek_deletion(output_path, dek)
|
||||
|
||||
# Cleanup
|
||||
print("\n=== Cleanup ===")
|
||||
if output_path.exists():
|
||||
output_path.unlink()
|
||||
print(f"Deleted test file: {output_path}")
|
||||
|
||||
print("\nDemo complete!")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Run the encryption demo."""
|
||||
parser = argparse.ArgumentParser(description="Encryption Demo - Spike 4")
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output",
|
||||
type=Path,
|
||||
default=Path("demo_encrypted.bin"),
|
||||
help="Output file path for encryption demo (default: demo_encrypted.bin)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-k",
|
||||
"--keyring",
|
||||
action="store_true",
|
||||
help="Use OS keyring instead of in-memory key storage",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--size",
|
||||
type=int,
|
||||
default=1024 * 1024,
|
||||
help="Total data size to encrypt in bytes (default: 1MB)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
demo = EncryptionDemo(use_keyring=args.keyring)
|
||||
demo.run(args.output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,135 +0,0 @@
|
||||
"""Keystore implementation using the keyring library.
|
||||
|
||||
Provides secure master key storage using OS credential stores.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import secrets
|
||||
from typing import Final
|
||||
|
||||
import keyring
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
KEY_SIZE: Final[int] = 32 # 256-bit key
|
||||
SERVICE_NAME: Final[str] = "noteflow"
|
||||
KEY_NAME: Final[str] = "master_key"
|
||||
|
||||
|
||||
class KeyringKeyStore:
|
||||
"""keyring-based key storage using OS credential store.
|
||||
|
||||
Uses:
|
||||
- macOS: Keychain
|
||||
- Windows: Credential Manager
|
||||
- Linux: SecretService (GNOME Keyring, KWallet)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
service_name: str = SERVICE_NAME,
|
||||
key_name: str = KEY_NAME,
|
||||
) -> None:
|
||||
"""Initialize the keystore.
|
||||
|
||||
Args:
|
||||
service_name: Service identifier for keyring.
|
||||
key_name: Key identifier within the service.
|
||||
"""
|
||||
self._service_name = service_name
|
||||
self._key_name = key_name
|
||||
|
||||
def get_or_create_master_key(self) -> bytes:
|
||||
"""Retrieve or generate the master encryption key.
|
||||
|
||||
Returns:
|
||||
32-byte master key.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If keychain is unavailable.
|
||||
"""
|
||||
try:
|
||||
# Try to retrieve existing key
|
||||
stored = keyring.get_password(self._service_name, self._key_name)
|
||||
if stored is not None:
|
||||
logger.debug("Retrieved existing master key")
|
||||
return base64.b64decode(stored)
|
||||
|
||||
# Generate new key
|
||||
new_key = secrets.token_bytes(KEY_SIZE)
|
||||
encoded = base64.b64encode(new_key).decode("ascii")
|
||||
|
||||
# Store in keyring
|
||||
keyring.set_password(self._service_name, self._key_name, encoded)
|
||||
logger.info("Generated and stored new master key")
|
||||
return new_key
|
||||
|
||||
except keyring.errors.KeyringError as e:
|
||||
raise RuntimeError(f"Keyring unavailable: {e}") from e
|
||||
|
||||
def delete_master_key(self) -> None:
|
||||
"""Delete the master key from the keychain.
|
||||
|
||||
Safe to call if key doesn't exist.
|
||||
"""
|
||||
try:
|
||||
keyring.delete_password(self._service_name, self._key_name)
|
||||
logger.info("Deleted master key")
|
||||
except keyring.errors.PasswordDeleteError:
|
||||
# Key doesn't exist, that's fine
|
||||
logger.debug("Master key not found, nothing to delete")
|
||||
except keyring.errors.KeyringError as e:
|
||||
logger.warning("Failed to delete master key: %s", e)
|
||||
|
||||
def has_master_key(self) -> bool:
|
||||
"""Check if master key exists in the keychain.
|
||||
|
||||
Returns:
|
||||
True if master key exists.
|
||||
"""
|
||||
try:
|
||||
stored = keyring.get_password(self._service_name, self._key_name)
|
||||
return stored is not None
|
||||
except keyring.errors.KeyringError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def service_name(self) -> str:
|
||||
"""Get the service name used for keyring."""
|
||||
return self._service_name
|
||||
|
||||
@property
|
||||
def key_name(self) -> str:
|
||||
"""Get the key name used for keyring."""
|
||||
return self._key_name
|
||||
|
||||
|
||||
class InMemoryKeyStore:
|
||||
"""In-memory key storage for testing.
|
||||
|
||||
Keys are lost when the process exits.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the in-memory keystore."""
|
||||
self._key: bytes | None = None
|
||||
|
||||
def get_or_create_master_key(self) -> bytes:
|
||||
"""Retrieve or generate the master encryption key."""
|
||||
if self._key is None:
|
||||
self._key = secrets.token_bytes(KEY_SIZE)
|
||||
logger.debug("Generated in-memory master key")
|
||||
return self._key
|
||||
|
||||
def delete_master_key(self) -> None:
|
||||
"""Delete the master key."""
|
||||
self._key = None
|
||||
logger.debug("Deleted in-memory master key")
|
||||
|
||||
def has_master_key(self) -> bool:
|
||||
"""Check if master key exists."""
|
||||
return self._key is not None
|
||||
@@ -1,221 +0,0 @@
|
||||
"""Encryption protocols and data types for Spike 4.
|
||||
|
||||
These protocols define the contracts for key storage and encryption components
|
||||
that will be promoted to src/noteflow/crypto/ after validation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EncryptedChunk:
|
||||
"""An encrypted chunk of data with authentication tag."""
|
||||
|
||||
nonce: bytes # Unique nonce for this chunk
|
||||
ciphertext: bytes # Encrypted data
|
||||
tag: bytes # Authentication tag
|
||||
|
||||
|
||||
class KeyStore(Protocol):
|
||||
"""Protocol for OS keychain access.
|
||||
|
||||
Implementations should use the OS credential store (Keychain, Credential Manager)
|
||||
to securely store the master encryption key.
|
||||
"""
|
||||
|
||||
def get_or_create_master_key(self) -> bytes:
|
||||
"""Retrieve or generate the master encryption key.
|
||||
|
||||
If the master key doesn't exist, generates a new 32-byte key
|
||||
and stores it in the OS keychain.
|
||||
|
||||
Returns:
|
||||
32-byte master key.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If keychain is unavailable or locked.
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_master_key(self) -> None:
|
||||
"""Delete the master key from the keychain.
|
||||
|
||||
This renders all encrypted data permanently unrecoverable.
|
||||
|
||||
Safe to call if key doesn't exist.
|
||||
"""
|
||||
...
|
||||
|
||||
def has_master_key(self) -> bool:
|
||||
"""Check if master key exists in the keychain.
|
||||
|
||||
Returns:
|
||||
True if master key exists.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class CryptoBox(Protocol):
|
||||
"""Protocol for envelope encryption with per-meeting keys.
|
||||
|
||||
Uses a master key to wrap/unwrap Data Encryption Keys (DEKs),
|
||||
which are used to encrypt actual meeting data.
|
||||
"""
|
||||
|
||||
def generate_dek(self) -> bytes:
|
||||
"""Generate a new Data Encryption Key.
|
||||
|
||||
Returns:
|
||||
32-byte random DEK.
|
||||
"""
|
||||
...
|
||||
|
||||
def wrap_dek(self, dek: bytes) -> bytes:
|
||||
"""Encrypt DEK with master key.
|
||||
|
||||
Args:
|
||||
dek: Data Encryption Key to wrap.
|
||||
|
||||
Returns:
|
||||
Encrypted DEK (can be stored in DB).
|
||||
"""
|
||||
...
|
||||
|
||||
def unwrap_dek(self, wrapped_dek: bytes) -> bytes:
|
||||
"""Decrypt DEK with master key.
|
||||
|
||||
Args:
|
||||
wrapped_dek: Encrypted DEK from wrap_dek().
|
||||
|
||||
Returns:
|
||||
Original DEK.
|
||||
|
||||
Raises:
|
||||
ValueError: If decryption fails (invalid or tampered).
|
||||
"""
|
||||
...
|
||||
|
||||
def encrypt_chunk(self, plaintext: bytes, dek: bytes) -> EncryptedChunk:
|
||||
"""Encrypt a chunk of data with AES-GCM.
|
||||
|
||||
Args:
|
||||
plaintext: Data to encrypt.
|
||||
dek: Data Encryption Key.
|
||||
|
||||
Returns:
|
||||
EncryptedChunk with nonce, ciphertext, and tag.
|
||||
"""
|
||||
...
|
||||
|
||||
def decrypt_chunk(self, chunk: EncryptedChunk, dek: bytes) -> bytes:
|
||||
"""Decrypt a chunk of data.
|
||||
|
||||
Args:
|
||||
chunk: EncryptedChunk to decrypt.
|
||||
dek: Data Encryption Key.
|
||||
|
||||
Returns:
|
||||
Original plaintext.
|
||||
|
||||
Raises:
|
||||
ValueError: If decryption fails (invalid or tampered).
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class EncryptedAssetWriter(Protocol):
|
||||
"""Protocol for streaming encrypted audio writer.
|
||||
|
||||
Writes audio chunks encrypted with a DEK to a file.
|
||||
"""
|
||||
|
||||
def open(self, path: Path, dek: bytes) -> None:
|
||||
"""Open file for writing.
|
||||
|
||||
Args:
|
||||
path: Path to the encrypted file.
|
||||
dek: Data Encryption Key for this file.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If already open.
|
||||
OSError: If file cannot be created.
|
||||
"""
|
||||
...
|
||||
|
||||
def write_chunk(self, audio_bytes: bytes) -> None:
|
||||
"""Write and encrypt an audio chunk.
|
||||
|
||||
Args:
|
||||
audio_bytes: Raw audio data to encrypt and write.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If not open.
|
||||
"""
|
||||
...
|
||||
|
||||
def close(self) -> None:
|
||||
"""Finalize and close the file.
|
||||
|
||||
Safe to call if already closed.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
"""Check if file is open for writing."""
|
||||
...
|
||||
|
||||
@property
|
||||
def bytes_written(self) -> int:
|
||||
"""Total encrypted bytes written."""
|
||||
...
|
||||
|
||||
|
||||
class EncryptedAssetReader(Protocol):
|
||||
"""Protocol for streaming encrypted audio reader.
|
||||
|
||||
Reads and decrypts audio chunks from a file.
|
||||
"""
|
||||
|
||||
def open(self, path: Path, dek: bytes) -> None:
|
||||
"""Open file for reading.
|
||||
|
||||
Args:
|
||||
path: Path to the encrypted file.
|
||||
dek: Data Encryption Key for this file.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If already open.
|
||||
OSError: If file cannot be read.
|
||||
ValueError: If file format is invalid.
|
||||
"""
|
||||
...
|
||||
|
||||
def read_chunks(self) -> Iterator[bytes]:
|
||||
"""Yield decrypted audio chunks.
|
||||
|
||||
Yields:
|
||||
Decrypted audio data chunks.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If not open.
|
||||
ValueError: If decryption fails.
|
||||
"""
|
||||
...
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the file.
|
||||
|
||||
Safe to call if already closed.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
"""Check if file is open for reading."""
|
||||
...
|
||||
Reference in New Issue
Block a user