Add summarization and trigger services
- Introduced `SummarizationService` and `TriggerService` to orchestrate summarization and trigger detection functionalities. - Added new modules for summarization, including citation verification and cloud-based summarization providers. - Implemented trigger detection based on audio activity and foreground application status. - Updated project configuration to include new dependencies for summarization and trigger functionalities. - Created tests for summarization and trigger services to ensure functionality and reliability.
This commit is contained in:
103
CLAUDE.md
Normal file
103
CLAUDE.md
Normal file
@@ -0,0 +1,103 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
NoteFlow is an intelligent meeting notetaker: local-first audio capture + navigable recall + evidence-linked summaries. Client-server architecture using gRPC for bidirectional audio streaming and transcription.
|
||||
|
||||
## Build and Development Commands
|
||||
|
||||
```bash
|
||||
# Install (editable with dev dependencies)
|
||||
python -m pip install -e ".[dev]"
|
||||
|
||||
# Run gRPC server
|
||||
python -m noteflow.grpc.server --help
|
||||
|
||||
# Run Flet client UI
|
||||
python -m noteflow.client.app --help
|
||||
|
||||
# Tests
|
||||
pytest # Full suite
|
||||
pytest -m "not integration" # Skip external-service tests
|
||||
pytest tests/domain/ # Run specific test directory
|
||||
pytest -k "test_segment" # Run by pattern
|
||||
|
||||
# Linting and type checking
|
||||
ruff check . # Lint
|
||||
ruff check --fix . # Autofix
|
||||
mypy src/noteflow # Strict type checks
|
||||
basedpyright # Additional type checks
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
src/noteflow/
|
||||
├── domain/ # Entities (meeting, segment, annotation, summary) + ports (repository interfaces)
|
||||
├── application/ # Use-cases/services (MeetingService, RecoveryService, ExportService)
|
||||
├── infrastructure/ # Implementations
|
||||
│ ├── audio/ # sounddevice capture, ring buffer, VU levels, playback
|
||||
│ ├── asr/ # faster-whisper engine, VAD segmenter, streaming
|
||||
│ ├── persistence/ # SQLAlchemy + asyncpg + pgvector, Alembic migrations
|
||||
│ ├── security/ # keyring keystore, AES-GCM encryption
|
||||
│ ├── export/ # Markdown/HTML export
|
||||
│ └── converters/ # ORM ↔ domain entity converters
|
||||
├── grpc/ # Proto definitions, server, client, meeting store
|
||||
├── client/ # Flet UI app + components (transcript, VU meter, playback)
|
||||
└── config/ # Pydantic settings (NOTEFLOW_ env vars)
|
||||
```
|
||||
|
||||
**Key patterns:**
|
||||
- Hexagonal architecture: domain → application → infrastructure
|
||||
- Repository pattern with Unit of Work (`SQLAlchemyUnitOfWork`)
|
||||
- gRPC bidirectional streaming for audio → transcript flow
|
||||
- Protocol-based DI (see `domain/ports/` and infrastructure `protocols.py` files)
|
||||
|
||||
## Database
|
||||
|
||||
PostgreSQL with pgvector extension. Async SQLAlchemy with asyncpg driver.
|
||||
|
||||
```bash
|
||||
# Alembic migrations
|
||||
alembic upgrade head
|
||||
alembic revision --autogenerate -m "description"
|
||||
```
|
||||
|
||||
Connection via `NOTEFLOW_DATABASE_URL` env var or settings.
|
||||
|
||||
## Testing Conventions
|
||||
|
||||
- Test files: `test_*.py`, functions: `test_*`
|
||||
- Markers: `@pytest.mark.slow` (model loading), `@pytest.mark.integration` (external services)
|
||||
- Integration tests use testcontainers for PostgreSQL
|
||||
- Asyncio auto-mode enabled
|
||||
|
||||
## Proto/gRPC
|
||||
|
||||
Proto definitions: `src/noteflow/grpc/proto/noteflow.proto`
|
||||
Generated files excluded from lint: `*_pb2.py`, `*_pb2_grpc.py`
|
||||
|
||||
Regenerate after proto changes:
|
||||
```bash
|
||||
python -m grpc_tools.protoc -I src/noteflow/grpc/proto \
|
||||
--python_out=src/noteflow/grpc/proto \
|
||||
--grpc_python_out=src/noteflow/grpc/proto \
|
||||
src/noteflow/grpc/proto/noteflow.proto
|
||||
```
|
||||
|
||||
## Code Style
|
||||
|
||||
- Python 3.12+, 100-char line length
|
||||
- Strict mypy (allow `type: ignore[code]` only with comment explaining why)
|
||||
- Ruff for linting (E, W, F, I, B, C4, UP, SIM, RUF)
|
||||
- Module soft limit 500 LoC, hard limit 750 LoC
|
||||
|
||||
## Spikes (De-risking Experiments)
|
||||
|
||||
`spikes/` contains validated platform experiments with `FINDINGS.md`:
|
||||
- `spike_01_ui_tray_hotkeys/` - Flet + pystray + pynput (requires X11)
|
||||
- `spike_02_audio_capture/` - sounddevice + PortAudio
|
||||
- `spike_03_asr_latency/` - faster-whisper benchmarks (0.05x real-time)
|
||||
- `spike_04_encryption/` - keyring + AES-GCM (826 MB/s throughput)
|
||||
@@ -94,6 +94,32 @@ I’m writing this so engineering can start building without re‑interpreting p
|
||||
* Final segments persisted to DB
|
||||
* Post-meeting transcript view
|
||||
|
||||
**Current status:**
|
||||
|
||||
* Final segments are emitted and persisted; partial updates are not yet produced.
|
||||
|
||||
**Implementation plan (add partials end-to-end):**
|
||||
|
||||
* ASR layer:
|
||||
* Extend ASR engine interface to surface partial hypotheses at a fixed cadence
|
||||
(e.g., every N seconds or on each VAD speech chunk).
|
||||
* Add a lightweight streaming mode for faster-whisper (or a buffering strategy
|
||||
that returns interim text from recent audio while finalization waits for
|
||||
silence).
|
||||
* Ensure partial outputs include a stable `segment_id=0` (or temporary ID)
|
||||
and do not persist to DB.
|
||||
* Server:
|
||||
* Emit `UPDATE_TYPE_PARTIAL` messages from the ASR loop on cadence.
|
||||
* Debounce partial updates to avoid UI churn and bandwidth spikes.
|
||||
* Keep final segment emission unchanged; partials must be overwritten by finals.
|
||||
* Client/UI:
|
||||
* Render a single “live partial” row at the bottom of the transcript list
|
||||
(grey text), replaced in-place on each partial update.
|
||||
* Drop partials on stop or on first final segment after a partial.
|
||||
* Tests:
|
||||
* Unit tests for partial cadence and suppression of partial persistence.
|
||||
* Integration test that partials appear before finals and are cleared on final.
|
||||
|
||||
**Exit criteria:**
|
||||
|
||||
* Live view shows partial text that settles into final segments.
|
||||
@@ -112,6 +138,12 @@ I’m writing this so engineering can start building without re‑interpreting p
|
||||
* Export: Markdown + HTML
|
||||
* Meeting library list + per-meeting search
|
||||
|
||||
**Gaps to close in this milestone:**
|
||||
|
||||
* Wire meeting library into the main UI and selection flow.
|
||||
* Add per-meeting transcript search (client-side filter is acceptable for V1).
|
||||
* Add `risk` annotation type end-to-end (domain enum, UI, persistence).
|
||||
|
||||
**Exit criteria:**
|
||||
|
||||
* Clicking a segment seeks audio playback to that time.
|
||||
@@ -132,6 +164,12 @@ I’m writing this so engineering can start building without re‑interpreting p
|
||||
* Prompt notification + snooze + suppress per-app
|
||||
* Settings for sensitivity and auto-start opt-in
|
||||
|
||||
**Deferred to a later, tray/hotkey-focused milestone:**
|
||||
|
||||
* Trigger prompts that include per-app suppression, calendar stubs, and
|
||||
snooze presets integrated with tray/menubar UX.
|
||||
* Persistent “recording/monitoring” indicator when background capture is active.
|
||||
|
||||
**Exit criteria:**
|
||||
|
||||
* Trigger prompts happen when expected and can be snoozed.
|
||||
@@ -153,6 +191,27 @@ I’m writing this so engineering can start building without re‑interpreting p
|
||||
* Citation verifier + “uncited drafts” handling
|
||||
* Summary UI panel with clickable citations
|
||||
|
||||
**Implementation plan (citations enforced):**
|
||||
|
||||
* Summarizer provider interface:
|
||||
* Define `Summarizer` protocol with `extract()` and `synthesize()` phases.
|
||||
* Provide `MockSummarizer` for tests and a cloud-backed provider behind opt-in.
|
||||
* Extraction stage:
|
||||
* Segment-aware chunking (~500 tokens) with stable `segment_ids` in each chunk.
|
||||
* Extraction prompt returns structured items: quote, segment_ids, category.
|
||||
* Synthesis stage:
|
||||
* Rewrite extracted items into bullets; each bullet must end with
|
||||
`[...]` containing segment IDs.
|
||||
* Verification stage:
|
||||
* Parse bullets; suppress any uncited bullets by default.
|
||||
* Store uncited drafts separately for optional user review.
|
||||
* UI:
|
||||
* Summary panel lists key points + action items with clickable citations.
|
||||
* Clicking a bullet scrolls transcript and seeks audio to the first segment.
|
||||
* Tests:
|
||||
* Unit tests for citation parsing, uncited suppression, and click→segment mapping.
|
||||
* Integration test for summary generation request and persisted citations.
|
||||
|
||||
**Exit criteria:**
|
||||
|
||||
* Every displayed bullet has citations.
|
||||
@@ -173,6 +232,19 @@ I’m writing this so engineering can start building without re‑interpreting p
|
||||
* “Check for updates” flow (manual link + version display)
|
||||
* Release checklist & troubleshooting docs
|
||||
|
||||
**Implementation plan (delete/retention correctness):**
|
||||
|
||||
* Meeting deletion:
|
||||
* Extend delete flow to remove encrypted audio assets on disk.
|
||||
* Delete wrapped DEK and master key references so audio cannot be decrypted.
|
||||
* Add best-effort cleanup for orphaned files on next startup.
|
||||
* Retention:
|
||||
* Scheduled job that deletes meetings older than retention days.
|
||||
* Include DB rows, summaries, and audio assets in the purge.
|
||||
* Tests:
|
||||
* Integration test that delete removes DB rows + audio file path.
|
||||
* Integration test that retention job removes expired meetings and assets.
|
||||
|
||||
**Exit criteria:**
|
||||
|
||||
* A signed installer (or unsigned for internal) that installs and runs on both OSs.
|
||||
|
||||
18
docs/spec.md
18
docs/spec.md
@@ -1,6 +1,6 @@
|
||||
Below is a rewritten, end‑to‑end **Product Specification + Engineering Design Document** for **NoteFlow V1 (Minimum Lovable Product)** that merges:
|
||||
|
||||
* your **revised V1 draft** (confidence-model triggers, single-process, partial/final UX, extract‑then‑synthesize citations, pragmatic typing, packaging constraints, risks table), and
|
||||
* your **revised V1 draft** (confidence-model triggers, client/server architecture, partial/final UX, extract‑then‑synthesize citations, pragmatic typing, packaging constraints, risks table), and
|
||||
* the **de-risking feedback** I gave earlier (audio capture reality, diarization scope, citation enforcement, OS permissions, shipping concerns, storage/retention, update strategy, and “don’t promise what you can’t reliably ship”).
|
||||
|
||||
I’ve kept it “shipping-ready” by being explicit about decisions, failure modes, acceptance criteria, and what is deferred.
|
||||
@@ -292,7 +292,9 @@ The system is split into two components that can run on the same machine or sepa
|
||||
**Server (Headless Backend)**
|
||||
* **ASR Engine:** faster-whisper for transcription
|
||||
* **Meeting Store:** in-memory meeting management
|
||||
* **Storage:** LanceDB for persistence + encrypted audio assets
|
||||
* **Storage:** PostgreSQL + pgvector for persistence + encrypted audio assets
|
||||
(current implementation). LanceDB is supported as an optional adapter for
|
||||
local-only deployments in single-process mode.
|
||||
* **gRPC Service:** bidirectional streaming for real-time transcription
|
||||
|
||||
**Client (GUI Application)**
|
||||
@@ -310,6 +312,8 @@ The system is split into two components that can run on the same machine or sepa
|
||||
**Deployment modes:**
|
||||
1. **Local:** Server + Client on same machine (default)
|
||||
2. **Split:** Server on headless machine, Client on workstation with audio
|
||||
3. **Local-only adapter:** Optional LanceDB-backed, single-process mode
|
||||
for development or constrained environments (feature-parity not guaranteed).
|
||||
|
||||
---
|
||||
|
||||
@@ -427,11 +431,17 @@ Supported provider modes:
|
||||
|
||||
## 9. Storage & Data Model
|
||||
|
||||
**Backend support:** The reference implementation uses PostgreSQL + pgvector.
|
||||
LanceDB is supported as an optional adapter for local-only, single-process
|
||||
deployments. The schema below describes the logical model and should be mapped
|
||||
to either backend.
|
||||
|
||||
### 9.1 On-Disk Layout (Per User)
|
||||
|
||||
* App data directory (OS standard)
|
||||
|
||||
* `db/` (LanceDB)
|
||||
* `db/` (PostgreSQL + pgvector)
|
||||
* `lancedb/` (optional local-only adapter)
|
||||
* `meetings/<meeting_id>/`
|
||||
|
||||
* `audio.<ext>` (encrypted container)
|
||||
@@ -439,7 +449,7 @@ Supported provider modes:
|
||||
* `logs/` (rotating; content-free)
|
||||
* `settings.json`
|
||||
|
||||
### 9.2 Database Schema (LanceDB)
|
||||
### 9.2 Database Schema (PostgreSQL baseline)
|
||||
|
||||
Core tables:
|
||||
|
||||
|
||||
12093
logs/status_line.json
12093
logs/status_line.json
File diff suppressed because it is too large
Load Diff
@@ -44,6 +44,14 @@ dev = [
|
||||
"basedpyright>=1.18",
|
||||
"testcontainers[postgres]>=4.0",
|
||||
]
|
||||
triggers = [
|
||||
"pywinctl>=0.3",
|
||||
]
|
||||
summarization = [
|
||||
"ollama>=0.6.1",
|
||||
"openai>=2.13.0",
|
||||
"anthropic>=0.75.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
|
||||
@@ -3,5 +3,23 @@
|
||||
from noteflow.application.services.export_service import ExportFormat, ExportService
|
||||
from noteflow.application.services.meeting_service import MeetingService
|
||||
from noteflow.application.services.recovery_service import RecoveryService
|
||||
from noteflow.application.services.summarization_service import (
|
||||
SummarizationMode,
|
||||
SummarizationService,
|
||||
SummarizationServiceResult,
|
||||
SummarizationServiceSettings,
|
||||
)
|
||||
from noteflow.application.services.trigger_service import TriggerService, TriggerServiceSettings
|
||||
|
||||
__all__ = ["ExportFormat", "ExportService", "MeetingService", "RecoveryService"]
|
||||
__all__ = [
|
||||
"ExportFormat",
|
||||
"ExportService",
|
||||
"MeetingService",
|
||||
"RecoveryService",
|
||||
"SummarizationMode",
|
||||
"SummarizationService",
|
||||
"SummarizationServiceResult",
|
||||
"SummarizationServiceSettings",
|
||||
"TriggerService",
|
||||
"TriggerServiceSettings",
|
||||
]
|
||||
|
||||
323
src/noteflow/application/services/summarization_service.py
Normal file
323
src/noteflow/application/services/summarization_service.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""Summarization orchestration service.
|
||||
|
||||
Coordinate provider selection, consent handling, and citation verification.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.domain.summarization import (
|
||||
CitationVerificationResult,
|
||||
ProviderUnavailableError,
|
||||
SummarizationRequest,
|
||||
SummarizationResult,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from noteflow.domain.entities import Segment, Summary
|
||||
from noteflow.domain.summarization import CitationVerifier, SummarizerProvider
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SummarizationMode(Enum):
|
||||
"""Available summarization modes."""
|
||||
|
||||
MOCK = "mock"
|
||||
LOCAL = "local" # Ollama
|
||||
CLOUD = "cloud" # OpenAI/Anthropic
|
||||
|
||||
|
||||
@dataclass
|
||||
class SummarizationServiceSettings:
|
||||
"""Configuration for summarization service.
|
||||
|
||||
Attributes:
|
||||
default_mode: Default summarization mode.
|
||||
cloud_consent_granted: Whether user has consented to cloud processing.
|
||||
fallback_to_local: Fall back to local if cloud unavailable.
|
||||
verify_citations: Whether to verify citations after summarization.
|
||||
filter_invalid_citations: Remove invalid citations from result.
|
||||
max_key_points: Default maximum key points.
|
||||
max_action_items: Default maximum action items.
|
||||
"""
|
||||
|
||||
default_mode: SummarizationMode = SummarizationMode.LOCAL
|
||||
cloud_consent_granted: bool = False
|
||||
fallback_to_local: bool = True
|
||||
verify_citations: bool = True
|
||||
filter_invalid_citations: bool = True
|
||||
max_key_points: int = 5
|
||||
max_action_items: int = 10
|
||||
|
||||
|
||||
@dataclass
|
||||
class SummarizationServiceResult:
|
||||
"""Result from summarization service.
|
||||
|
||||
Attributes:
|
||||
result: The raw summarization result from the provider.
|
||||
verification: Citation verification result (if verification enabled).
|
||||
filtered_summary: Summary with invalid citations removed (if filtering enabled).
|
||||
provider_used: Which provider was actually used.
|
||||
fallback_used: Whether a fallback provider was used.
|
||||
"""
|
||||
|
||||
result: SummarizationResult
|
||||
verification: CitationVerificationResult | None = None
|
||||
filtered_summary: Summary | None = None
|
||||
provider_used: str = ""
|
||||
fallback_used: bool = False
|
||||
|
||||
@property
|
||||
def summary(self) -> Summary:
|
||||
"""Get the best available summary (filtered if available)."""
|
||||
return self.filtered_summary or self.result.summary
|
||||
|
||||
@property
|
||||
def has_invalid_citations(self) -> bool:
|
||||
"""Check if summary has invalid citations."""
|
||||
return self.verification is not None and not self.verification.is_valid
|
||||
|
||||
|
||||
@dataclass
|
||||
class SummarizationService:
|
||||
"""Orchestrate summarization with provider selection and citation verification.
|
||||
|
||||
Manages provider selection based on mode and availability, handles
|
||||
cloud consent requirements, and verifies/filters citation integrity.
|
||||
"""
|
||||
|
||||
providers: dict[SummarizationMode, SummarizerProvider] = field(default_factory=dict)
|
||||
verifier: CitationVerifier | None = None
|
||||
settings: SummarizationServiceSettings = field(default_factory=SummarizationServiceSettings)
|
||||
|
||||
def register_provider(self, mode: SummarizationMode, provider: SummarizerProvider) -> None:
|
||||
"""Register a provider for a specific mode.
|
||||
|
||||
Args:
|
||||
mode: The mode this provider handles.
|
||||
provider: The provider implementation.
|
||||
"""
|
||||
self.providers[mode] = provider
|
||||
logger.debug("Registered %s provider: %s", mode.value, provider.provider_name)
|
||||
|
||||
def set_verifier(self, verifier: CitationVerifier) -> None:
|
||||
"""Set the citation verifier.
|
||||
|
||||
Args:
|
||||
verifier: Citation verifier implementation.
|
||||
"""
|
||||
self.verifier = verifier
|
||||
|
||||
def get_available_modes(self) -> list[SummarizationMode]:
|
||||
"""Get list of currently available summarization modes.
|
||||
|
||||
Returns:
|
||||
List of available modes based on registered providers.
|
||||
"""
|
||||
available = []
|
||||
for mode, provider in self.providers.items():
|
||||
if mode == SummarizationMode.CLOUD:
|
||||
if provider.is_available and self.settings.cloud_consent_granted:
|
||||
available.append(mode)
|
||||
elif provider.is_available:
|
||||
available.append(mode)
|
||||
return available
|
||||
|
||||
def is_mode_available(self, mode: SummarizationMode) -> bool:
|
||||
"""Check if a specific mode is available.
|
||||
|
||||
Args:
|
||||
mode: The mode to check.
|
||||
|
||||
Returns:
|
||||
True if mode is available.
|
||||
"""
|
||||
return mode in self.get_available_modes()
|
||||
|
||||
def grant_cloud_consent(self) -> None:
|
||||
"""Grant consent for cloud processing."""
|
||||
self.settings.cloud_consent_granted = True
|
||||
logger.info("Cloud consent granted")
|
||||
|
||||
def revoke_cloud_consent(self) -> None:
|
||||
"""Revoke consent for cloud processing."""
|
||||
self.settings.cloud_consent_granted = False
|
||||
logger.info("Cloud consent revoked")
|
||||
|
||||
async def summarize(
|
||||
self,
|
||||
meeting_id: MeetingId,
|
||||
segments: Sequence[Segment],
|
||||
mode: SummarizationMode | None = None,
|
||||
max_key_points: int | None = None,
|
||||
max_action_items: int | None = None,
|
||||
) -> SummarizationServiceResult:
|
||||
"""Generate evidence-linked summary for meeting transcript.
|
||||
|
||||
Args:
|
||||
meeting_id: The meeting ID.
|
||||
segments: Transcript segments to summarize.
|
||||
mode: Override default mode (None uses settings default).
|
||||
max_key_points: Override default max key points.
|
||||
max_action_items: Override default max action items.
|
||||
|
||||
Returns:
|
||||
SummarizationServiceResult with summary and verification.
|
||||
|
||||
Raises:
|
||||
SummarizationError: If summarization fails and no fallback available.
|
||||
ProviderUnavailableError: If no provider is available for the mode.
|
||||
"""
|
||||
target_mode = mode or self.settings.default_mode
|
||||
fallback_used = False
|
||||
|
||||
# Get provider, potentially with fallback
|
||||
provider, actual_mode = self._get_provider_with_fallback(target_mode)
|
||||
if actual_mode != target_mode:
|
||||
fallback_used = True
|
||||
logger.info(
|
||||
"Falling back from %s to %s mode",
|
||||
target_mode.value,
|
||||
actual_mode.value,
|
||||
)
|
||||
|
||||
# Build request
|
||||
request = SummarizationRequest(
|
||||
meeting_id=meeting_id,
|
||||
segments=segments,
|
||||
max_key_points=max_key_points or self.settings.max_key_points,
|
||||
max_action_items=max_action_items or self.settings.max_action_items,
|
||||
)
|
||||
|
||||
# Execute summarization
|
||||
logger.info(
|
||||
"Summarizing %d segments with %s provider",
|
||||
len(segments),
|
||||
provider.provider_name,
|
||||
)
|
||||
result = await provider.summarize(request)
|
||||
|
||||
# Build service result
|
||||
service_result = SummarizationServiceResult(
|
||||
result=result,
|
||||
provider_used=provider.provider_name,
|
||||
fallback_used=fallback_used,
|
||||
)
|
||||
|
||||
# Verify citations if enabled
|
||||
if self.settings.verify_citations and self.verifier is not None:
|
||||
verification = self.verifier.verify_citations(result.summary, list(segments))
|
||||
service_result.verification = verification
|
||||
|
||||
if not verification.is_valid:
|
||||
logger.warning(
|
||||
"Summary has %d invalid citations",
|
||||
verification.invalid_count,
|
||||
)
|
||||
|
||||
# Filter if enabled
|
||||
if self.settings.filter_invalid_citations:
|
||||
service_result.filtered_summary = self._filter_citations(
|
||||
result.summary, list(segments)
|
||||
)
|
||||
|
||||
return service_result
|
||||
|
||||
def _get_provider_with_fallback(
|
||||
self, mode: SummarizationMode
|
||||
) -> tuple[SummarizerProvider, SummarizationMode]:
|
||||
"""Get provider for mode, with fallback if unavailable.
|
||||
|
||||
Args:
|
||||
mode: Requested mode.
|
||||
|
||||
Returns:
|
||||
Tuple of (provider, actual_mode).
|
||||
|
||||
Raises:
|
||||
ProviderUnavailableError: If no provider available.
|
||||
"""
|
||||
# Check requested mode
|
||||
if mode in self.providers:
|
||||
provider = self.providers[mode]
|
||||
|
||||
# Check cloud consent
|
||||
if mode == SummarizationMode.CLOUD and not self.settings.cloud_consent_granted:
|
||||
logger.warning("Cloud mode requested but consent not granted")
|
||||
if self.settings.fallback_to_local:
|
||||
return self._get_fallback_provider(mode)
|
||||
raise ProviderUnavailableError("Cloud consent not granted")
|
||||
|
||||
if provider.is_available:
|
||||
return provider, mode
|
||||
|
||||
# Provider exists but unavailable
|
||||
if self.settings.fallback_to_local and mode != SummarizationMode.MOCK:
|
||||
return self._get_fallback_provider(mode)
|
||||
|
||||
raise ProviderUnavailableError(f"No provider available for mode: {mode.value}")
|
||||
|
||||
def _get_fallback_provider(
|
||||
self, original_mode: SummarizationMode
|
||||
) -> tuple[SummarizerProvider, SummarizationMode]:
|
||||
"""Get fallback provider when primary unavailable.
|
||||
|
||||
Fallback order: LOCAL -> MOCK
|
||||
|
||||
Args:
|
||||
original_mode: The mode that was unavailable.
|
||||
|
||||
Returns:
|
||||
Tuple of (provider, mode).
|
||||
|
||||
Raises:
|
||||
ProviderUnavailableError: If no fallback available.
|
||||
"""
|
||||
fallback_order = [SummarizationMode.LOCAL, SummarizationMode.MOCK]
|
||||
|
||||
for fallback_mode in fallback_order:
|
||||
if fallback_mode == original_mode:
|
||||
continue
|
||||
if fallback_mode in self.providers:
|
||||
provider = self.providers[fallback_mode]
|
||||
if provider.is_available:
|
||||
return provider, fallback_mode
|
||||
|
||||
raise ProviderUnavailableError("No fallback provider available")
|
||||
|
||||
def _filter_citations(self, summary: Summary, segments: list[Segment]) -> Summary:
|
||||
"""Filter invalid citations from summary.
|
||||
|
||||
Args:
|
||||
summary: Summary to filter.
|
||||
segments: Available segments.
|
||||
|
||||
Returns:
|
||||
Summary with invalid citations removed.
|
||||
"""
|
||||
if self.verifier is None:
|
||||
return summary
|
||||
|
||||
# Use verifier's filter method if available
|
||||
if hasattr(self.verifier, "filter_invalid_citations"):
|
||||
return self.verifier.filter_invalid_citations(summary, segments)
|
||||
|
||||
return summary
|
||||
|
||||
def set_default_mode(self, mode: SummarizationMode) -> None:
|
||||
"""Set the default summarization mode.
|
||||
|
||||
Args:
|
||||
mode: New default mode.
|
||||
"""
|
||||
self.settings.default_mode = mode
|
||||
logger.info("Default summarization mode set to %s", mode.value)
|
||||
207
src/noteflow/application/services/trigger_service.py
Normal file
207
src/noteflow/application/services/trigger_service.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""Trigger evaluation and decision service.
|
||||
|
||||
Orchestrate trigger detection with rate limiting and snooze support.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.domain.triggers.entities import TriggerAction, TriggerDecision, TriggerSignal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.domain.triggers.ports import SignalProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TriggerServiceSettings:
|
||||
"""Configuration for trigger service.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether trigger detection is enabled.
|
||||
auto_start_enabled: Whether to auto-start recording at high confidence.
|
||||
rate_limit_seconds: Minimum seconds between trigger prompts.
|
||||
snooze_seconds: Default snooze duration.
|
||||
threshold_ignore: Confidence below which triggers are ignored.
|
||||
threshold_auto_start: Confidence at or above which auto-start is allowed.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
auto_start_enabled: bool
|
||||
rate_limit_seconds: int
|
||||
snooze_seconds: int
|
||||
threshold_ignore: float
|
||||
threshold_auto_start: float
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.threshold_auto_start < self.threshold_ignore:
|
||||
msg = "threshold_auto_start must be >= threshold_ignore"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
class TriggerService:
|
||||
"""Orchestrate trigger detection with rate limiting and snooze.
|
||||
|
||||
Evaluates all signal providers and determines the appropriate action
|
||||
based on combined confidence scores, rate limits, and snooze state.
|
||||
|
||||
Threshold behavior is driven by TriggerServiceSettings:
|
||||
- Confidence < threshold_ignore: IGNORE
|
||||
- Confidence >= threshold_auto_start: AUTO_START (if enabled, else NOTIFY)
|
||||
- Otherwise: NOTIFY
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
providers: list[SignalProvider],
|
||||
settings: TriggerServiceSettings,
|
||||
) -> None:
|
||||
"""Initialize trigger service.
|
||||
|
||||
Args:
|
||||
providers: List of signal providers to evaluate.
|
||||
settings: Configuration settings for trigger behavior.
|
||||
"""
|
||||
self._providers = providers
|
||||
self._settings = settings
|
||||
self._last_prompt: float | None = None
|
||||
self._snoozed_until: float | None = None
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if trigger service is enabled."""
|
||||
return self._settings.enabled
|
||||
|
||||
@property
|
||||
def is_snoozed(self) -> bool:
|
||||
"""Check if triggers are currently snoozed."""
|
||||
if self._snoozed_until is None:
|
||||
return False
|
||||
return time.monotonic() < self._snoozed_until
|
||||
|
||||
@property
|
||||
def snooze_remaining_seconds(self) -> float:
|
||||
"""Get remaining snooze time in seconds, or 0 if not snoozed."""
|
||||
if self._snoozed_until is None:
|
||||
return 0.0
|
||||
remaining = self._snoozed_until - time.monotonic()
|
||||
return max(0.0, remaining)
|
||||
|
||||
def evaluate(self) -> TriggerDecision:
|
||||
"""Evaluate all providers and determine action.
|
||||
|
||||
Returns:
|
||||
TriggerDecision with action and confidence details.
|
||||
"""
|
||||
now = time.monotonic()
|
||||
|
||||
# Check if disabled
|
||||
if not self._settings.enabled:
|
||||
return self._make_decision(TriggerAction.IGNORE, 0.0, ())
|
||||
|
||||
# Check if snoozed
|
||||
if self._snoozed_until is not None and now < self._snoozed_until:
|
||||
return self._make_decision(TriggerAction.IGNORE, 0.0, ())
|
||||
|
||||
# Collect signals from all enabled providers
|
||||
signals = []
|
||||
for provider in self._providers:
|
||||
if not provider.is_enabled():
|
||||
continue
|
||||
if signal := provider.get_signal():
|
||||
signals.append(signal)
|
||||
|
||||
# Calculate total confidence
|
||||
confidence = sum(s.weight for s in signals)
|
||||
|
||||
# Determine action
|
||||
action = self._determine_action(confidence, now)
|
||||
|
||||
# Record prompt time for rate limiting
|
||||
if action in (TriggerAction.NOTIFY, TriggerAction.AUTO_START):
|
||||
self._last_prompt = now
|
||||
logger.info(
|
||||
"Trigger %s: confidence=%.2f, signals=%d",
|
||||
action.value,
|
||||
confidence,
|
||||
len(signals),
|
||||
)
|
||||
|
||||
return self._make_decision(action, confidence, tuple(signals))
|
||||
|
||||
def _determine_action(self, confidence: float, now: float) -> TriggerAction:
|
||||
"""Determine action based on confidence and rate limits.
|
||||
|
||||
Args:
|
||||
confidence: Total confidence from all signals.
|
||||
now: Current monotonic time.
|
||||
|
||||
Returns:
|
||||
TriggerAction to take.
|
||||
"""
|
||||
# Check rate limit
|
||||
if self._last_prompt is not None:
|
||||
elapsed = now - self._last_prompt
|
||||
if elapsed < self._settings.rate_limit_seconds:
|
||||
return TriggerAction.IGNORE
|
||||
|
||||
# Apply thresholds
|
||||
if confidence < self._settings.threshold_ignore:
|
||||
return TriggerAction.IGNORE
|
||||
|
||||
if confidence >= self._settings.threshold_auto_start and self._settings.auto_start_enabled:
|
||||
return TriggerAction.AUTO_START
|
||||
|
||||
return TriggerAction.NOTIFY
|
||||
|
||||
def _make_decision(
|
||||
self,
|
||||
action: TriggerAction,
|
||||
confidence: float,
|
||||
signals: tuple[TriggerSignal, ...],
|
||||
) -> TriggerDecision:
|
||||
"""Create a TriggerDecision with the given parameters."""
|
||||
return TriggerDecision(
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
signals=signals,
|
||||
)
|
||||
|
||||
def snooze(self, seconds: int | None = None) -> None:
|
||||
"""Snooze triggers for the specified duration.
|
||||
|
||||
Args:
|
||||
seconds: Snooze duration in seconds (uses default if None).
|
||||
"""
|
||||
duration = seconds if seconds is not None else self._settings.snooze_seconds
|
||||
self._snoozed_until = time.monotonic() + duration
|
||||
logger.info("Triggers snoozed for %d seconds", duration)
|
||||
|
||||
def clear_snooze(self) -> None:
|
||||
"""Clear any active snooze."""
|
||||
if self._snoozed_until is not None:
|
||||
self._snoozed_until = None
|
||||
logger.info("Trigger snooze cleared")
|
||||
|
||||
def set_enabled(self, enabled: bool) -> None:
|
||||
"""Enable or disable trigger detection.
|
||||
|
||||
Args:
|
||||
enabled: Whether triggers should be enabled.
|
||||
"""
|
||||
self._settings.enabled = enabled
|
||||
logger.info("Triggers %s", "enabled" if enabled else "disabled")
|
||||
|
||||
def set_auto_start(self, enabled: bool) -> None:
|
||||
"""Enable or disable auto-start on high confidence.
|
||||
|
||||
Args:
|
||||
enabled: Whether auto-start should be enabled.
|
||||
"""
|
||||
self._settings.auto_start_enabled = enabled
|
||||
logger.info("Auto-start %s", "enabled" if enabled else "disabled")
|
||||
@@ -7,12 +7,14 @@ Orchestrates UI components - does not contain component logic.
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Final
|
||||
|
||||
import flet as ft
|
||||
|
||||
from noteflow.application.services import TriggerService, TriggerServiceSettings
|
||||
from noteflow.client.components import (
|
||||
AnnotationToolbarComponent,
|
||||
ConnectionPanelComponent,
|
||||
@@ -23,7 +25,15 @@ from noteflow.client.components import (
|
||||
VuMeterComponent,
|
||||
)
|
||||
from noteflow.client.state import AppState
|
||||
from noteflow.config.settings import TriggerSettings, get_trigger_settings
|
||||
from noteflow.domain.triggers import TriggerAction, TriggerDecision
|
||||
from noteflow.infrastructure.audio import SoundDeviceCapture, TimestampedAudio
|
||||
from noteflow.infrastructure.triggers import (
|
||||
AudioActivityProvider,
|
||||
AudioActivitySettings,
|
||||
ForegroundAppProvider,
|
||||
ForegroundAppSettings,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
@@ -66,6 +76,14 @@ class NoteFlowClientApp:
|
||||
self._sync_controller: PlaybackSyncController | None = None
|
||||
self._annotation_toolbar: AnnotationToolbarComponent | None = None
|
||||
|
||||
# Trigger detection (M5)
|
||||
self._trigger_settings: TriggerSettings | None = None
|
||||
self._trigger_service: TriggerService | None = None
|
||||
self._audio_activity: AudioActivityProvider | None = None
|
||||
self._foreground_app: ForegroundAppProvider | None = None
|
||||
self._trigger_poll_interval: float = 0.0
|
||||
self._trigger_task: asyncio.Task | None = None
|
||||
|
||||
# Recording buttons
|
||||
self._record_btn: ft.ElevatedButton | None = None
|
||||
self._stop_btn: ft.ElevatedButton | None = None
|
||||
@@ -89,6 +107,16 @@ class NoteFlowClientApp:
|
||||
page.add(self._build_ui())
|
||||
page.update()
|
||||
|
||||
# Initialize trigger detection (M5)
|
||||
self._initialize_triggers()
|
||||
|
||||
# Start trigger check loop if enabled (opt-in via settings)
|
||||
if self._state.trigger_enabled:
|
||||
self._trigger_task = page.run_task(self._trigger_check_loop)
|
||||
|
||||
# Ensure background tasks are cancelled when the UI closes
|
||||
page.on_disconnect = lambda _e: self._shutdown()
|
||||
|
||||
def _build_ui(self) -> ft.Column:
|
||||
"""Build the main UI by composing components.
|
||||
|
||||
@@ -163,6 +191,80 @@ class NoteFlowClientApp:
|
||||
spacing=10,
|
||||
)
|
||||
|
||||
def _initialize_triggers(self) -> None:
|
||||
"""Initialize trigger settings, providers, and service."""
|
||||
self._trigger_settings = get_trigger_settings()
|
||||
self._state.trigger_enabled = self._trigger_settings.trigger_enabled
|
||||
self._trigger_poll_interval = self._trigger_settings.trigger_poll_interval_seconds
|
||||
|
||||
audio_settings = AudioActivitySettings(
|
||||
enabled=self._trigger_settings.trigger_audio_enabled,
|
||||
threshold_db=self._trigger_settings.trigger_audio_threshold_db,
|
||||
window_seconds=self._trigger_settings.trigger_audio_window_seconds,
|
||||
min_active_ratio=self._trigger_settings.trigger_audio_min_active_ratio,
|
||||
min_samples=self._trigger_settings.trigger_audio_min_samples,
|
||||
max_history=self._trigger_settings.trigger_audio_max_history,
|
||||
weight=self._trigger_settings.trigger_weight_audio,
|
||||
)
|
||||
meeting_apps = {app.lower() for app in self._trigger_settings.trigger_meeting_apps}
|
||||
suppressed_apps = {app.lower() for app in self._trigger_settings.trigger_suppressed_apps}
|
||||
foreground_settings = ForegroundAppSettings(
|
||||
enabled=self._trigger_settings.trigger_foreground_enabled,
|
||||
weight=self._trigger_settings.trigger_weight_foreground,
|
||||
meeting_apps=meeting_apps,
|
||||
suppressed_apps=suppressed_apps,
|
||||
)
|
||||
|
||||
self._audio_activity = AudioActivityProvider(
|
||||
self._state.level_provider,
|
||||
audio_settings,
|
||||
)
|
||||
self._foreground_app = ForegroundAppProvider(foreground_settings)
|
||||
self._trigger_service = TriggerService(
|
||||
providers=[self._audio_activity, self._foreground_app],
|
||||
settings=TriggerServiceSettings(
|
||||
enabled=self._trigger_settings.trigger_enabled,
|
||||
auto_start_enabled=self._trigger_settings.trigger_auto_start,
|
||||
rate_limit_seconds=self._trigger_settings.trigger_rate_limit_minutes * 60,
|
||||
snooze_seconds=self._trigger_settings.trigger_snooze_minutes * 60,
|
||||
threshold_ignore=self._trigger_settings.trigger_confidence_ignore,
|
||||
threshold_auto_start=self._trigger_settings.trigger_confidence_auto,
|
||||
),
|
||||
)
|
||||
|
||||
def _should_keep_capture_running(self) -> bool:
|
||||
"""Return True if background audio capture should remain active."""
|
||||
if not self._trigger_settings:
|
||||
return False
|
||||
return (
|
||||
self._trigger_settings.trigger_enabled and self._trigger_settings.trigger_audio_enabled
|
||||
)
|
||||
|
||||
def _ensure_audio_capture(self) -> bool:
|
||||
"""Start audio capture if needed.
|
||||
|
||||
Returns:
|
||||
True if audio capture is running, False if start failed.
|
||||
"""
|
||||
if self._audio_capture:
|
||||
return True
|
||||
|
||||
try:
|
||||
self._audio_capture = SoundDeviceCapture()
|
||||
self._audio_capture.start(
|
||||
device_id=None,
|
||||
on_frames=self._on_audio_frames,
|
||||
sample_rate=16000,
|
||||
channels=1,
|
||||
chunk_duration_ms=100,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to start audio capture")
|
||||
self._audio_capture = None
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _on_connected(self, client: NoteFlowClient, info: ServerInfo) -> None:
|
||||
"""Handle successful connection.
|
||||
|
||||
@@ -184,6 +286,7 @@ class NoteFlowClientApp:
|
||||
|
||||
def _on_disconnected(self) -> None:
|
||||
"""Handle disconnection."""
|
||||
self._shutdown()
|
||||
if self._state.recording:
|
||||
self._stop_recording()
|
||||
self._client = None
|
||||
@@ -243,19 +346,8 @@ class NoteFlowClientApp:
|
||||
self._state.current_meeting = None
|
||||
return
|
||||
|
||||
# Start audio capture (REUSE existing SoundDeviceCapture)
|
||||
try:
|
||||
self._audio_capture = SoundDeviceCapture()
|
||||
self._audio_capture.start(
|
||||
device_id=None,
|
||||
on_frames=self._on_audio_frames,
|
||||
sample_rate=16000,
|
||||
channels=1,
|
||||
chunk_duration_ms=100,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to start audio capture")
|
||||
self._audio_capture = None
|
||||
# Start audio capture (reuse existing capture if already running)
|
||||
if not self._ensure_audio_capture():
|
||||
self._client.stop_streaming()
|
||||
self._client.stop_meeting(meeting.id)
|
||||
self._state.reset_recording_state()
|
||||
@@ -285,7 +377,7 @@ class NoteFlowClientApp:
|
||||
def _stop_recording(self) -> None:
|
||||
"""Stop recording audio."""
|
||||
# Stop audio capture first
|
||||
if self._audio_capture:
|
||||
if self._audio_capture and not self._should_keep_capture_running():
|
||||
self._audio_capture.stop()
|
||||
self._audio_capture = None
|
||||
|
||||
@@ -335,15 +427,20 @@ class NoteFlowClientApp:
|
||||
self._client.send_audio(frames, timestamp)
|
||||
|
||||
# Buffer for playback (estimate duration from chunk size)
|
||||
duration = len(frames) / 16000.0 # Sample rate is 16kHz
|
||||
self._state.session_audio_buffer.append(
|
||||
TimestampedAudio(frames=frames.copy(), timestamp=timestamp, duration=duration)
|
||||
)
|
||||
if self._state.recording:
|
||||
duration = len(frames) / 16000.0 # Sample rate is 16kHz
|
||||
self._state.session_audio_buffer.append(
|
||||
TimestampedAudio(frames=frames.copy(), timestamp=timestamp, duration=duration)
|
||||
)
|
||||
|
||||
# Update VU meter
|
||||
if self._vu_meter:
|
||||
self._vu_meter.on_audio_frames(frames)
|
||||
|
||||
# Feed audio activity provider for trigger detection
|
||||
if self._audio_activity:
|
||||
self._audio_activity.update(frames, timestamp)
|
||||
|
||||
def _on_segment_click(self, segment_index: int) -> None:
|
||||
"""Handle transcript segment click - seek playback to segment.
|
||||
|
||||
@@ -371,6 +468,114 @@ class NoteFlowClientApp:
|
||||
# Sync controller handles segment matching internally
|
||||
_ = position # Position tracked in state
|
||||
|
||||
async def _trigger_check_loop(self) -> None:
|
||||
"""Background loop to check trigger conditions.
|
||||
|
||||
Runs every 2 seconds while not recording.
|
||||
"""
|
||||
check_interval = self._trigger_poll_interval
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(check_interval)
|
||||
|
||||
# Skip if recording or trigger pending
|
||||
if self._state.recording or self._state.trigger_pending:
|
||||
continue
|
||||
|
||||
# Skip if triggers disabled
|
||||
if not self._state.trigger_enabled or not self._trigger_service:
|
||||
continue
|
||||
|
||||
# Start background audio capture only when needed for triggers
|
||||
if self._should_keep_capture_running():
|
||||
self._ensure_audio_capture()
|
||||
|
||||
# Evaluate triggers
|
||||
decision = self._trigger_service.evaluate()
|
||||
self._state.trigger_decision = decision
|
||||
|
||||
if decision.action == TriggerAction.IGNORE:
|
||||
continue
|
||||
|
||||
if decision.action == TriggerAction.AUTO_START:
|
||||
# Auto-start if connected
|
||||
if self._state.connected:
|
||||
logger.info("Auto-starting recording (confidence=%.2f)", decision.confidence)
|
||||
self._start_recording()
|
||||
elif decision.action == TriggerAction.NOTIFY:
|
||||
# Show prompt to user
|
||||
self._show_trigger_prompt(decision)
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Trigger loop cancelled")
|
||||
raise
|
||||
|
||||
def _show_trigger_prompt(self, decision: TriggerDecision) -> None:
|
||||
"""Show trigger notification prompt to user.
|
||||
|
||||
Args:
|
||||
decision: Trigger decision with confidence and signals.
|
||||
"""
|
||||
self._state.trigger_pending = True
|
||||
|
||||
# Build signal description
|
||||
signal_desc = ", ".join(s.app_name or s.source.value for s in decision.signals)
|
||||
|
||||
def handle_start(_: ft.ControlEvent) -> None:
|
||||
self._state.trigger_pending = False
|
||||
if dialog.open:
|
||||
dialog.open = False
|
||||
self._state.request_update()
|
||||
if self._state.connected:
|
||||
self._start_recording()
|
||||
|
||||
def handle_snooze(_: ft.ControlEvent) -> None:
|
||||
self._state.trigger_pending = False
|
||||
if self._trigger_service:
|
||||
self._trigger_service.snooze()
|
||||
if dialog.open:
|
||||
dialog.open = False
|
||||
self._state.request_update()
|
||||
|
||||
def handle_dismiss(_: ft.ControlEvent) -> None:
|
||||
self._state.trigger_pending = False
|
||||
if dialog.open:
|
||||
dialog.open = False
|
||||
self._state.request_update()
|
||||
|
||||
dialog = ft.AlertDialog(
|
||||
title=ft.Text("Meeting Detected"),
|
||||
content=ft.Text(
|
||||
"Detected: "
|
||||
f"{signal_desc}\n"
|
||||
f"Confidence: {decision.confidence:.0%}\n\n"
|
||||
"Start recording?"
|
||||
),
|
||||
actions=[
|
||||
ft.TextButton("Start", on_click=handle_start),
|
||||
ft.TextButton("Snooze", on_click=handle_snooze),
|
||||
ft.TextButton("Dismiss", on_click=handle_dismiss),
|
||||
],
|
||||
actions_alignment=ft.MainAxisAlignment.END,
|
||||
)
|
||||
|
||||
if self._state._page:
|
||||
self._state._page.dialog = dialog
|
||||
dialog.open = True
|
||||
self._state.request_update()
|
||||
|
||||
def _shutdown(self) -> None:
|
||||
"""Stop background tasks and capture started for triggers."""
|
||||
if self._trigger_task:
|
||||
self._trigger_task.cancel()
|
||||
self._trigger_task = None
|
||||
|
||||
if self._audio_capture and not self._state.recording:
|
||||
try:
|
||||
self._audio_capture.stop()
|
||||
except Exception:
|
||||
logger.debug("Error stopping audio capture during shutdown", exc_info=True)
|
||||
self._audio_capture = None
|
||||
|
||||
def _update_recording_buttons(self) -> None:
|
||||
"""Update recording button states."""
|
||||
if self._record_btn:
|
||||
|
||||
@@ -13,6 +13,7 @@ from dataclasses import dataclass, field
|
||||
import flet as ft
|
||||
|
||||
# REUSE existing types - do not recreate
|
||||
from noteflow.domain.triggers import TriggerDecision
|
||||
from noteflow.grpc.client import AnnotationInfo, MeetingInfo, ServerInfo, TranscriptSegment
|
||||
from noteflow.infrastructure.audio import (
|
||||
RmsLevelProvider,
|
||||
@@ -68,6 +69,11 @@ class AppState:
|
||||
meetings: list[MeetingInfo] = field(default_factory=list)
|
||||
selected_meeting: MeetingInfo | None = None
|
||||
|
||||
# Trigger state (REUSE existing TriggerDecision)
|
||||
trigger_enabled: bool = True
|
||||
trigger_pending: bool = False # True when prompt is shown
|
||||
trigger_decision: TriggerDecision | None = None # Last trigger decision
|
||||
|
||||
# UI page reference (private)
|
||||
_page: ft.Page | None = field(default=None, repr=False)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""NoteFlow configuration module."""
|
||||
|
||||
from .settings import Settings, get_settings
|
||||
from .settings import Settings, TriggerSettings, get_settings, get_trigger_settings
|
||||
|
||||
__all__ = ["Settings", "get_settings"]
|
||||
__all__ = ["Settings", "TriggerSettings", "get_settings", "get_trigger_settings"]
|
||||
|
||||
@@ -2,11 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Annotated, cast
|
||||
|
||||
from pydantic import Field, PostgresDsn
|
||||
from pydantic import Field, PostgresDsn, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
@@ -15,12 +16,146 @@ def _default_meetings_dir() -> Path:
|
||||
return Path.home() / ".noteflow" / "meetings"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
class TriggerSettings(BaseSettings):
|
||||
"""Client trigger settings loaded from environment variables."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="NOTEFLOW_",
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
enable_decoding=False,
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# Trigger settings (client-side)
|
||||
trigger_enabled: Annotated[
|
||||
bool,
|
||||
Field(default=False, description="Enable smart recording triggers (opt-in)"),
|
||||
]
|
||||
trigger_auto_start: Annotated[
|
||||
bool,
|
||||
Field(default=False, description="Auto-start recording on high confidence"),
|
||||
]
|
||||
trigger_rate_limit_minutes: Annotated[
|
||||
int,
|
||||
Field(default=10, ge=1, le=60, description="Minimum minutes between trigger prompts"),
|
||||
]
|
||||
trigger_snooze_minutes: Annotated[
|
||||
int,
|
||||
Field(default=30, ge=5, le=480, description="Default snooze duration in minutes"),
|
||||
]
|
||||
trigger_poll_interval_seconds: Annotated[
|
||||
float,
|
||||
Field(default=2.0, ge=0.5, le=30.0, description="Trigger polling interval in seconds"),
|
||||
]
|
||||
trigger_confidence_ignore: Annotated[
|
||||
float,
|
||||
Field(default=0.40, ge=0.0, le=1.0, description="Confidence below which to ignore"),
|
||||
]
|
||||
trigger_confidence_auto: Annotated[
|
||||
float,
|
||||
Field(default=0.80, ge=0.0, le=1.0, description="Confidence to auto-start recording"),
|
||||
]
|
||||
|
||||
# Audio trigger tuning
|
||||
trigger_audio_enabled: Annotated[
|
||||
bool,
|
||||
Field(default=True, description="Enable audio activity detection"),
|
||||
]
|
||||
trigger_audio_threshold_db: Annotated[
|
||||
float,
|
||||
Field(default=-40.0, ge=-60.0, le=0.0, description="Audio activity threshold in dB"),
|
||||
]
|
||||
trigger_audio_window_seconds: Annotated[
|
||||
float,
|
||||
Field(default=5.0, ge=1.0, le=30.0, description="Audio activity window in seconds"),
|
||||
]
|
||||
trigger_audio_min_active_ratio: Annotated[
|
||||
float,
|
||||
Field(default=0.6, ge=0.0, le=1.0, description="Minimum active ratio in window"),
|
||||
]
|
||||
trigger_audio_min_samples: Annotated[
|
||||
int,
|
||||
Field(default=10, ge=1, le=200, description="Minimum samples before evaluating audio"),
|
||||
]
|
||||
trigger_audio_max_history: Annotated[
|
||||
int,
|
||||
Field(default=50, ge=10, le=1000, description="Max audio activity samples to retain"),
|
||||
]
|
||||
|
||||
# Foreground app trigger tuning
|
||||
trigger_foreground_enabled: Annotated[
|
||||
bool,
|
||||
Field(default=True, description="Enable foreground app detection"),
|
||||
]
|
||||
trigger_meeting_apps: Annotated[
|
||||
list[str],
|
||||
Field(
|
||||
default_factory=lambda: [
|
||||
"zoom",
|
||||
"teams",
|
||||
"microsoft teams",
|
||||
"meet",
|
||||
"google meet",
|
||||
"slack",
|
||||
"webex",
|
||||
"discord",
|
||||
"skype",
|
||||
"gotomeeting",
|
||||
"facetime",
|
||||
"webinar",
|
||||
"ringcentral",
|
||||
],
|
||||
description="Meeting app name substrings to detect",
|
||||
),
|
||||
]
|
||||
trigger_suppressed_apps: Annotated[
|
||||
list[str],
|
||||
Field(default_factory=list, description="Meeting app substrings to ignore"),
|
||||
]
|
||||
|
||||
# Signal weights
|
||||
trigger_weight_audio: Annotated[
|
||||
float,
|
||||
Field(default=0.30, ge=0.0, le=1.0, description="Audio signal confidence weight"),
|
||||
]
|
||||
trigger_weight_foreground: Annotated[
|
||||
float,
|
||||
Field(
|
||||
default=0.40,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Foreground app signal confidence weight",
|
||||
),
|
||||
]
|
||||
trigger_weight_calendar: Annotated[
|
||||
float,
|
||||
Field(default=0.30, ge=0.0, le=1.0, description="Calendar signal confidence weight"),
|
||||
]
|
||||
|
||||
@field_validator("trigger_meeting_apps", "trigger_suppressed_apps", mode="before")
|
||||
@classmethod
|
||||
def _parse_csv_list(cls, value: object) -> list[str]:
|
||||
if not isinstance(value, str):
|
||||
return [] if value is None else list(value)
|
||||
stripped = value.strip()
|
||||
if stripped.startswith("[") and stripped.endswith("]"):
|
||||
try:
|
||||
parsed = json.loads(stripped)
|
||||
except json.JSONDecodeError:
|
||||
parsed = None
|
||||
if isinstance(parsed, list):
|
||||
return [str(item).strip() for item in parsed if str(item).strip()]
|
||||
return [item.strip() for item in value.split(",") if item.strip()]
|
||||
|
||||
|
||||
class Settings(TriggerSettings):
|
||||
"""Application settings loaded from environment variables.
|
||||
|
||||
Environment variables:
|
||||
NOTEFLOW_DATABASE_URL: PostgreSQL connection URL
|
||||
Example: postgresql+asyncpg://user:pass@host:5432/dbname?options=-csearch_path%3Dnoteflow
|
||||
Example: postgresql+asyncpg://user:pass@host:5432/dbname?
|
||||
options=-csearch_path%3Dnoteflow
|
||||
NOTEFLOW_DB_POOL_SIZE: Connection pool size (default: 5)
|
||||
NOTEFLOW_DB_ECHO: Echo SQL statements (default: False)
|
||||
NOTEFLOW_ASR_MODEL_SIZE: Whisper model size (default: base)
|
||||
@@ -29,13 +164,6 @@ class Settings(BaseSettings):
|
||||
NOTEFLOW_MEETINGS_DIR: Directory for meeting audio storage (default: ~/.noteflow/meetings)
|
||||
"""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="NOTEFLOW_",
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# Database settings
|
||||
database_url: Annotated[
|
||||
PostgresDsn,
|
||||
@@ -101,6 +229,11 @@ def _load_settings() -> Settings:
|
||||
return cast("Settings", Settings.model_validate({}))
|
||||
|
||||
|
||||
def _load_trigger_settings() -> TriggerSettings:
|
||||
"""Load trigger settings from environment."""
|
||||
return cast("TriggerSettings", TriggerSettings.model_validate({}))
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
"""Get cached settings instance.
|
||||
@@ -112,3 +245,9 @@ def get_settings() -> Settings:
|
||||
ValidationError: If required environment variables are not set.
|
||||
"""
|
||||
return _load_settings()
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_trigger_settings() -> TriggerSettings:
|
||||
"""Get cached trigger settings instance."""
|
||||
return _load_trigger_settings()
|
||||
|
||||
28
src/noteflow/domain/summarization/__init__.py
Normal file
28
src/noteflow/domain/summarization/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Summarization domain module.
|
||||
|
||||
Provides protocols and data transfer objects for meeting summarization.
|
||||
"""
|
||||
|
||||
from noteflow.domain.summarization.ports import (
|
||||
CitationVerificationResult,
|
||||
CitationVerifier,
|
||||
InvalidResponseError,
|
||||
ProviderUnavailableError,
|
||||
SummarizationError,
|
||||
SummarizationRequest,
|
||||
SummarizationResult,
|
||||
SummarizationTimeoutError,
|
||||
SummarizerProvider,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CitationVerificationResult",
|
||||
"CitationVerifier",
|
||||
"InvalidResponseError",
|
||||
"ProviderUnavailableError",
|
||||
"SummarizationError",
|
||||
"SummarizationRequest",
|
||||
"SummarizationResult",
|
||||
"SummarizationTimeoutError",
|
||||
"SummarizerProvider",
|
||||
]
|
||||
165
src/noteflow/domain/summarization/ports.py
Normal file
165
src/noteflow/domain/summarization/ports.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Summarization provider port protocols."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from noteflow.domain.entities import Segment, Summary
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SummarizationRequest:
|
||||
"""Request for meeting summarization.
|
||||
|
||||
Contains the meeting context needed for summary generation.
|
||||
"""
|
||||
|
||||
meeting_id: MeetingId
|
||||
segments: Sequence[Segment]
|
||||
max_key_points: int = 5
|
||||
max_action_items: int = 10
|
||||
|
||||
@property
|
||||
def transcript_text(self) -> str:
|
||||
"""Concatenate all segment text into a single transcript."""
|
||||
return " ".join(seg.text for seg in self.segments)
|
||||
|
||||
@property
|
||||
def segment_count(self) -> int:
|
||||
"""Number of segments in the request."""
|
||||
return len(self.segments)
|
||||
|
||||
@property
|
||||
def total_duration(self) -> float:
|
||||
"""Total duration of all segments in seconds."""
|
||||
if not self.segments:
|
||||
return 0.0
|
||||
return self.segments[-1].end_time - self.segments[0].start_time
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SummarizationResult:
|
||||
"""Result from summarization provider.
|
||||
|
||||
Contains the generated summary along with metadata.
|
||||
"""
|
||||
|
||||
summary: Summary
|
||||
model_name: str
|
||||
provider_name: str
|
||||
tokens_used: int | None = None
|
||||
latency_ms: float = 0.0
|
||||
|
||||
@property
|
||||
def is_success(self) -> bool:
|
||||
"""Check if summarization succeeded with content."""
|
||||
return bool(self.summary.executive_summary)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CitationVerificationResult:
|
||||
"""Result of citation verification.
|
||||
|
||||
Identifies which citations are valid and which are invalid.
|
||||
"""
|
||||
|
||||
is_valid: bool
|
||||
invalid_key_point_indices: tuple[int, ...] = field(default_factory=tuple)
|
||||
invalid_action_item_indices: tuple[int, ...] = field(default_factory=tuple)
|
||||
missing_segment_ids: tuple[int, ...] = field(default_factory=tuple)
|
||||
|
||||
@property
|
||||
def invalid_count(self) -> int:
|
||||
"""Total number of invalid citations."""
|
||||
return len(self.invalid_key_point_indices) + len(self.invalid_action_item_indices)
|
||||
|
||||
|
||||
class SummarizerProvider(Protocol):
|
||||
"""Protocol for LLM summarization providers.
|
||||
|
||||
Implementations must provide async summarization with evidence linking.
|
||||
"""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
"""Provider identifier (e.g., 'mock', 'ollama', 'openai')."""
|
||||
...
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
"""Check if provider is configured and available."""
|
||||
...
|
||||
|
||||
@property
|
||||
def requires_cloud_consent(self) -> bool:
|
||||
"""Return True if data is sent to external services.
|
||||
|
||||
Cloud providers must return True to ensure explicit user consent.
|
||||
"""
|
||||
...
|
||||
|
||||
async def summarize(self, request: SummarizationRequest) -> SummarizationResult:
|
||||
"""Generate evidence-linked summary from transcript segments.
|
||||
|
||||
Args:
|
||||
request: Summarization request with segments and constraints.
|
||||
|
||||
Returns:
|
||||
SummarizationResult with generated summary and metadata.
|
||||
|
||||
Raises:
|
||||
SummarizationError: If summarization fails.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class CitationVerifier(Protocol):
|
||||
"""Protocol for verifying evidence citations.
|
||||
|
||||
Validates that segment_ids in summaries reference actual segments.
|
||||
"""
|
||||
|
||||
def verify_citations(
|
||||
self,
|
||||
summary: Summary,
|
||||
segments: Sequence[Segment],
|
||||
) -> CitationVerificationResult:
|
||||
"""Verify all segment_ids exist in the transcript.
|
||||
|
||||
Args:
|
||||
summary: Summary with key points and action items to verify.
|
||||
segments: Available transcript segments.
|
||||
|
||||
Returns:
|
||||
CitationVerificationResult with validation status and details.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class SummarizationError(Exception):
|
||||
"""Base exception for summarization errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ProviderUnavailableError(SummarizationError):
|
||||
"""Provider is not available or not configured."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SummarizationTimeoutError(SummarizationError):
|
||||
"""Summarization operation timed out."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidResponseError(SummarizationError):
|
||||
"""Provider returned an invalid or unparseable response."""
|
||||
|
||||
pass
|
||||
17
src/noteflow/domain/triggers/__init__.py
Normal file
17
src/noteflow/domain/triggers/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Trigger domain package."""
|
||||
|
||||
from noteflow.domain.triggers.entities import (
|
||||
TriggerAction,
|
||||
TriggerDecision,
|
||||
TriggerSignal,
|
||||
TriggerSource,
|
||||
)
|
||||
from noteflow.domain.triggers.ports import SignalProvider
|
||||
|
||||
__all__ = [
|
||||
"SignalProvider",
|
||||
"TriggerAction",
|
||||
"TriggerDecision",
|
||||
"TriggerSignal",
|
||||
"TriggerSource",
|
||||
]
|
||||
84
src/noteflow/domain/triggers/entities.py
Normal file
84
src/noteflow/domain/triggers/entities.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Trigger domain entities and value objects.
|
||||
|
||||
Define trigger signals, decisions, and actions for meeting detection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TriggerSource(Enum):
|
||||
"""Source of a trigger signal."""
|
||||
|
||||
AUDIO_ACTIVITY = "audio_activity"
|
||||
FOREGROUND_APP = "foreground_app"
|
||||
CALENDAR = "calendar" # Deferred - optional connector
|
||||
|
||||
|
||||
class TriggerAction(Enum):
|
||||
"""Action determined by trigger evaluation."""
|
||||
|
||||
IGNORE = "ignore" # Confidence < 0.40
|
||||
NOTIFY = "notify" # Confidence 0.40-0.79
|
||||
AUTO_START = "auto_start" # Confidence >= 0.80 (if enabled)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TriggerSignal:
|
||||
"""A signal from a single trigger source.
|
||||
|
||||
Attributes:
|
||||
source: The source that generated this signal.
|
||||
weight: Confidence contribution (0.0-1.0).
|
||||
app_name: For foreground app signals, the detected app name.
|
||||
timestamp: When the signal was generated (monotonic time).
|
||||
"""
|
||||
|
||||
source: TriggerSource
|
||||
weight: float
|
||||
app_name: str | None = None
|
||||
timestamp: float = field(default_factory=time.monotonic)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate weight is in valid range."""
|
||||
if not 0.0 <= self.weight <= 1.0:
|
||||
msg = f"Weight must be 0.0-1.0, got {self.weight}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TriggerDecision:
|
||||
"""Result of trigger evaluation.
|
||||
|
||||
Attributes:
|
||||
action: The determined action (ignore, notify, auto_start).
|
||||
confidence: Total confidence score from all signals.
|
||||
signals: The signals that contributed to this decision.
|
||||
timestamp: When the decision was made (monotonic time).
|
||||
"""
|
||||
|
||||
action: TriggerAction
|
||||
confidence: float
|
||||
signals: tuple[TriggerSignal, ...]
|
||||
timestamp: float = field(default_factory=time.monotonic)
|
||||
|
||||
@property
|
||||
def primary_signal(self) -> TriggerSignal | None:
|
||||
"""Get the signal with highest weight contribution."""
|
||||
return max(self.signals, key=lambda s: s.weight) if self.signals else None
|
||||
|
||||
@property
|
||||
def detected_app(self) -> str | None:
|
||||
"""Get the detected app name from foreground signal if present."""
|
||||
return next(
|
||||
(
|
||||
signal.app_name
|
||||
for signal in self.signals
|
||||
if signal.source == TriggerSource.FOREGROUND_APP
|
||||
and signal.app_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
51
src/noteflow/domain/triggers/ports.py
Normal file
51
src/noteflow/domain/triggers/ports.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Trigger signal provider port protocol.
|
||||
|
||||
Define the interface for signal providers that detect meeting conditions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.domain.triggers.entities import TriggerSignal, TriggerSource
|
||||
|
||||
|
||||
class SignalProvider(Protocol):
|
||||
"""Protocol for trigger signal providers.
|
||||
|
||||
Signal providers detect specific conditions (audio activity, foreground app, etc.)
|
||||
and return weighted signals used in trigger evaluation.
|
||||
|
||||
Each provider:
|
||||
- Has a specific source type
|
||||
- Has a maximum weight contribution
|
||||
- Can be enabled/disabled
|
||||
- Returns a signal when conditions are met, None otherwise
|
||||
"""
|
||||
|
||||
@property
|
||||
def source(self) -> TriggerSource:
|
||||
"""Get the source type for this provider."""
|
||||
...
|
||||
|
||||
@property
|
||||
def max_weight(self) -> float:
|
||||
"""Get the maximum weight this provider can contribute."""
|
||||
...
|
||||
|
||||
def get_signal(self) -> TriggerSignal | None:
|
||||
"""Get current signal if conditions are met.
|
||||
|
||||
Returns:
|
||||
TriggerSignal if provider conditions are satisfied, None otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if this provider is enabled.
|
||||
|
||||
Returns:
|
||||
True if provider is enabled and can produce signals.
|
||||
"""
|
||||
...
|
||||
@@ -161,5 +161,7 @@ class MeetingStore:
|
||||
def active_count(self) -> int:
|
||||
"""Count of meetings in RECORDING or STOPPING state."""
|
||||
with self._lock:
|
||||
return sum(bool(m.state in (MeetingState.RECORDING, MeetingState.STOPPING))
|
||||
for m in self._meetings.values())
|
||||
return sum(
|
||||
m.state in (MeetingState.RECORDING, MeetingState.STOPPING)
|
||||
for m in self._meetings.values()
|
||||
)
|
||||
|
||||
@@ -6,10 +6,13 @@ Provides real-time speech detection for audio streams.
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Protocol
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from noteflow.infrastructure.audio import compute_rms
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
||||
class VadEngine(Protocol):
|
||||
@@ -72,7 +75,7 @@ class EnergyVad:
|
||||
Returns:
|
||||
True if speech detected, False for silence.
|
||||
"""
|
||||
energy = self._compute_rms(audio)
|
||||
energy = compute_rms(audio)
|
||||
|
||||
if self._is_speech:
|
||||
# Currently in speech - check for silence
|
||||
@@ -99,11 +102,6 @@ class EnergyVad:
|
||||
self._speech_frame_count = 0
|
||||
self._silence_frame_count = 0
|
||||
|
||||
@staticmethod
|
||||
def _compute_rms(audio: NDArray[np.float32]) -> float:
|
||||
"""Compute Root Mean Square energy of audio."""
|
||||
return 0.0 if len(audio) == 0 else float(np.sqrt(np.mean(audio**2)))
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamingVad:
|
||||
|
||||
@@ -9,7 +9,7 @@ from noteflow.infrastructure.audio.dto import (
|
||||
AudioFrameCallback,
|
||||
TimestampedAudio,
|
||||
)
|
||||
from noteflow.infrastructure.audio.levels import RmsLevelProvider
|
||||
from noteflow.infrastructure.audio.levels import RmsLevelProvider, compute_rms
|
||||
from noteflow.infrastructure.audio.playback import PlaybackState, SoundDevicePlayback
|
||||
from noteflow.infrastructure.audio.protocols import (
|
||||
AudioCapture,
|
||||
@@ -34,4 +34,5 @@ __all__ = [
|
||||
"SoundDevicePlayback",
|
||||
"TimestampedAudio",
|
||||
"TimestampedRingBuffer",
|
||||
"compute_rms",
|
||||
]
|
||||
|
||||
@@ -12,6 +12,21 @@ import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
||||
def compute_rms(frames: NDArray[np.float32]) -> float:
|
||||
"""Calculate Root Mean Square of audio samples.
|
||||
|
||||
Args:
|
||||
frames: Audio samples as float32 array.
|
||||
|
||||
Returns:
|
||||
RMS level as float (0.0 for empty array).
|
||||
"""
|
||||
if len(frames) == 0:
|
||||
return 0.0
|
||||
# Use float64 for precision during squaring to avoid overflow
|
||||
return float(np.sqrt(np.mean(frames.astype(np.float64) ** 2)))
|
||||
|
||||
|
||||
class RmsLevelProvider:
|
||||
"""RMS-based audio level provider.
|
||||
|
||||
@@ -30,13 +45,8 @@ class RmsLevelProvider:
|
||||
Returns:
|
||||
RMS level normalized to 0.0-1.0 range.
|
||||
"""
|
||||
if len(frames) == 0:
|
||||
return 0.0
|
||||
|
||||
# Calculate RMS: sqrt(mean(samples^2))
|
||||
rms = float(np.sqrt(np.mean(frames.astype(np.float64) ** 2)))
|
||||
|
||||
# Clamp to 0.0-1.0 range
|
||||
rms = compute_rms(frames)
|
||||
# Clamp to 0.0-1.0 range for VU meter display
|
||||
return min(1.0, max(0.0, rms))
|
||||
|
||||
def get_db(self, frames: NDArray[np.float32]) -> float:
|
||||
|
||||
@@ -125,37 +125,38 @@ class HtmlExporter:
|
||||
content_parts.append(
|
||||
f"<dt>Duration:</dt><dd>{format_timestamp(meeting.duration_seconds)}</dd>"
|
||||
)
|
||||
content_parts.extend((f"<dt>Segments:</dt><dd>{len(segments)}</dd>", "</dl>"))
|
||||
content_parts.extend(("</div>", "<h2>Transcript</h2>"))
|
||||
content_parts.append('<div class="transcript">')
|
||||
|
||||
content_parts.extend(
|
||||
(
|
||||
f"<dt>Segments:</dt><dd>{len(segments)}</dd>",
|
||||
"</dl>",
|
||||
"</div>",
|
||||
"<h2>Transcript</h2>",
|
||||
'<div class="transcript">',
|
||||
)
|
||||
)
|
||||
for segment in segments:
|
||||
timestamp = format_timestamp(segment.start_time)
|
||||
content_parts.append('<div class="segment">')
|
||||
content_parts.append(f'<span class="timestamp">[{timestamp}]</span>')
|
||||
content_parts.append(f"<span>{_escape(segment.text)}</span>")
|
||||
content_parts.append("</div>")
|
||||
|
||||
content_parts.extend((f"<span>{_escape(segment.text)}</span>", "</div>"))
|
||||
content_parts.append("</div>")
|
||||
|
||||
# Summary section (if available)
|
||||
if meeting.summary:
|
||||
content_parts.append('<div class="summary">')
|
||||
content_parts.append("<h2>Summary</h2>")
|
||||
|
||||
content_parts.extend(('<div class="summary">', "<h2>Summary</h2>"))
|
||||
if meeting.summary.executive_summary:
|
||||
content_parts.append(f"<p>{_escape(meeting.summary.executive_summary)}</p>")
|
||||
|
||||
if meeting.summary.key_points:
|
||||
content_parts.append("<h3>Key Points</h3>")
|
||||
content_parts.append('<ul class="key-points">')
|
||||
for point in meeting.summary.key_points:
|
||||
content_parts.append(f"<li>{_escape(point.text)}</li>")
|
||||
content_parts.extend(("<h3>Key Points</h3>", '<ul class="key-points">'))
|
||||
content_parts.extend(
|
||||
f"<li>{_escape(point.text)}</li>"
|
||||
for point in meeting.summary.key_points
|
||||
)
|
||||
content_parts.append("</ul>")
|
||||
|
||||
if meeting.summary.action_items:
|
||||
content_parts.append("<h3>Action Items</h3>")
|
||||
content_parts.append('<ul class="action-items">')
|
||||
content_parts.extend(("<h3>Action Items</h3>", '<ul class="action-items">'))
|
||||
for item in meeting.summary.action_items:
|
||||
assignee = (
|
||||
f' <span class="assignee">@{_escape(item.assignee)}</span>'
|
||||
@@ -169,10 +170,11 @@ class HtmlExporter:
|
||||
|
||||
# Footer
|
||||
content_parts.append("<footer>")
|
||||
content_parts.append(
|
||||
f"Exported from NoteFlow on {_escape(format_datetime(datetime.now()))}"
|
||||
content_parts.extend(
|
||||
(
|
||||
f"Exported from NoteFlow on {_escape(format_datetime(datetime.now()))}",
|
||||
"</footer>",
|
||||
)
|
||||
)
|
||||
content_parts.append("</footer>")
|
||||
|
||||
content = "\n".join(content_parts)
|
||||
return _HTML_TEMPLATE.format(title=_escape(meeting.title), content=content)
|
||||
|
||||
@@ -61,37 +61,22 @@ class MarkdownExporter:
|
||||
if meeting.ended_at:
|
||||
lines.append(f"- **Ended:** {format_datetime(meeting.ended_at)}")
|
||||
lines.append(f"- **Duration:** {format_timestamp(meeting.duration_seconds)}")
|
||||
lines.append(f"- **Segments:** {len(segments)}")
|
||||
lines.append("")
|
||||
|
||||
# Transcript section
|
||||
lines.append("## Transcript")
|
||||
lines.append("")
|
||||
|
||||
lines.extend((f"- **Segments:** {len(segments)}", "", "## Transcript", ""))
|
||||
for segment in segments:
|
||||
timestamp = format_timestamp(segment.start_time)
|
||||
lines.append(f"**[{timestamp}]** {segment.text}")
|
||||
lines.append("")
|
||||
|
||||
lines.extend((f"**[{timestamp}]** {segment.text}", ""))
|
||||
# Summary section (if available)
|
||||
if meeting.summary:
|
||||
lines.append("## Summary")
|
||||
lines.append("")
|
||||
|
||||
lines.extend(("## Summary", ""))
|
||||
if meeting.summary.executive_summary:
|
||||
lines.append(meeting.summary.executive_summary)
|
||||
lines.append("")
|
||||
|
||||
lines.extend((meeting.summary.executive_summary, ""))
|
||||
if meeting.summary.key_points:
|
||||
lines.append("### Key Points")
|
||||
lines.append("")
|
||||
for point in meeting.summary.key_points:
|
||||
lines.append(f"- {point.text}")
|
||||
lines.extend(("### Key Points", ""))
|
||||
lines.extend(f"- {point.text}" for point in meeting.summary.key_points)
|
||||
lines.append("")
|
||||
|
||||
if meeting.summary.action_items:
|
||||
lines.append("### Action Items")
|
||||
lines.append("")
|
||||
lines.extend(("### Action Items", ""))
|
||||
for item in meeting.summary.action_items:
|
||||
assignee = f" (@{item.assignee})" if item.assignee else ""
|
||||
lines.append(f"- [ ] {item.text}{assignee}")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Alembic migration environment configuration."""
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
22
src/noteflow/infrastructure/summarization/__init__.py
Normal file
22
src/noteflow/infrastructure/summarization/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Summarization infrastructure module.
|
||||
|
||||
Provides summarization provider implementations and citation verification.
|
||||
"""
|
||||
|
||||
from noteflow.infrastructure.summarization.citation_verifier import (
|
||||
SegmentCitationVerifier,
|
||||
)
|
||||
from noteflow.infrastructure.summarization.cloud_provider import (
|
||||
CloudBackend,
|
||||
CloudSummarizer,
|
||||
)
|
||||
from noteflow.infrastructure.summarization.mock_provider import MockSummarizer
|
||||
from noteflow.infrastructure.summarization.ollama_provider import OllamaSummarizer
|
||||
|
||||
__all__ = [
|
||||
"CloudBackend",
|
||||
"CloudSummarizer",
|
||||
"MockSummarizer",
|
||||
"OllamaSummarizer",
|
||||
"SegmentCitationVerifier",
|
||||
]
|
||||
134
src/noteflow/infrastructure/summarization/_parsing.py
Normal file
134
src/noteflow/infrastructure/summarization/_parsing.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Shared parsing utilities for summarization providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.domain.entities import ActionItem, KeyPoint, Summary
|
||||
from noteflow.domain.summarization import InvalidResponseError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.domain.summarization import SummarizationRequest
|
||||
|
||||
|
||||
# System prompt for structured summarization
|
||||
SYSTEM_PROMPT = """You are a meeting summarization assistant. Analyze the transcript and produce structured output.
|
||||
|
||||
OUTPUT FORMAT (JSON):
|
||||
{
|
||||
"executive_summary": "2-3 sentence high-level overview",
|
||||
"key_points": [
|
||||
{"text": "Key insight or decision", "segment_ids": [0, 1]}
|
||||
],
|
||||
"action_items": [
|
||||
{"text": "Action to take", "assignee": "Person name or empty string", "priority": 0, "segment_ids": [2]}
|
||||
]
|
||||
}
|
||||
|
||||
RULES:
|
||||
1. Each key_point and action_item MUST have at least one segment_id referencing the source
|
||||
2. segment_ids are integers matching the [N] markers in the transcript
|
||||
3. priority: 0=unspecified, 1=low, 2=medium, 3=high
|
||||
4. Only extract action items that clearly indicate tasks to be done
|
||||
5. Output ONLY valid JSON, no markdown or explanation"""
|
||||
|
||||
|
||||
def build_transcript_prompt(request: SummarizationRequest) -> str:
|
||||
"""Build transcript prompt with segment markers.
|
||||
|
||||
Args:
|
||||
request: Summarization request with segments.
|
||||
|
||||
Returns:
|
||||
Formatted prompt string with transcript and constraints.
|
||||
"""
|
||||
lines = [f"[{seg.segment_id}] {seg.text}" for seg in request.segments]
|
||||
constraints = ""
|
||||
if request.segments:
|
||||
valid_ids = ", ".join(str(seg.segment_id) for seg in request.segments)
|
||||
constraints = (
|
||||
"\n\nCONSTRAINTS:\n"
|
||||
f"- Maximum {request.max_key_points} key points\n"
|
||||
f"- Maximum {request.max_action_items} action items\n"
|
||||
f"- Valid segment_ids: {valid_ids}"
|
||||
)
|
||||
|
||||
return f"TRANSCRIPT:\n{chr(10).join(lines)}{constraints}"
|
||||
|
||||
|
||||
def parse_llm_response(response_text: str, request: SummarizationRequest) -> Summary:
|
||||
"""Parse JSON response into Summary entity.
|
||||
|
||||
Args:
|
||||
response_text: Raw JSON response from LLM.
|
||||
request: Original request for validation context.
|
||||
|
||||
Returns:
|
||||
Summary entity with parsed data.
|
||||
|
||||
Raises:
|
||||
InvalidResponseError: If JSON is malformed.
|
||||
"""
|
||||
# Strip markdown code fences if present
|
||||
text = response_text.strip()
|
||||
if text.startswith("```"):
|
||||
lines = text.split("\n")
|
||||
if lines[0].startswith("```"):
|
||||
lines = lines[1:]
|
||||
if lines and lines[-1].strip() == "```":
|
||||
lines = lines[:-1]
|
||||
text = "\n".join(lines)
|
||||
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise InvalidResponseError(f"Invalid JSON response: {e}") from e
|
||||
|
||||
valid_ids = {seg.segment_id for seg in request.segments}
|
||||
|
||||
# Parse key points
|
||||
key_points: list[KeyPoint] = []
|
||||
for kp_data in data.get("key_points", [])[: request.max_key_points]:
|
||||
seg_ids = [sid for sid in kp_data.get("segment_ids", []) if sid in valid_ids]
|
||||
start_time = 0.0
|
||||
end_time = 0.0
|
||||
if seg_ids:
|
||||
if refs := [
|
||||
s for s in request.segments if s.segment_id in seg_ids
|
||||
]:
|
||||
start_time = min(s.start_time for s in refs)
|
||||
end_time = max(s.end_time for s in refs)
|
||||
key_points.append(
|
||||
KeyPoint(
|
||||
text=str(kp_data.get("text", "")),
|
||||
segment_ids=seg_ids,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
)
|
||||
|
||||
# Parse action items
|
||||
action_items: list[ActionItem] = []
|
||||
for ai_data in data.get("action_items", [])[: request.max_action_items]:
|
||||
seg_ids = [sid for sid in ai_data.get("segment_ids", []) if sid in valid_ids]
|
||||
priority = ai_data.get("priority", 0)
|
||||
if not isinstance(priority, int) or priority not in range(4):
|
||||
priority = 0
|
||||
action_items.append(
|
||||
ActionItem(
|
||||
text=str(ai_data.get("text", "")),
|
||||
assignee=str(ai_data.get("assignee", "")),
|
||||
priority=priority,
|
||||
segment_ids=seg_ids,
|
||||
)
|
||||
)
|
||||
|
||||
return Summary(
|
||||
meeting_id=request.meeting_id,
|
||||
executive_summary=str(data.get("executive_summary", "")),
|
||||
key_points=key_points,
|
||||
action_items=action_items,
|
||||
generated_at=datetime.now(UTC),
|
||||
)
|
||||
124
src/noteflow/infrastructure/summarization/citation_verifier.py
Normal file
124
src/noteflow/infrastructure/summarization/citation_verifier.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Citation verification implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.domain.summarization import CitationVerificationResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from noteflow.domain.entities import Segment, Summary
|
||||
|
||||
|
||||
class SegmentCitationVerifier:
|
||||
"""Verify that summary citations reference valid segments.
|
||||
|
||||
Checks that all segment_ids in key points and action items
|
||||
correspond to actual segments in the transcript.
|
||||
"""
|
||||
|
||||
def verify_citations(
|
||||
self,
|
||||
summary: Summary,
|
||||
segments: Sequence[Segment],
|
||||
) -> CitationVerificationResult:
|
||||
"""Verify all segment_ids exist in the transcript.
|
||||
|
||||
Args:
|
||||
summary: Summary with key points and action items to verify.
|
||||
segments: Available transcript segments.
|
||||
|
||||
Returns:
|
||||
CitationVerificationResult with validation status and details.
|
||||
"""
|
||||
# Build set of valid segment IDs
|
||||
valid_segment_ids = {seg.segment_id for seg in segments}
|
||||
|
||||
# Track invalid citations
|
||||
invalid_key_point_indices: list[int] = []
|
||||
invalid_action_item_indices: list[int] = []
|
||||
missing_segment_ids: set[int] = set()
|
||||
|
||||
# Verify key points
|
||||
for idx, key_point in enumerate(summary.key_points):
|
||||
for seg_id in key_point.segment_ids:
|
||||
if seg_id not in valid_segment_ids:
|
||||
if idx not in invalid_key_point_indices:
|
||||
invalid_key_point_indices.append(idx)
|
||||
missing_segment_ids.add(seg_id)
|
||||
|
||||
# Verify action items
|
||||
for idx, action_item in enumerate(summary.action_items):
|
||||
for seg_id in action_item.segment_ids:
|
||||
if seg_id not in valid_segment_ids:
|
||||
if idx not in invalid_action_item_indices:
|
||||
invalid_action_item_indices.append(idx)
|
||||
missing_segment_ids.add(seg_id)
|
||||
|
||||
is_valid = not invalid_key_point_indices and not invalid_action_item_indices
|
||||
|
||||
return CitationVerificationResult(
|
||||
is_valid=is_valid,
|
||||
invalid_key_point_indices=tuple(invalid_key_point_indices),
|
||||
invalid_action_item_indices=tuple(invalid_action_item_indices),
|
||||
missing_segment_ids=tuple(sorted(missing_segment_ids)),
|
||||
)
|
||||
|
||||
def filter_invalid_citations(
|
||||
self,
|
||||
summary: Summary,
|
||||
segments: Sequence[Segment],
|
||||
) -> Summary:
|
||||
"""Return a copy of the summary with invalid citations removed.
|
||||
|
||||
Invalid segment_ids are removed from key points and action items.
|
||||
Items with no remaining citations keep empty segment_ids lists.
|
||||
|
||||
Args:
|
||||
summary: Summary to filter.
|
||||
segments: Available transcript segments.
|
||||
|
||||
Returns:
|
||||
New Summary with invalid citations removed.
|
||||
"""
|
||||
valid_segment_ids = {seg.segment_id for seg in segments}
|
||||
|
||||
# Filter key point citations
|
||||
from noteflow.domain.entities import ActionItem, KeyPoint
|
||||
from noteflow.domain.entities import Summary as SummaryEntity
|
||||
|
||||
filtered_key_points = [
|
||||
KeyPoint(
|
||||
text=kp.text,
|
||||
segment_ids=[sid for sid in kp.segment_ids if sid in valid_segment_ids],
|
||||
start_time=kp.start_time,
|
||||
end_time=kp.end_time,
|
||||
db_id=kp.db_id,
|
||||
)
|
||||
for kp in summary.key_points
|
||||
]
|
||||
|
||||
# Filter action item citations
|
||||
filtered_action_items = [
|
||||
ActionItem(
|
||||
text=ai.text,
|
||||
assignee=ai.assignee,
|
||||
due_date=ai.due_date,
|
||||
priority=ai.priority,
|
||||
segment_ids=[sid for sid in ai.segment_ids if sid in valid_segment_ids],
|
||||
db_id=ai.db_id,
|
||||
)
|
||||
for ai in summary.action_items
|
||||
]
|
||||
|
||||
return SummaryEntity(
|
||||
meeting_id=summary.meeting_id,
|
||||
executive_summary=summary.executive_summary,
|
||||
key_points=filtered_key_points,
|
||||
action_items=filtered_action_items,
|
||||
generated_at=summary.generated_at,
|
||||
model_version=summary.model_version,
|
||||
db_id=summary.db_id,
|
||||
)
|
||||
249
src/noteflow/infrastructure/summarization/cloud_provider.py
Normal file
249
src/noteflow/infrastructure/summarization/cloud_provider.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""Cloud summarization provider for OpenAI/Anthropic APIs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from noteflow.domain.entities import Summary
|
||||
from noteflow.domain.summarization import (
|
||||
InvalidResponseError,
|
||||
ProviderUnavailableError,
|
||||
SummarizationRequest,
|
||||
SummarizationResult,
|
||||
SummarizationTimeoutError,
|
||||
)
|
||||
from noteflow.infrastructure.summarization._parsing import (
|
||||
SYSTEM_PROMPT,
|
||||
build_transcript_prompt,
|
||||
parse_llm_response,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import anthropic
|
||||
import openai
|
||||
|
||||
|
||||
class CloudBackend(Enum):
|
||||
"""Supported cloud LLM backends."""
|
||||
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
|
||||
|
||||
class CloudSummarizer:
|
||||
"""Cloud-based LLM summarizer using OpenAI or Anthropic.
|
||||
|
||||
Requires explicit user consent as data is sent to external services.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backend: CloudBackend = CloudBackend.OPENAI,
|
||||
api_key: str | None = None,
|
||||
model: str | None = None,
|
||||
timeout_seconds: float = 60.0,
|
||||
) -> None:
|
||||
"""Initialize cloud summarizer.
|
||||
|
||||
Args:
|
||||
backend: Cloud provider backend (OpenAI or Anthropic).
|
||||
api_key: API key (defaults to env var if not provided).
|
||||
model: Model name (defaults per backend if not provided).
|
||||
timeout_seconds: Request timeout in seconds.
|
||||
"""
|
||||
self._backend = backend
|
||||
self._api_key = api_key
|
||||
self._timeout = timeout_seconds
|
||||
self._client: openai.OpenAI | anthropic.Anthropic | None = None
|
||||
|
||||
# Set default models per backend
|
||||
if model is None:
|
||||
self._model = (
|
||||
"gpt-4o-mini" if backend == CloudBackend.OPENAI else "claude-3-haiku-20240307"
|
||||
)
|
||||
else:
|
||||
self._model = model
|
||||
|
||||
def _get_openai_client(self) -> openai.OpenAI:
|
||||
"""Get or create OpenAI client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
import openai
|
||||
|
||||
self._client = openai.OpenAI(api_key=self._api_key, timeout=self._timeout)
|
||||
except ImportError as e:
|
||||
raise ProviderUnavailableError(
|
||||
"openai package not installed. Install with: pip install openai"
|
||||
) from e
|
||||
return cast(openai.OpenAI, self._client)
|
||||
|
||||
def _get_anthropic_client(self) -> anthropic.Anthropic:
|
||||
"""Get or create Anthropic client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
import anthropic
|
||||
|
||||
self._client = anthropic.Anthropic(api_key=self._api_key, timeout=self._timeout)
|
||||
except ImportError as e:
|
||||
raise ProviderUnavailableError(
|
||||
"anthropic package not installed. Install with: pip install anthropic"
|
||||
) from e
|
||||
return cast(anthropic.Anthropic, self._client)
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
"""Provider identifier."""
|
||||
return self._backend.value
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
"""Check if cloud provider is configured with an API key."""
|
||||
import os
|
||||
|
||||
if self._api_key:
|
||||
return True
|
||||
|
||||
# Check environment variables
|
||||
if self._backend == CloudBackend.OPENAI:
|
||||
return bool(os.environ.get("OPENAI_API_KEY"))
|
||||
return bool(os.environ.get("ANTHROPIC_API_KEY"))
|
||||
|
||||
@property
|
||||
def requires_cloud_consent(self) -> bool:
|
||||
"""Cloud providers require explicit user consent."""
|
||||
return True
|
||||
|
||||
async def summarize(self, request: SummarizationRequest) -> SummarizationResult:
|
||||
"""Generate evidence-linked summary using cloud LLM.
|
||||
|
||||
Args:
|
||||
request: Summarization request with segments.
|
||||
|
||||
Returns:
|
||||
SummarizationResult with generated summary.
|
||||
|
||||
Raises:
|
||||
ProviderUnavailableError: If provider not configured.
|
||||
SummarizationTimeoutError: If request times out.
|
||||
InvalidResponseError: If response cannot be parsed.
|
||||
"""
|
||||
start = time.monotonic()
|
||||
|
||||
# Handle empty segments
|
||||
if not request.segments:
|
||||
return SummarizationResult(
|
||||
summary=Summary(
|
||||
meeting_id=request.meeting_id,
|
||||
executive_summary="No transcript segments to summarize.",
|
||||
key_points=[],
|
||||
action_items=[],
|
||||
generated_at=datetime.now(UTC),
|
||||
model_version=self._model,
|
||||
),
|
||||
model_name=self._model,
|
||||
provider_name=self.provider_name,
|
||||
tokens_used=None,
|
||||
latency_ms=0.0,
|
||||
)
|
||||
|
||||
user_prompt = build_transcript_prompt(request)
|
||||
|
||||
if self._backend == CloudBackend.OPENAI:
|
||||
content, tokens_used = await asyncio.to_thread(self._call_openai, user_prompt)
|
||||
else:
|
||||
content, tokens_used = await asyncio.to_thread(self._call_anthropic, user_prompt)
|
||||
|
||||
# Parse into Summary
|
||||
summary = parse_llm_response(content, request)
|
||||
summary = Summary(
|
||||
meeting_id=summary.meeting_id,
|
||||
executive_summary=summary.executive_summary,
|
||||
key_points=summary.key_points,
|
||||
action_items=summary.action_items,
|
||||
generated_at=summary.generated_at,
|
||||
model_version=self._model,
|
||||
)
|
||||
|
||||
elapsed_ms = (time.monotonic() - start) * 1000
|
||||
|
||||
return SummarizationResult(
|
||||
summary=summary,
|
||||
model_name=self._model,
|
||||
provider_name=self.provider_name,
|
||||
tokens_used=tokens_used,
|
||||
latency_ms=elapsed_ms,
|
||||
)
|
||||
|
||||
def _call_openai(self, user_prompt: str) -> tuple[str, int | None]:
|
||||
"""Call OpenAI API and return (content, tokens_used)."""
|
||||
try:
|
||||
client = self._get_openai_client()
|
||||
except ProviderUnavailableError:
|
||||
raise
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=self._model,
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
temperature=0.3,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
except TimeoutError as e:
|
||||
raise SummarizationTimeoutError(f"OpenAI request timed out: {e}") from e
|
||||
except Exception as e:
|
||||
err_str = str(e).lower()
|
||||
if "api key" in err_str or "authentication" in err_str:
|
||||
raise ProviderUnavailableError(f"OpenAI authentication failed: {e}") from e
|
||||
if "rate limit" in err_str:
|
||||
raise SummarizationTimeoutError(f"OpenAI rate limited: {e}") from e
|
||||
raise InvalidResponseError(f"OpenAI error: {e}") from e
|
||||
|
||||
content = response.choices[0].message.content or ""
|
||||
if not content:
|
||||
raise InvalidResponseError("Empty response from OpenAI")
|
||||
|
||||
tokens_used = response.usage.total_tokens if response.usage else None
|
||||
return content, tokens_used
|
||||
|
||||
def _call_anthropic(self, user_prompt: str) -> tuple[str, int | None]:
|
||||
"""Call Anthropic API and return (content, tokens_used)."""
|
||||
try:
|
||||
client = self._get_anthropic_client()
|
||||
except ProviderUnavailableError:
|
||||
raise
|
||||
|
||||
try:
|
||||
response = client.messages.create(
|
||||
model=self._model,
|
||||
max_tokens=4096,
|
||||
system=SYSTEM_PROMPT,
|
||||
messages=[{"role": "user", "content": user_prompt}],
|
||||
)
|
||||
except TimeoutError as e:
|
||||
raise SummarizationTimeoutError(f"Anthropic request timed out: {e}") from e
|
||||
except Exception as e:
|
||||
err_str = str(e).lower()
|
||||
if "api key" in err_str or "authentication" in err_str:
|
||||
raise ProviderUnavailableError(f"Anthropic authentication failed: {e}") from e
|
||||
if "rate limit" in err_str:
|
||||
raise SummarizationTimeoutError(f"Anthropic rate limited: {e}") from e
|
||||
raise InvalidResponseError(f"Anthropic error: {e}") from e
|
||||
|
||||
content = "".join(
|
||||
block.text for block in response.content if hasattr(block, "text")
|
||||
)
|
||||
if not content:
|
||||
raise InvalidResponseError("Empty response from Anthropic")
|
||||
|
||||
tokens_used = None
|
||||
if hasattr(response, "usage"):
|
||||
tokens_used = response.usage.input_tokens + response.usage.output_tokens
|
||||
|
||||
return content, tokens_used
|
||||
115
src/noteflow/infrastructure/summarization/mock_provider.py
Normal file
115
src/noteflow/infrastructure/summarization/mock_provider.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Mock summarization provider for testing."""
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.domain.entities import ActionItem, KeyPoint, Summary
|
||||
from noteflow.domain.summarization import (
|
||||
SummarizationRequest,
|
||||
SummarizationResult,
|
||||
)
|
||||
|
||||
|
||||
class MockSummarizer:
|
||||
"""Deterministic mock summarizer for testing.
|
||||
|
||||
Generates predictable summaries based on input segments without
|
||||
requiring an actual LLM. Useful for unit tests and development.
|
||||
"""
|
||||
|
||||
def __init__(self, latency_ms: float = 10.0) -> None:
|
||||
"""Initialize mock summarizer.
|
||||
|
||||
Args:
|
||||
latency_ms: Simulated latency in milliseconds.
|
||||
"""
|
||||
self._latency_ms = latency_ms
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
"""Provider identifier."""
|
||||
return "mock"
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
"""Mock provider is always available."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def requires_cloud_consent(self) -> bool:
|
||||
"""Mock provider does not send data externally."""
|
||||
return False
|
||||
|
||||
async def summarize(self, request: SummarizationRequest) -> SummarizationResult:
|
||||
"""Generate deterministic mock summary.
|
||||
|
||||
Creates key points and action items based on segment content,
|
||||
with proper evidence linking to segment_ids.
|
||||
|
||||
Args:
|
||||
request: Summarization request with segments.
|
||||
|
||||
Returns:
|
||||
SummarizationResult with mock summary.
|
||||
"""
|
||||
start = time.monotonic()
|
||||
|
||||
# Generate executive summary
|
||||
segment_count = request.segment_count
|
||||
total_duration = request.total_duration
|
||||
executive_summary = (
|
||||
f"Meeting with {segment_count} segments spanning {total_duration:.1f} seconds."
|
||||
)
|
||||
|
||||
# Generate key points from segments (up to max_key_points)
|
||||
key_points: list[KeyPoint] = []
|
||||
for i, segment in enumerate(request.segments[: request.max_key_points]):
|
||||
# Truncate text for key point
|
||||
text = f"{segment.text[:100]}..." if len(segment.text) > 100 else segment.text
|
||||
key_points.append(
|
||||
KeyPoint(
|
||||
text=f"Point {i + 1}: {text}",
|
||||
segment_ids=[segment.segment_id],
|
||||
start_time=segment.start_time,
|
||||
end_time=segment.end_time,
|
||||
)
|
||||
)
|
||||
|
||||
# Generate action items from segments containing action words
|
||||
action_items: list[ActionItem] = []
|
||||
action_keywords = {"todo", "action", "will", "should", "must", "need to"}
|
||||
for segment in request.segments:
|
||||
text_lower = segment.text.lower()
|
||||
if any(kw in text_lower for kw in action_keywords):
|
||||
if len(action_items) >= request.max_action_items:
|
||||
break
|
||||
action_items.append(
|
||||
ActionItem(
|
||||
text=f"Action: {segment.text[:80]}",
|
||||
assignee="", # Mock doesn't extract assignees
|
||||
segment_ids=[segment.segment_id],
|
||||
)
|
||||
)
|
||||
|
||||
summary = Summary(
|
||||
meeting_id=request.meeting_id,
|
||||
executive_summary=executive_summary,
|
||||
key_points=key_points,
|
||||
action_items=action_items,
|
||||
generated_at=datetime.now(UTC),
|
||||
model_version="mock-1.0",
|
||||
)
|
||||
|
||||
elapsed = (time.monotonic() - start) * 1000 + self._latency_ms
|
||||
|
||||
return SummarizationResult(
|
||||
summary=summary,
|
||||
model_name="mock-1.0",
|
||||
provider_name=self.provider_name,
|
||||
tokens_used=None,
|
||||
latency_ms=elapsed,
|
||||
)
|
||||
176
src/noteflow/infrastructure/summarization/ollama_provider.py
Normal file
176
src/noteflow/infrastructure/summarization/ollama_provider.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""Ollama summarization provider for local LLM inference."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.domain.entities import Summary
|
||||
from noteflow.domain.summarization import (
|
||||
InvalidResponseError,
|
||||
ProviderUnavailableError,
|
||||
SummarizationRequest,
|
||||
SummarizationResult,
|
||||
SummarizationTimeoutError,
|
||||
)
|
||||
from noteflow.infrastructure.summarization._parsing import (
|
||||
SYSTEM_PROMPT,
|
||||
build_transcript_prompt,
|
||||
parse_llm_response,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import ollama
|
||||
|
||||
|
||||
class OllamaSummarizer:
|
||||
"""Ollama-based local LLM summarizer.
|
||||
|
||||
Uses a local Ollama server for privacy-preserving summarization.
|
||||
No data is sent to external cloud services.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "llama3.2",
|
||||
host: str = "http://localhost:11434",
|
||||
timeout_seconds: float = 120.0,
|
||||
) -> None:
|
||||
"""Initialize Ollama summarizer.
|
||||
|
||||
Args:
|
||||
model: Ollama model name (e.g., 'llama3.2', 'mistral').
|
||||
host: Ollama server URL.
|
||||
timeout_seconds: Request timeout in seconds.
|
||||
"""
|
||||
self._model = model
|
||||
self._host = host
|
||||
self._timeout = timeout_seconds
|
||||
self._client: ollama.Client | None = None
|
||||
|
||||
def _get_client(self) -> ollama.Client:
|
||||
"""Lazy-load Ollama client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
import ollama
|
||||
|
||||
self._client = ollama.Client(host=self._host)
|
||||
except ImportError as e:
|
||||
raise ProviderUnavailableError(
|
||||
"ollama package not installed. Install with: pip install ollama"
|
||||
) from e
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
"""Provider identifier."""
|
||||
return "ollama"
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
"""Check if Ollama server is reachable."""
|
||||
try:
|
||||
client = self._get_client()
|
||||
# Try to list models to verify connectivity
|
||||
client.list()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def requires_cloud_consent(self) -> bool:
|
||||
"""Ollama runs locally, no cloud consent required."""
|
||||
return False
|
||||
|
||||
async def summarize(self, request: SummarizationRequest) -> SummarizationResult:
|
||||
"""Generate evidence-linked summary using Ollama.
|
||||
|
||||
Args:
|
||||
request: Summarization request with segments.
|
||||
|
||||
Returns:
|
||||
SummarizationResult with generated summary.
|
||||
|
||||
Raises:
|
||||
ProviderUnavailableError: If Ollama is not accessible.
|
||||
SummarizationTimeoutError: If request times out.
|
||||
InvalidResponseError: If response cannot be parsed.
|
||||
"""
|
||||
start = time.monotonic()
|
||||
|
||||
# Handle empty segments
|
||||
if not request.segments:
|
||||
return SummarizationResult(
|
||||
summary=Summary(
|
||||
meeting_id=request.meeting_id,
|
||||
executive_summary="No transcript segments to summarize.",
|
||||
key_points=[],
|
||||
action_items=[],
|
||||
generated_at=datetime.now(UTC),
|
||||
model_version=self._model,
|
||||
),
|
||||
model_name=self._model,
|
||||
provider_name=self.provider_name,
|
||||
tokens_used=None,
|
||||
latency_ms=0.0,
|
||||
)
|
||||
|
||||
try:
|
||||
client = self._get_client()
|
||||
except ProviderUnavailableError:
|
||||
raise
|
||||
|
||||
user_prompt = build_transcript_prompt(request)
|
||||
|
||||
try:
|
||||
# Offload blocking call to a worker thread to avoid blocking the event loop
|
||||
response = await asyncio.to_thread(
|
||||
client.chat,
|
||||
model=self._model,
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
options={"temperature": 0.3},
|
||||
format="json",
|
||||
)
|
||||
except TimeoutError as e:
|
||||
raise SummarizationTimeoutError(f"Ollama request timed out: {e}") from e
|
||||
except Exception as e:
|
||||
err_str = str(e).lower()
|
||||
if "connection" in err_str or "refused" in err_str:
|
||||
raise ProviderUnavailableError(f"Cannot connect to Ollama: {e}") from e
|
||||
raise InvalidResponseError(f"Ollama error: {e}") from e
|
||||
|
||||
# Extract response text
|
||||
content = response.get("message", {}).get("content", "")
|
||||
if not content:
|
||||
raise InvalidResponseError("Empty response from Ollama")
|
||||
|
||||
# Parse into Summary
|
||||
summary = parse_llm_response(content, request)
|
||||
summary = Summary(
|
||||
meeting_id=summary.meeting_id,
|
||||
executive_summary=summary.executive_summary,
|
||||
key_points=summary.key_points,
|
||||
action_items=summary.action_items,
|
||||
generated_at=summary.generated_at,
|
||||
model_version=self._model,
|
||||
)
|
||||
|
||||
elapsed_ms = (time.monotonic() - start) * 1000
|
||||
|
||||
# Extract token usage if available
|
||||
tokens_used = None
|
||||
if "eval_count" in response:
|
||||
tokens_used = response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
|
||||
|
||||
return SummarizationResult(
|
||||
summary=summary,
|
||||
model_name=self._model,
|
||||
provider_name=self.provider_name,
|
||||
tokens_used=tokens_used,
|
||||
latency_ms=elapsed_ms,
|
||||
)
|
||||
20
src/noteflow/infrastructure/triggers/__init__.py
Normal file
20
src/noteflow/infrastructure/triggers/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Trigger infrastructure module.
|
||||
|
||||
Provide signal providers for meeting detection triggers.
|
||||
"""
|
||||
|
||||
from noteflow.infrastructure.triggers.audio_activity import (
|
||||
AudioActivityProvider,
|
||||
AudioActivitySettings,
|
||||
)
|
||||
from noteflow.infrastructure.triggers.foreground_app import (
|
||||
ForegroundAppProvider,
|
||||
ForegroundAppSettings,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AudioActivityProvider",
|
||||
"AudioActivitySettings",
|
||||
"ForegroundAppProvider",
|
||||
"ForegroundAppSettings",
|
||||
]
|
||||
143
src/noteflow/infrastructure/triggers/audio_activity.py
Normal file
143
src/noteflow/infrastructure/triggers/audio_activity.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Audio activity signal provider.
|
||||
|
||||
Detect sustained audio activity using existing RmsLevelProvider.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.domain.triggers.entities import TriggerSignal, TriggerSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from noteflow.infrastructure.audio import RmsLevelProvider
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioActivitySettings:
|
||||
"""Configuration for audio activity detection.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether audio activity detection is enabled.
|
||||
threshold_db: Minimum dB level to consider as activity (default -40 dB).
|
||||
window_seconds: Time window for sustained activity detection.
|
||||
min_active_ratio: Minimum ratio of active samples in window (0.0-1.0).
|
||||
min_samples: Minimum samples required before evaluation.
|
||||
max_history: Maximum samples retained in history.
|
||||
weight: Confidence weight contributed by this provider.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
threshold_db: float
|
||||
window_seconds: float
|
||||
min_active_ratio: float
|
||||
min_samples: int
|
||||
max_history: int
|
||||
weight: float
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.min_samples > self.max_history:
|
||||
msg = "min_samples must be <= max_history"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
class AudioActivityProvider:
|
||||
"""Detect sustained audio activity using existing RmsLevelProvider.
|
||||
|
||||
Reuses RmsLevelProvider from infrastructure/audio for dB calculation.
|
||||
Tracks activity history over a sliding window and generates signals
|
||||
when sustained speech activity is detected.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
level_provider: RmsLevelProvider,
|
||||
settings: AudioActivitySettings,
|
||||
) -> None:
|
||||
"""Initialize audio activity provider.
|
||||
|
||||
Args:
|
||||
level_provider: Existing RmsLevelProvider instance to reuse.
|
||||
settings: Configuration settings for audio activity detection.
|
||||
"""
|
||||
self._level_provider = level_provider
|
||||
self._settings = settings
|
||||
self._history: deque[tuple[float, bool]] = deque(maxlen=self._settings.max_history)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def source(self) -> TriggerSource:
|
||||
"""Get the source type for this provider."""
|
||||
return TriggerSource.AUDIO_ACTIVITY
|
||||
|
||||
@property
|
||||
def max_weight(self) -> float:
|
||||
"""Get the maximum weight this provider can contribute."""
|
||||
return self._settings.weight
|
||||
|
||||
def update(self, frames: NDArray[np.float32], timestamp: float) -> None:
|
||||
"""Update activity history with new audio frames.
|
||||
|
||||
Call this from the audio capture callback to feed new samples.
|
||||
|
||||
Args:
|
||||
frames: Audio samples as float32 array.
|
||||
timestamp: Monotonic timestamp of the audio chunk.
|
||||
"""
|
||||
if not self._settings.enabled:
|
||||
return
|
||||
|
||||
db = self._level_provider.get_db(frames)
|
||||
is_active = db >= self._settings.threshold_db
|
||||
with self._lock:
|
||||
self._history.append((timestamp, is_active))
|
||||
|
||||
def get_signal(self) -> TriggerSignal | None:
|
||||
"""Get current signal if sustained activity detected.
|
||||
|
||||
Returns:
|
||||
TriggerSignal if activity ratio exceeds threshold, None otherwise.
|
||||
"""
|
||||
if not self._settings.enabled:
|
||||
return None
|
||||
|
||||
# Need minimum samples before we can evaluate
|
||||
with self._lock:
|
||||
history = list(self._history)
|
||||
|
||||
if len(history) < self._settings.min_samples:
|
||||
return None
|
||||
|
||||
# Prune old samples outside window
|
||||
now = time.monotonic()
|
||||
cutoff = now - self._settings.window_seconds
|
||||
recent = [(ts, active) for ts, active in history if ts >= cutoff]
|
||||
|
||||
if len(recent) < self._settings.min_samples:
|
||||
return None
|
||||
|
||||
# Calculate activity ratio
|
||||
active_count = sum(bool(active)
|
||||
for _, active in recent)
|
||||
ratio = active_count / len(recent)
|
||||
|
||||
if ratio < self._settings.min_active_ratio:
|
||||
return None
|
||||
|
||||
return TriggerSignal(source=self.source, weight=self.max_weight)
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if this provider is enabled."""
|
||||
return self._settings.enabled
|
||||
|
||||
def clear_history(self) -> None:
|
||||
"""Clear activity history. Useful when recording starts."""
|
||||
with self._lock:
|
||||
self._history.clear()
|
||||
157
src/noteflow/infrastructure/triggers/foreground_app.py
Normal file
157
src/noteflow/infrastructure/triggers/foreground_app.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""Foreground app detection using PyWinCtl.
|
||||
|
||||
Detect meeting applications in the foreground window.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from noteflow.domain.triggers.entities import TriggerSignal, TriggerSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForegroundAppSettings:
|
||||
"""Configuration for foreground app detection.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether foreground app detection is enabled.
|
||||
weight: Confidence weight contributed by this provider.
|
||||
meeting_apps: Set of app name substrings to match (lowercase).
|
||||
suppressed_apps: Apps to ignore even if they match meeting_apps.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
weight: float
|
||||
meeting_apps: set[str] = field(default_factory=set)
|
||||
suppressed_apps: set[str] = field(default_factory=set)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.meeting_apps = {app.lower() for app in self.meeting_apps}
|
||||
self.suppressed_apps = {app.lower() for app in self.suppressed_apps}
|
||||
|
||||
|
||||
class ForegroundAppProvider:
|
||||
"""Detect meeting apps in foreground using PyWinCtl.
|
||||
|
||||
PyWinCtl provides cross-platform active window detection for
|
||||
Linux (X11/Wayland), macOS, and Windows.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: ForegroundAppSettings) -> None:
|
||||
"""Initialize foreground app provider.
|
||||
|
||||
Args:
|
||||
settings: Configuration settings for foreground app detection.
|
||||
"""
|
||||
self._settings = settings
|
||||
self._available: bool | None = None
|
||||
|
||||
@property
|
||||
def source(self) -> TriggerSource:
|
||||
"""Get the source type for this provider."""
|
||||
return TriggerSource.FOREGROUND_APP
|
||||
|
||||
@property
|
||||
def max_weight(self) -> float:
|
||||
"""Get the maximum weight this provider can contribute."""
|
||||
return self._settings.weight
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if this provider is enabled and available."""
|
||||
return self._settings.enabled and self._is_available()
|
||||
|
||||
def _is_available(self) -> bool:
|
||||
"""Check if PyWinCtl is available and working."""
|
||||
if self._available is not None:
|
||||
return self._available
|
||||
|
||||
try:
|
||||
import pywinctl
|
||||
|
||||
# Try to get active window to verify it works
|
||||
_ = pywinctl.getActiveWindow()
|
||||
self._available = True
|
||||
logger.debug("PyWinCtl available for foreground detection")
|
||||
except ImportError:
|
||||
self._available = False
|
||||
logger.warning("PyWinCtl not installed - foreground detection disabled")
|
||||
except Exception as e:
|
||||
self._available = False
|
||||
logger.warning("PyWinCtl unavailable: %s - foreground detection disabled", e)
|
||||
|
||||
return self._available
|
||||
|
||||
def get_signal(self) -> TriggerSignal | None:
|
||||
"""Get current signal if meeting app is in foreground.
|
||||
|
||||
Returns:
|
||||
TriggerSignal if a meeting app is detected, None otherwise.
|
||||
"""
|
||||
if not self.is_enabled():
|
||||
return None
|
||||
|
||||
try:
|
||||
import pywinctl
|
||||
|
||||
window = pywinctl.getActiveWindow()
|
||||
if not window:
|
||||
return None
|
||||
|
||||
title = window.title
|
||||
if not title:
|
||||
return None
|
||||
|
||||
title_lower = title.lower()
|
||||
|
||||
# Check if app is suppressed
|
||||
for suppressed in self._settings.suppressed_apps:
|
||||
if suppressed in title_lower:
|
||||
return None
|
||||
|
||||
# Check if it's a meeting app
|
||||
for app in self._settings.meeting_apps:
|
||||
if app in title_lower:
|
||||
return TriggerSignal(
|
||||
source=self.source,
|
||||
weight=self.max_weight,
|
||||
app_name=title,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Foreground detection error: %s", e)
|
||||
|
||||
return None
|
||||
|
||||
def suppress_app(self, app_name: str) -> None:
|
||||
"""Add an app to the suppression list.
|
||||
|
||||
Args:
|
||||
app_name: App name substring to suppress (will be lowercased).
|
||||
"""
|
||||
self._settings.suppressed_apps.add(app_name.lower())
|
||||
logger.info("Suppressed app: %s", app_name)
|
||||
|
||||
def unsuppress_app(self, app_name: str) -> None:
|
||||
"""Remove an app from the suppression list.
|
||||
|
||||
Args:
|
||||
app_name: App name substring to unsuppress.
|
||||
"""
|
||||
self._settings.suppressed_apps.discard(app_name.lower())
|
||||
|
||||
def add_meeting_app(self, app_name: str) -> None:
|
||||
"""Add an app to the meeting apps list.
|
||||
|
||||
Args:
|
||||
app_name: App name substring to add (will be lowercased).
|
||||
"""
|
||||
self._settings.meeting_apps.add(app_name.lower())
|
||||
|
||||
@property
|
||||
def suppressed_apps(self) -> frozenset[str]:
|
||||
"""Get current suppressed apps."""
|
||||
return frozenset(self._settings.suppressed_apps)
|
||||
412
tests/application/test_summarization_service.py
Normal file
412
tests/application/test_summarization_service.py
Normal file
@@ -0,0 +1,412 @@
|
||||
"""Tests for summarization service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.application.services import (
|
||||
SummarizationMode,
|
||||
SummarizationService,
|
||||
SummarizationServiceSettings,
|
||||
)
|
||||
from noteflow.domain.entities import KeyPoint, Segment, Summary
|
||||
from noteflow.domain.summarization import (
|
||||
CitationVerificationResult,
|
||||
ProviderUnavailableError,
|
||||
SummarizationRequest,
|
||||
SummarizationResult,
|
||||
)
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
|
||||
|
||||
def _segment(segment_id: int, text: str = "Test") -> Segment:
|
||||
"""Create a test segment."""
|
||||
return Segment(
|
||||
segment_id=segment_id,
|
||||
text=text,
|
||||
start_time=segment_id * 5.0,
|
||||
end_time=(segment_id + 1) * 5.0,
|
||||
)
|
||||
|
||||
|
||||
class MockProvider:
|
||||
"""Mock summarizer provider for testing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "mock",
|
||||
available: bool = True,
|
||||
requires_consent: bool = False,
|
||||
) -> None:
|
||||
self._name = name
|
||||
self._available = available
|
||||
self._requires_consent = requires_consent
|
||||
self.call_count = 0
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
return self._available
|
||||
|
||||
@property
|
||||
def requires_cloud_consent(self) -> bool:
|
||||
return self._requires_consent
|
||||
|
||||
async def summarize(self, request: SummarizationRequest) -> SummarizationResult:
|
||||
self.call_count += 1
|
||||
summary = Summary(
|
||||
meeting_id=request.meeting_id,
|
||||
executive_summary=f"Summary from {self._name}",
|
||||
key_points=[KeyPoint(text=f"Point from {self._name}", segment_ids=[0])],
|
||||
generated_at=datetime.now(UTC),
|
||||
)
|
||||
return SummarizationResult(
|
||||
summary=summary,
|
||||
model_name=f"{self._name}-model",
|
||||
provider_name=self._name,
|
||||
)
|
||||
|
||||
|
||||
class MockVerifier:
|
||||
"""Mock citation verifier for testing."""
|
||||
|
||||
def __init__(self, is_valid: bool = True) -> None:
|
||||
self._is_valid = is_valid
|
||||
self.verify_call_count = 0
|
||||
self.filter_call_count = 0
|
||||
|
||||
def verify_citations(
|
||||
self, summary: Summary, segments: list[Segment]
|
||||
) -> CitationVerificationResult:
|
||||
self.verify_call_count += 1
|
||||
if self._is_valid:
|
||||
return CitationVerificationResult(is_valid=True)
|
||||
return CitationVerificationResult(
|
||||
is_valid=False,
|
||||
invalid_key_point_indices=(0,),
|
||||
missing_segment_ids=(99,),
|
||||
)
|
||||
|
||||
def filter_invalid_citations(self, summary: Summary, segments: list[Segment]) -> Summary:
|
||||
self.filter_call_count += 1
|
||||
# Return summary with empty segment_ids for key points
|
||||
return Summary(
|
||||
meeting_id=summary.meeting_id,
|
||||
executive_summary=summary.executive_summary,
|
||||
key_points=[KeyPoint(text=kp.text, segment_ids=[]) for kp in summary.key_points],
|
||||
action_items=[],
|
||||
generated_at=summary.generated_at,
|
||||
)
|
||||
|
||||
|
||||
class TestSummarizationServiceConfiguration:
|
||||
"""Tests for SummarizationService configuration."""
|
||||
|
||||
def test_register_provider(self) -> None:
|
||||
"""Provider should be registered for mode."""
|
||||
service = SummarizationService()
|
||||
provider = MockProvider()
|
||||
|
||||
service.register_provider(SummarizationMode.LOCAL, provider)
|
||||
|
||||
assert SummarizationMode.LOCAL in service.providers
|
||||
|
||||
def test_set_verifier(self) -> None:
|
||||
"""Verifier should be set."""
|
||||
service = SummarizationService()
|
||||
verifier = MockVerifier()
|
||||
|
||||
service.set_verifier(verifier)
|
||||
|
||||
assert service.verifier is verifier
|
||||
|
||||
def test_get_available_modes_with_local(self) -> None:
|
||||
"""Available modes should include local when provider is available."""
|
||||
service = SummarizationService()
|
||||
service.register_provider(SummarizationMode.LOCAL, MockProvider())
|
||||
|
||||
available = service.get_available_modes()
|
||||
|
||||
assert SummarizationMode.LOCAL in available
|
||||
|
||||
def test_get_available_modes_excludes_unavailable(self) -> None:
|
||||
"""Unavailable providers should not be in available modes."""
|
||||
service = SummarizationService()
|
||||
service.register_provider(SummarizationMode.LOCAL, MockProvider(available=False))
|
||||
|
||||
available = service.get_available_modes()
|
||||
|
||||
assert SummarizationMode.LOCAL not in available
|
||||
|
||||
def test_cloud_requires_consent(self) -> None:
|
||||
"""Cloud mode should require consent to be available."""
|
||||
service = SummarizationService()
|
||||
service.register_provider(
|
||||
SummarizationMode.CLOUD,
|
||||
MockProvider(name="cloud", requires_consent=True),
|
||||
)
|
||||
|
||||
available_without_consent = service.get_available_modes()
|
||||
service.grant_cloud_consent()
|
||||
available_with_consent = service.get_available_modes()
|
||||
|
||||
assert SummarizationMode.CLOUD not in available_without_consent
|
||||
assert SummarizationMode.CLOUD in available_with_consent
|
||||
|
||||
def test_revoke_cloud_consent(self) -> None:
|
||||
"""Revoking consent should remove cloud from available modes."""
|
||||
service = SummarizationService()
|
||||
service.register_provider(
|
||||
SummarizationMode.CLOUD,
|
||||
MockProvider(name="cloud", requires_consent=True),
|
||||
)
|
||||
service.grant_cloud_consent()
|
||||
|
||||
service.revoke_cloud_consent()
|
||||
available = service.get_available_modes()
|
||||
|
||||
assert SummarizationMode.CLOUD not in available
|
||||
|
||||
|
||||
class TestSummarizationServiceSummarize:
|
||||
"""Tests for SummarizationService.summarize method."""
|
||||
|
||||
@pytest.fixture
|
||||
def meeting_id(self) -> MeetingId:
|
||||
"""Create test meeting ID."""
|
||||
return MeetingId(uuid4())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_uses_default_mode(self, meeting_id: MeetingId) -> None:
|
||||
"""Summarize should use default mode when not specified."""
|
||||
provider = MockProvider()
|
||||
service = SummarizationService(
|
||||
settings=SummarizationServiceSettings(default_mode=SummarizationMode.LOCAL)
|
||||
)
|
||||
service.register_provider(SummarizationMode.LOCAL, provider)
|
||||
|
||||
segments = [_segment(0)]
|
||||
result = await service.summarize(meeting_id, segments)
|
||||
|
||||
assert result.provider_used == "mock"
|
||||
assert provider.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_uses_specified_mode(self, meeting_id: MeetingId) -> None:
|
||||
"""Summarize should use specified mode."""
|
||||
local_provider = MockProvider(name="local")
|
||||
mock_provider = MockProvider(name="mock")
|
||||
service = SummarizationService()
|
||||
service.register_provider(SummarizationMode.LOCAL, local_provider)
|
||||
service.register_provider(SummarizationMode.MOCK, mock_provider)
|
||||
|
||||
segments = [_segment(0)]
|
||||
result = await service.summarize(meeting_id, segments, mode=SummarizationMode.MOCK)
|
||||
|
||||
assert result.provider_used == "mock"
|
||||
assert mock_provider.call_count == 1
|
||||
assert local_provider.call_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_falls_back_on_unavailable(self, meeting_id: MeetingId) -> None:
|
||||
"""Should fall back to available provider when primary unavailable."""
|
||||
unavailable = MockProvider(name="cloud", available=False)
|
||||
fallback = MockProvider(name="local")
|
||||
service = SummarizationService(
|
||||
settings=SummarizationServiceSettings(
|
||||
fallback_to_local=True,
|
||||
cloud_consent_granted=True,
|
||||
)
|
||||
)
|
||||
service.register_provider(SummarizationMode.CLOUD, unavailable)
|
||||
service.register_provider(SummarizationMode.LOCAL, fallback)
|
||||
|
||||
segments = [_segment(0)]
|
||||
result = await service.summarize(meeting_id, segments, mode=SummarizationMode.CLOUD)
|
||||
|
||||
assert result.provider_used == "local"
|
||||
assert result.fallback_used is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_raises_when_no_fallback(self, meeting_id: MeetingId) -> None:
|
||||
"""Should raise error when no fallback available."""
|
||||
unavailable = MockProvider(name="local", available=False)
|
||||
service = SummarizationService(
|
||||
settings=SummarizationServiceSettings(fallback_to_local=False)
|
||||
)
|
||||
service.register_provider(SummarizationMode.LOCAL, unavailable)
|
||||
|
||||
segments = [_segment(0)]
|
||||
with pytest.raises(ProviderUnavailableError):
|
||||
await service.summarize(meeting_id, segments, mode=SummarizationMode.LOCAL)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_verifies_citations(self, meeting_id: MeetingId) -> None:
|
||||
"""Citations should be verified when enabled."""
|
||||
provider = MockProvider()
|
||||
verifier = MockVerifier(is_valid=True)
|
||||
service = SummarizationService(settings=SummarizationServiceSettings(verify_citations=True))
|
||||
service.register_provider(SummarizationMode.LOCAL, provider)
|
||||
service.set_verifier(verifier)
|
||||
|
||||
segments = [_segment(0)]
|
||||
result = await service.summarize(meeting_id, segments)
|
||||
|
||||
assert verifier.verify_call_count == 1
|
||||
assert result.verification is not None
|
||||
assert result.verification.is_valid is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_filters_invalid_citations(self, meeting_id: MeetingId) -> None:
|
||||
"""Invalid citations should be filtered when enabled."""
|
||||
provider = MockProvider()
|
||||
verifier = MockVerifier(is_valid=False)
|
||||
service = SummarizationService(
|
||||
settings=SummarizationServiceSettings(
|
||||
verify_citations=True,
|
||||
filter_invalid_citations=True,
|
||||
)
|
||||
)
|
||||
service.register_provider(SummarizationMode.LOCAL, provider)
|
||||
service.set_verifier(verifier)
|
||||
|
||||
segments = [_segment(0)]
|
||||
result = await service.summarize(meeting_id, segments)
|
||||
|
||||
assert verifier.filter_call_count == 1
|
||||
assert result.filtered_summary is not None
|
||||
assert result.has_invalid_citations is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_passes_max_limits(self, meeting_id: MeetingId) -> None:
|
||||
"""Max limits should be passed to provider."""
|
||||
captured_request: SummarizationRequest | None = None
|
||||
|
||||
class CapturingProvider(MockProvider):
|
||||
async def summarize(self, request: SummarizationRequest) -> SummarizationResult:
|
||||
nonlocal captured_request
|
||||
captured_request = request
|
||||
return await super().summarize(request)
|
||||
|
||||
provider = CapturingProvider()
|
||||
service = SummarizationService()
|
||||
service.register_provider(SummarizationMode.LOCAL, provider)
|
||||
|
||||
segments = [_segment(0)]
|
||||
await service.summarize(meeting_id, segments, max_key_points=3, max_action_items=5)
|
||||
|
||||
assert captured_request is not None
|
||||
assert captured_request.max_key_points == 3
|
||||
assert captured_request.max_action_items == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_requires_cloud_consent(self, meeting_id: MeetingId) -> None:
|
||||
"""Cloud mode should require consent."""
|
||||
cloud = MockProvider(name="cloud", requires_consent=True)
|
||||
fallback = MockProvider(name="local")
|
||||
service = SummarizationService(
|
||||
settings=SummarizationServiceSettings(
|
||||
cloud_consent_granted=False, fallback_to_local=True
|
||||
)
|
||||
)
|
||||
service.register_provider(SummarizationMode.CLOUD, cloud)
|
||||
service.register_provider(SummarizationMode.LOCAL, fallback)
|
||||
|
||||
segments = [_segment(0)]
|
||||
result = await service.summarize(meeting_id, segments, mode=SummarizationMode.CLOUD)
|
||||
|
||||
assert result.provider_used == "local"
|
||||
assert result.fallback_used is True
|
||||
assert cloud.call_count == 0
|
||||
|
||||
|
||||
class TestSummarizationServiceResult:
|
||||
"""Tests for SummarizationServiceResult."""
|
||||
|
||||
def test_summary_returns_filtered_when_available(self) -> None:
|
||||
"""summary property should return filtered_summary if available."""
|
||||
from noteflow.application.services import SummarizationServiceResult
|
||||
|
||||
original = Summary(
|
||||
meeting_id=MeetingId(uuid4()),
|
||||
executive_summary="Original",
|
||||
key_points=[KeyPoint(text="Point", segment_ids=[99])],
|
||||
)
|
||||
filtered = Summary(
|
||||
meeting_id=original.meeting_id,
|
||||
executive_summary="Original",
|
||||
key_points=[KeyPoint(text="Point", segment_ids=[])],
|
||||
)
|
||||
result = SummarizationServiceResult(
|
||||
result=SummarizationResult(
|
||||
summary=original,
|
||||
model_name="test",
|
||||
provider_name="test",
|
||||
),
|
||||
filtered_summary=filtered,
|
||||
)
|
||||
|
||||
assert result.summary is filtered
|
||||
|
||||
def test_summary_returns_original_when_no_filter(self) -> None:
|
||||
"""summary property should return original when no filter applied."""
|
||||
from noteflow.application.services import SummarizationServiceResult
|
||||
|
||||
original = Summary(
|
||||
meeting_id=MeetingId(uuid4()),
|
||||
executive_summary="Original",
|
||||
key_points=[],
|
||||
)
|
||||
result = SummarizationServiceResult(
|
||||
result=SummarizationResult(
|
||||
summary=original,
|
||||
model_name="test",
|
||||
provider_name="test",
|
||||
),
|
||||
)
|
||||
|
||||
assert result.summary is original
|
||||
|
||||
def test_has_invalid_citations_true(self) -> None:
|
||||
"""has_invalid_citations should be True when verification fails."""
|
||||
from noteflow.application.services import SummarizationServiceResult
|
||||
|
||||
result = SummarizationServiceResult(
|
||||
result=SummarizationResult(
|
||||
summary=Summary(
|
||||
meeting_id=MeetingId(uuid4()),
|
||||
executive_summary="Test",
|
||||
key_points=[],
|
||||
),
|
||||
model_name="test",
|
||||
provider_name="test",
|
||||
),
|
||||
verification=CitationVerificationResult(is_valid=False, invalid_key_point_indices=(0,)),
|
||||
)
|
||||
|
||||
assert result.has_invalid_citations is True
|
||||
|
||||
def test_has_invalid_citations_false_when_valid(self) -> None:
|
||||
"""has_invalid_citations should be False when verification passes."""
|
||||
from noteflow.application.services import SummarizationServiceResult
|
||||
|
||||
result = SummarizationServiceResult(
|
||||
result=SummarizationResult(
|
||||
summary=Summary(
|
||||
meeting_id=MeetingId(uuid4()),
|
||||
executive_summary="Test",
|
||||
key_points=[],
|
||||
),
|
||||
model_name="test",
|
||||
provider_name="test",
|
||||
),
|
||||
verification=CitationVerificationResult(is_valid=True),
|
||||
)
|
||||
|
||||
assert result.has_invalid_citations is False
|
||||
151
tests/application/test_trigger_service.py
Normal file
151
tests/application/test_trigger_service.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Tests for TriggerService application logic."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.application.services.trigger_service import (
|
||||
TriggerService,
|
||||
TriggerServiceSettings,
|
||||
)
|
||||
from noteflow.domain.triggers import TriggerAction, TriggerSignal, TriggerSource
|
||||
|
||||
|
||||
@dataclass
|
||||
class FakeProvider:
|
||||
"""Simple signal provider for testing."""
|
||||
|
||||
signal: TriggerSignal | None
|
||||
enabled: bool = True
|
||||
calls: int = 0
|
||||
|
||||
@property
|
||||
def source(self) -> TriggerSource:
|
||||
return TriggerSource.AUDIO_ACTIVITY
|
||||
|
||||
@property
|
||||
def max_weight(self) -> float:
|
||||
return 1.0
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
return self.enabled
|
||||
|
||||
def get_signal(self) -> TriggerSignal | None:
|
||||
self.calls += 1
|
||||
return self.signal
|
||||
|
||||
|
||||
def _settings(
|
||||
*,
|
||||
enabled: bool = True,
|
||||
auto_start: bool = False,
|
||||
rate_limit_seconds: int = 60,
|
||||
snooze_seconds: int = 30,
|
||||
threshold_ignore: float = 0.2,
|
||||
threshold_auto: float = 0.8,
|
||||
) -> TriggerServiceSettings:
|
||||
return TriggerServiceSettings(
|
||||
enabled=enabled,
|
||||
auto_start_enabled=auto_start,
|
||||
rate_limit_seconds=rate_limit_seconds,
|
||||
snooze_seconds=snooze_seconds,
|
||||
threshold_ignore=threshold_ignore,
|
||||
threshold_auto_start=threshold_auto,
|
||||
)
|
||||
|
||||
|
||||
def test_trigger_service_disabled_skips_providers() -> None:
|
||||
"""Disabled trigger service should ignore without evaluating providers."""
|
||||
provider = FakeProvider(signal=TriggerSignal(TriggerSource.AUDIO_ACTIVITY, weight=0.5))
|
||||
service = TriggerService([provider], settings=_settings(enabled=False))
|
||||
|
||||
decision = service.evaluate()
|
||||
|
||||
assert decision.action == TriggerAction.IGNORE
|
||||
assert decision.confidence == 0.0
|
||||
assert decision.signals == ()
|
||||
assert provider.calls == 0
|
||||
|
||||
|
||||
def test_trigger_service_snooze_ignores_signals(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Snoozed trigger service ignores signals until snooze expires."""
|
||||
provider = FakeProvider(signal=TriggerSignal(TriggerSource.AUDIO_ACTIVITY, weight=0.5))
|
||||
service = TriggerService([provider], settings=_settings())
|
||||
|
||||
monkeypatch.setattr(time, "monotonic", lambda: 100.0)
|
||||
service.snooze(seconds=20)
|
||||
|
||||
monkeypatch.setattr(time, "monotonic", lambda: 110.0)
|
||||
decision = service.evaluate()
|
||||
assert decision.action == TriggerAction.IGNORE
|
||||
|
||||
monkeypatch.setattr(time, "monotonic", lambda: 130.0)
|
||||
decision = service.evaluate()
|
||||
assert decision.action == TriggerAction.NOTIFY
|
||||
|
||||
|
||||
def test_trigger_service_rate_limit(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""TriggerService enforces rate limit between prompts."""
|
||||
provider = FakeProvider(signal=TriggerSignal(TriggerSource.AUDIO_ACTIVITY, weight=0.5))
|
||||
service = TriggerService([provider], settings=_settings(rate_limit_seconds=60))
|
||||
|
||||
monkeypatch.setattr(time, "monotonic", lambda: 100.0)
|
||||
first = service.evaluate()
|
||||
assert first.action == TriggerAction.NOTIFY
|
||||
|
||||
monkeypatch.setattr(time, "monotonic", lambda: 120.0)
|
||||
second = service.evaluate()
|
||||
assert second.action == TriggerAction.IGNORE
|
||||
|
||||
monkeypatch.setattr(time, "monotonic", lambda: 200.0)
|
||||
third = service.evaluate()
|
||||
assert third.action == TriggerAction.NOTIFY
|
||||
|
||||
|
||||
def test_trigger_service_auto_start(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Auto-start fires when confidence passes threshold and auto-start is enabled."""
|
||||
provider = FakeProvider(signal=TriggerSignal(TriggerSource.AUDIO_ACTIVITY, weight=0.9))
|
||||
service = TriggerService([provider], settings=_settings(auto_start=True, threshold_auto=0.8))
|
||||
|
||||
monkeypatch.setattr(time, "monotonic", lambda: 100.0)
|
||||
decision = service.evaluate()
|
||||
|
||||
assert decision.action == TriggerAction.AUTO_START
|
||||
|
||||
|
||||
def test_trigger_service_auto_start_disabled_notifies(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""High confidence should still notify when auto-start is disabled."""
|
||||
provider = FakeProvider(signal=TriggerSignal(TriggerSource.AUDIO_ACTIVITY, weight=0.9))
|
||||
service = TriggerService([provider], settings=_settings(auto_start=False, threshold_auto=0.8))
|
||||
|
||||
monkeypatch.setattr(time, "monotonic", lambda: 100.0)
|
||||
decision = service.evaluate()
|
||||
|
||||
assert decision.action == TriggerAction.NOTIFY
|
||||
|
||||
|
||||
def test_trigger_service_below_ignore_threshold(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Signals below ignore threshold should be ignored."""
|
||||
provider = FakeProvider(signal=TriggerSignal(TriggerSource.AUDIO_ACTIVITY, weight=0.1))
|
||||
service = TriggerService([provider], settings=_settings(threshold_ignore=0.2))
|
||||
|
||||
monkeypatch.setattr(time, "monotonic", lambda: 100.0)
|
||||
decision = service.evaluate()
|
||||
|
||||
assert decision.action == TriggerAction.IGNORE
|
||||
|
||||
|
||||
def test_trigger_service_threshold_validation() -> None:
|
||||
"""Invalid threshold ordering should raise."""
|
||||
with pytest.raises(ValueError, match="threshold_auto_start"):
|
||||
TriggerServiceSettings(
|
||||
enabled=True,
|
||||
auto_start_enabled=False,
|
||||
rate_limit_seconds=10,
|
||||
snooze_seconds=5,
|
||||
threshold_ignore=0.9,
|
||||
threshold_auto_start=0.2,
|
||||
)
|
||||
@@ -2,8 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
41
tests/domain/test_triggers.py
Normal file
41
tests/domain/test_triggers.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Tests for trigger domain entities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.domain.triggers import TriggerAction, TriggerDecision, TriggerSignal, TriggerSource
|
||||
|
||||
|
||||
def test_trigger_signal_weight_bounds() -> None:
|
||||
"""TriggerSignal enforces weight bounds."""
|
||||
with pytest.raises(ValueError, match="Weight must be 0.0-1.0"):
|
||||
TriggerSignal(source=TriggerSource.AUDIO_ACTIVITY, weight=-0.1)
|
||||
|
||||
with pytest.raises(ValueError, match="Weight must be 0.0-1.0"):
|
||||
TriggerSignal(source=TriggerSource.AUDIO_ACTIVITY, weight=1.1)
|
||||
|
||||
signal = TriggerSignal(source=TriggerSource.AUDIO_ACTIVITY, weight=0.5)
|
||||
assert signal.weight == 0.5
|
||||
|
||||
|
||||
def test_trigger_decision_primary_signal_and_detected_app() -> None:
|
||||
"""TriggerDecision exposes primary signal and detected app."""
|
||||
audio = TriggerSignal(source=TriggerSource.AUDIO_ACTIVITY, weight=0.2)
|
||||
foreground = TriggerSignal(
|
||||
source=TriggerSource.FOREGROUND_APP,
|
||||
weight=0.4,
|
||||
app_name="Zoom Meeting",
|
||||
)
|
||||
decision = TriggerDecision(
|
||||
action=TriggerAction.NOTIFY,
|
||||
confidence=0.6,
|
||||
signals=(audio, foreground),
|
||||
)
|
||||
|
||||
assert decision.primary_signal == foreground
|
||||
assert decision.detected_app == "Zoom Meeting"
|
||||
|
||||
empty = TriggerDecision(action=TriggerAction.IGNORE, confidence=0.0, signals=())
|
||||
assert empty.primary_signal is None
|
||||
assert empty.detected_app is None
|
||||
@@ -126,43 +126,6 @@ class TestEnergyVadReset:
|
||||
assert vad._silence_frame_count == 0
|
||||
|
||||
|
||||
class TestEnergyVadRms:
|
||||
"""Tests for RMS computation."""
|
||||
|
||||
def test_rms_zeros(self) -> None:
|
||||
"""RMS of zeros is zero."""
|
||||
audio = np.zeros(100, dtype=np.float32)
|
||||
|
||||
result = EnergyVad._compute_rms(audio)
|
||||
|
||||
assert result == 0.0
|
||||
|
||||
def test_rms_ones(self) -> None:
|
||||
"""RMS of all ones is one."""
|
||||
audio = np.ones(100, dtype=np.float32)
|
||||
|
||||
result = EnergyVad._compute_rms(audio)
|
||||
|
||||
assert result == 1.0
|
||||
|
||||
def test_rms_empty(self) -> None:
|
||||
"""RMS of empty array is zero."""
|
||||
audio = np.array([], dtype=np.float32)
|
||||
|
||||
result = EnergyVad._compute_rms(audio)
|
||||
|
||||
assert result == 0.0
|
||||
|
||||
def test_rms_sine_wave(self) -> None:
|
||||
"""RMS of sine wave is ~0.707."""
|
||||
t = np.linspace(0, 2 * np.pi, 1000, dtype=np.float32)
|
||||
audio = np.sin(t).astype(np.float32)
|
||||
|
||||
result = EnergyVad._compute_rms(audio)
|
||||
|
||||
assert 0.7 < result < 0.72
|
||||
|
||||
|
||||
class TestStreamingVad:
|
||||
"""Tests for StreamingVad wrapper."""
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Tests for RmsLevelProvider."""
|
||||
"""Tests for RmsLevelProvider and compute_rms."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -8,12 +8,43 @@ from typing import TYPE_CHECKING
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from noteflow.infrastructure.audio import RmsLevelProvider
|
||||
from noteflow.infrastructure.audio import RmsLevelProvider, compute_rms
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
||||
class TestComputeRms:
|
||||
"""Tests for compute_rms function."""
|
||||
|
||||
def test_empty_array_returns_zero(self) -> None:
|
||||
"""RMS of empty array is zero."""
|
||||
frames = np.array([], dtype=np.float32)
|
||||
assert compute_rms(frames) == 0.0
|
||||
|
||||
def test_zeros_returns_zero(self) -> None:
|
||||
"""RMS of zeros is zero."""
|
||||
frames = np.zeros(100, dtype=np.float32)
|
||||
assert compute_rms(frames) == 0.0
|
||||
|
||||
def test_ones_returns_one(self) -> None:
|
||||
"""RMS of all ones is one."""
|
||||
frames = np.ones(100, dtype=np.float32)
|
||||
assert compute_rms(frames) == 1.0
|
||||
|
||||
def test_half_amplitude_returns_half(self) -> None:
|
||||
"""RMS of constant 0.5 is 0.5."""
|
||||
frames = np.full(100, 0.5, dtype=np.float32)
|
||||
assert compute_rms(frames) == 0.5
|
||||
|
||||
def test_sine_wave_returns_sqrt_half(self) -> None:
|
||||
"""RMS of sine wave is approximately 1/sqrt(2)."""
|
||||
t = np.linspace(0, 2 * np.pi, 1000, dtype=np.float32)
|
||||
frames = np.sin(t).astype(np.float32)
|
||||
result = compute_rms(frames)
|
||||
assert 0.7 < result < 0.72 # ~0.707
|
||||
|
||||
|
||||
class TestRmsLevelProvider:
|
||||
"""Tests for RmsLevelProvider class."""
|
||||
|
||||
|
||||
220
tests/infrastructure/summarization/test_citation_verifier.py
Normal file
220
tests/infrastructure/summarization/test_citation_verifier.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""Tests for citation verification."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.domain.entities import ActionItem, KeyPoint, Segment, Summary
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
from noteflow.infrastructure.summarization import SegmentCitationVerifier
|
||||
|
||||
|
||||
def _segment(segment_id: int, text: str = "Test") -> Segment:
|
||||
"""Create a test segment."""
|
||||
return Segment(
|
||||
segment_id=segment_id,
|
||||
text=text,
|
||||
start_time=segment_id * 5.0,
|
||||
end_time=(segment_id + 1) * 5.0,
|
||||
)
|
||||
|
||||
|
||||
def _key_point(text: str, segment_ids: list[int]) -> KeyPoint:
|
||||
"""Create a test key point."""
|
||||
return KeyPoint(text=text, segment_ids=segment_ids)
|
||||
|
||||
|
||||
def _action_item(text: str, segment_ids: list[int]) -> ActionItem:
|
||||
"""Create a test action item."""
|
||||
return ActionItem(text=text, segment_ids=segment_ids)
|
||||
|
||||
|
||||
def _summary(
|
||||
key_points: list[KeyPoint] | None = None,
|
||||
action_items: list[ActionItem] | None = None,
|
||||
) -> Summary:
|
||||
"""Create a test summary."""
|
||||
return Summary(
|
||||
meeting_id=MeetingId(uuid4()),
|
||||
executive_summary="Test summary",
|
||||
key_points=key_points or [],
|
||||
action_items=action_items or [],
|
||||
)
|
||||
|
||||
|
||||
class TestSegmentCitationVerifier:
|
||||
"""Tests for SegmentCitationVerifier."""
|
||||
|
||||
@pytest.fixture
|
||||
def verifier(self) -> SegmentCitationVerifier:
|
||||
"""Create verifier instance."""
|
||||
return SegmentCitationVerifier()
|
||||
|
||||
def test_verify_valid_citations(self, verifier: SegmentCitationVerifier) -> None:
|
||||
"""All citations valid should return is_valid=True."""
|
||||
segments = [_segment(0), _segment(1), _segment(2)]
|
||||
summary = _summary(
|
||||
key_points=[_key_point("Point 1", [0, 1])],
|
||||
action_items=[_action_item("Action 1", [2])],
|
||||
)
|
||||
|
||||
result = verifier.verify_citations(summary, segments)
|
||||
|
||||
assert result.is_valid is True
|
||||
assert result.invalid_key_point_indices == ()
|
||||
assert result.invalid_action_item_indices == ()
|
||||
assert result.missing_segment_ids == ()
|
||||
|
||||
def test_verify_invalid_key_point_citation(self, verifier: SegmentCitationVerifier) -> None:
|
||||
"""Invalid segment_id in key point should be detected."""
|
||||
segments = [_segment(0), _segment(1)]
|
||||
summary = _summary(
|
||||
key_points=[_key_point("Point 1", [0, 99])], # 99 doesn't exist
|
||||
)
|
||||
|
||||
result = verifier.verify_citations(summary, segments)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert result.invalid_key_point_indices == (0,)
|
||||
assert result.invalid_action_item_indices == ()
|
||||
assert result.missing_segment_ids == (99,)
|
||||
|
||||
def test_verify_invalid_action_item_citation(self, verifier: SegmentCitationVerifier) -> None:
|
||||
"""Invalid segment_id in action item should be detected."""
|
||||
segments = [_segment(0), _segment(1)]
|
||||
summary = _summary(
|
||||
action_items=[_action_item("Action 1", [50])], # 50 doesn't exist
|
||||
)
|
||||
|
||||
result = verifier.verify_citations(summary, segments)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert result.invalid_key_point_indices == ()
|
||||
assert result.invalid_action_item_indices == (0,)
|
||||
assert result.missing_segment_ids == (50,)
|
||||
|
||||
def test_verify_multiple_invalid_citations(self, verifier: SegmentCitationVerifier) -> None:
|
||||
"""Multiple invalid citations should all be detected."""
|
||||
segments = [_segment(0)]
|
||||
summary = _summary(
|
||||
key_points=[
|
||||
_key_point("Point 1", [0]),
|
||||
_key_point("Point 2", [1]), # Invalid
|
||||
_key_point("Point 3", [2]), # Invalid
|
||||
],
|
||||
action_items=[
|
||||
_action_item("Action 1", [3]), # Invalid
|
||||
],
|
||||
)
|
||||
|
||||
result = verifier.verify_citations(summary, segments)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert result.invalid_key_point_indices == (1, 2)
|
||||
assert result.invalid_action_item_indices == (0,)
|
||||
assert result.missing_segment_ids == (1, 2, 3)
|
||||
|
||||
def test_verify_empty_summary(self, verifier: SegmentCitationVerifier) -> None:
|
||||
"""Empty summary should be valid."""
|
||||
segments = [_segment(0)]
|
||||
summary = _summary()
|
||||
|
||||
result = verifier.verify_citations(summary, segments)
|
||||
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_verify_empty_segments(self, verifier: SegmentCitationVerifier) -> None:
|
||||
"""Summary with citations but no segments should be invalid."""
|
||||
segments: list[Segment] = []
|
||||
summary = _summary(key_points=[_key_point("Point 1", [0])])
|
||||
|
||||
result = verifier.verify_citations(summary, segments)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert result.missing_segment_ids == (0,)
|
||||
|
||||
def test_verify_empty_citations(self, verifier: SegmentCitationVerifier) -> None:
|
||||
"""Key points/actions with empty segment_ids should be valid."""
|
||||
segments = [_segment(0)]
|
||||
summary = _summary(
|
||||
key_points=[_key_point("Point 1", [])], # No citations
|
||||
action_items=[_action_item("Action 1", [])], # No citations
|
||||
)
|
||||
|
||||
result = verifier.verify_citations(summary, segments)
|
||||
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_invalid_count_property(self, verifier: SegmentCitationVerifier) -> None:
|
||||
"""invalid_count should sum key point and action item invalid counts."""
|
||||
segments = [_segment(0)]
|
||||
summary = _summary(
|
||||
key_points=[
|
||||
_key_point("Point 1", [1]), # Invalid
|
||||
_key_point("Point 2", [2]), # Invalid
|
||||
],
|
||||
action_items=[
|
||||
_action_item("Action 1", [3]), # Invalid
|
||||
],
|
||||
)
|
||||
|
||||
result = verifier.verify_citations(summary, segments)
|
||||
|
||||
assert result.invalid_count == 3
|
||||
|
||||
|
||||
class TestFilterInvalidCitations:
|
||||
"""Tests for filter_invalid_citations method."""
|
||||
|
||||
@pytest.fixture
|
||||
def verifier(self) -> SegmentCitationVerifier:
|
||||
"""Create verifier instance."""
|
||||
return SegmentCitationVerifier()
|
||||
|
||||
def test_filter_removes_invalid_segment_ids(self, verifier: SegmentCitationVerifier) -> None:
|
||||
"""Invalid segment_ids should be removed from citations."""
|
||||
segments = [_segment(0), _segment(1)]
|
||||
summary = _summary(
|
||||
key_points=[_key_point("Point 1", [0, 1, 99])], # 99 invalid
|
||||
action_items=[_action_item("Action 1", [1, 50])], # 50 invalid
|
||||
)
|
||||
|
||||
filtered = verifier.filter_invalid_citations(summary, segments)
|
||||
|
||||
assert filtered.key_points[0].segment_ids == [0, 1]
|
||||
assert filtered.action_items[0].segment_ids == [1]
|
||||
|
||||
def test_filter_preserves_valid_citations(self, verifier: SegmentCitationVerifier) -> None:
|
||||
"""Valid citations should be preserved."""
|
||||
segments = [_segment(0), _segment(1), _segment(2)]
|
||||
summary = _summary(
|
||||
key_points=[_key_point("Point 1", [0, 1])],
|
||||
action_items=[_action_item("Action 1", [2])],
|
||||
)
|
||||
|
||||
filtered = verifier.filter_invalid_citations(summary, segments)
|
||||
|
||||
assert filtered.key_points[0].segment_ids == [0, 1]
|
||||
assert filtered.action_items[0].segment_ids == [2]
|
||||
|
||||
def test_filter_preserves_other_fields(self, verifier: SegmentCitationVerifier) -> None:
|
||||
"""Non-citation fields should be preserved."""
|
||||
segments = [_segment(0)]
|
||||
summary = Summary(
|
||||
meeting_id=MeetingId(uuid4()),
|
||||
executive_summary="Important meeting",
|
||||
key_points=[KeyPoint(text="Key point", segment_ids=[0], start_time=1.0, end_time=2.0)],
|
||||
action_items=[ActionItem(text="Action", segment_ids=[0], assignee="Alice", priority=2)],
|
||||
model_version="test-1.0",
|
||||
)
|
||||
|
||||
filtered = verifier.filter_invalid_citations(summary, segments)
|
||||
|
||||
assert filtered.executive_summary == "Important meeting"
|
||||
assert filtered.key_points[0].text == "Key point"
|
||||
assert filtered.key_points[0].start_time == 1.0
|
||||
assert filtered.action_items[0].assignee == "Alice"
|
||||
assert filtered.action_items[0].priority == 2
|
||||
assert filtered.model_version == "test-1.0"
|
||||
458
tests/infrastructure/summarization/test_cloud_provider.py
Normal file
458
tests/infrastructure/summarization/test_cloud_provider.py
Normal file
@@ -0,0 +1,458 @@
|
||||
"""Tests for cloud summarization provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.domain.entities import Segment
|
||||
from noteflow.domain.summarization import (
|
||||
InvalidResponseError,
|
||||
ProviderUnavailableError,
|
||||
SummarizationRequest,
|
||||
)
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
from noteflow.infrastructure.summarization import CloudBackend
|
||||
|
||||
|
||||
def _segment(
|
||||
segment_id: int,
|
||||
text: str,
|
||||
start: float = 0.0,
|
||||
end: float = 5.0,
|
||||
) -> Segment:
|
||||
"""Create a test segment."""
|
||||
return Segment(
|
||||
segment_id=segment_id,
|
||||
text=text,
|
||||
start_time=start,
|
||||
end_time=end,
|
||||
)
|
||||
|
||||
|
||||
def _valid_json_response(
|
||||
summary: str = "Test summary.",
|
||||
key_points: list[dict[str, Any]] | None = None,
|
||||
action_items: list[dict[str, Any]] | None = None,
|
||||
) -> str:
|
||||
"""Build a valid JSON response string."""
|
||||
return json.dumps(
|
||||
{
|
||||
"executive_summary": summary,
|
||||
"key_points": key_points or [],
|
||||
"action_items": action_items or [],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestCloudSummarizerProperties:
|
||||
"""Tests for CloudSummarizer properties."""
|
||||
|
||||
def test_provider_name_openai(self) -> None:
|
||||
"""Provider name should be 'openai' for OpenAI backend."""
|
||||
from noteflow.infrastructure.summarization import CloudSummarizer
|
||||
|
||||
summarizer = CloudSummarizer(backend=CloudBackend.OPENAI)
|
||||
assert summarizer.provider_name == "openai"
|
||||
|
||||
def test_provider_name_anthropic(self) -> None:
|
||||
"""Provider name should be 'anthropic' for Anthropic backend."""
|
||||
from noteflow.infrastructure.summarization import CloudSummarizer
|
||||
|
||||
summarizer = CloudSummarizer(backend=CloudBackend.ANTHROPIC)
|
||||
assert summarizer.provider_name == "anthropic"
|
||||
|
||||
def test_requires_cloud_consent_true(self) -> None:
|
||||
"""Cloud providers should require consent."""
|
||||
from noteflow.infrastructure.summarization import CloudSummarizer
|
||||
|
||||
summarizer = CloudSummarizer()
|
||||
assert summarizer.requires_cloud_consent is True
|
||||
|
||||
def test_is_available_with_api_key(self) -> None:
|
||||
"""is_available should be True when API key is provided."""
|
||||
from noteflow.infrastructure.summarization import CloudSummarizer
|
||||
|
||||
summarizer = CloudSummarizer(api_key="test-key")
|
||||
assert summarizer.is_available is True
|
||||
|
||||
def test_is_available_without_api_key(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""is_available should be False without API key or env var."""
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
|
||||
from noteflow.infrastructure.summarization import CloudSummarizer
|
||||
|
||||
summarizer = CloudSummarizer()
|
||||
assert summarizer.is_available is False
|
||||
|
||||
def test_is_available_with_openai_env_var(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""is_available should be True with OPENAI_API_KEY env var."""
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-test")
|
||||
|
||||
from noteflow.infrastructure.summarization import CloudSummarizer
|
||||
|
||||
summarizer = CloudSummarizer(backend=CloudBackend.OPENAI)
|
||||
assert summarizer.is_available is True
|
||||
|
||||
def test_is_available_with_anthropic_env_var(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""is_available should be True with ANTHROPIC_API_KEY env var."""
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test")
|
||||
|
||||
from noteflow.infrastructure.summarization import CloudSummarizer
|
||||
|
||||
summarizer = CloudSummarizer(backend=CloudBackend.ANTHROPIC)
|
||||
assert summarizer.is_available is True
|
||||
|
||||
def test_default_model_openai(self) -> None:
|
||||
"""Default model for OpenAI should be gpt-4o-mini."""
|
||||
from noteflow.infrastructure.summarization import CloudSummarizer
|
||||
|
||||
summarizer = CloudSummarizer(backend=CloudBackend.OPENAI)
|
||||
assert summarizer._model == "gpt-4o-mini"
|
||||
|
||||
def test_default_model_anthropic(self) -> None:
|
||||
"""Default model for Anthropic should be claude-3-haiku."""
|
||||
from noteflow.infrastructure.summarization import CloudSummarizer
|
||||
|
||||
summarizer = CloudSummarizer(backend=CloudBackend.ANTHROPIC)
|
||||
assert summarizer._model == "claude-3-haiku-20240307"
|
||||
|
||||
def test_custom_model(self) -> None:
|
||||
"""Custom model should override default."""
|
||||
from noteflow.infrastructure.summarization import CloudSummarizer
|
||||
|
||||
summarizer = CloudSummarizer(model="gpt-4-turbo")
|
||||
assert summarizer._model == "gpt-4-turbo"
|
||||
|
||||
|
||||
class TestCloudSummarizerOpenAI:
|
||||
"""Tests for CloudSummarizer with OpenAI backend."""
|
||||
|
||||
@pytest.fixture
|
||||
def meeting_id(self) -> MeetingId:
|
||||
"""Create test meeting ID."""
|
||||
return MeetingId(uuid4())
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai(self, monkeypatch: pytest.MonkeyPatch) -> types.ModuleType:
|
||||
"""Mock openai module."""
|
||||
|
||||
def create_response(content: str, tokens: int = 100) -> types.SimpleNamespace:
|
||||
"""Create mock OpenAI response."""
|
||||
return types.SimpleNamespace(
|
||||
choices=[types.SimpleNamespace(message=types.SimpleNamespace(content=content))],
|
||||
usage=types.SimpleNamespace(total_tokens=tokens),
|
||||
)
|
||||
|
||||
mock_client = types.SimpleNamespace(
|
||||
chat=types.SimpleNamespace(
|
||||
completions=types.SimpleNamespace(
|
||||
create=lambda **_: create_response(_valid_json_response())
|
||||
)
|
||||
)
|
||||
)
|
||||
mock_module = types.ModuleType("openai")
|
||||
mock_module.OpenAI = lambda **_: mock_client
|
||||
monkeypatch.setitem(sys.modules, "openai", mock_module)
|
||||
return mock_module
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_empty_segments(
|
||||
self, meeting_id: MeetingId, mock_openai: types.ModuleType
|
||||
) -> None:
|
||||
"""Empty segments should return empty summary."""
|
||||
from noteflow.infrastructure.summarization import CloudSummarizer
|
||||
|
||||
summarizer = CloudSummarizer(api_key="test-key")
|
||||
request = SummarizationRequest(meeting_id=meeting_id, segments=[])
|
||||
|
||||
result = await summarizer.summarize(request)
|
||||
|
||||
assert result.summary.key_points == []
|
||||
assert result.summary.action_items == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_returns_result(
|
||||
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Summarize should return SummarizationResult."""
|
||||
response_content = _valid_json_response(
|
||||
summary="Project meeting summary.",
|
||||
key_points=[{"text": "Key point", "segment_ids": [0]}],
|
||||
action_items=[{"text": "Action", "assignee": "Bob", "priority": 1, "segment_ids": [1]}],
|
||||
)
|
||||
|
||||
def create_response(**_: Any) -> types.SimpleNamespace:
|
||||
return types.SimpleNamespace(
|
||||
choices=[
|
||||
types.SimpleNamespace(message=types.SimpleNamespace(content=response_content))
|
||||
],
|
||||
usage=types.SimpleNamespace(total_tokens=150),
|
||||
)
|
||||
|
||||
mock_client = types.SimpleNamespace(
|
||||
chat=types.SimpleNamespace(completions=types.SimpleNamespace(create=create_response))
|
||||
)
|
||||
mock_module = types.ModuleType("openai")
|
||||
mock_module.OpenAI = lambda **_: mock_client
|
||||
monkeypatch.setitem(sys.modules, "openai", mock_module)
|
||||
|
||||
from noteflow.infrastructure.summarization import CloudSummarizer
|
||||
|
||||
summarizer = CloudSummarizer(api_key="test-key", backend=CloudBackend.OPENAI)
|
||||
segments = [_segment(0, "Key point"), _segment(1, "Action item")]
|
||||
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
||||
|
||||
result = await summarizer.summarize(request)
|
||||
|
||||
assert result.provider_name == "openai"
|
||||
assert result.summary.executive_summary == "Project meeting summary."
|
||||
assert result.tokens_used == 150
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_unavailable_on_auth_error(
|
||||
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Should raise ProviderUnavailableError on auth failure."""
|
||||
|
||||
def raise_auth_error(**_: Any) -> None:
|
||||
raise Exception("Invalid API key provided")
|
||||
|
||||
mock_client = types.SimpleNamespace(
|
||||
chat=types.SimpleNamespace(completions=types.SimpleNamespace(create=raise_auth_error))
|
||||
)
|
||||
mock_module = types.ModuleType("openai")
|
||||
mock_module.OpenAI = lambda **_: mock_client
|
||||
monkeypatch.setitem(sys.modules, "openai", mock_module)
|
||||
|
||||
from noteflow.infrastructure.summarization import CloudSummarizer
|
||||
|
||||
summarizer = CloudSummarizer(api_key="bad-key")
|
||||
segments = [_segment(0, "Test")]
|
||||
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
||||
|
||||
with pytest.raises(ProviderUnavailableError, match="authentication failed"):
|
||||
await summarizer.summarize(request)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_invalid_response_on_empty_content(
|
||||
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Should raise InvalidResponseError on empty response."""
|
||||
|
||||
def create_empty_response(**_: Any) -> types.SimpleNamespace:
|
||||
return types.SimpleNamespace(
|
||||
choices=[types.SimpleNamespace(message=types.SimpleNamespace(content=""))],
|
||||
usage=None,
|
||||
)
|
||||
|
||||
mock_client = types.SimpleNamespace(
|
||||
chat=types.SimpleNamespace(
|
||||
completions=types.SimpleNamespace(create=create_empty_response)
|
||||
)
|
||||
)
|
||||
mock_module = types.ModuleType("openai")
|
||||
mock_module.OpenAI = lambda **_: mock_client
|
||||
monkeypatch.setitem(sys.modules, "openai", mock_module)
|
||||
|
||||
from noteflow.infrastructure.summarization import CloudSummarizer
|
||||
|
||||
summarizer = CloudSummarizer(api_key="test-key")
|
||||
segments = [_segment(0, "Test")]
|
||||
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
||||
|
||||
with pytest.raises(InvalidResponseError, match="Empty response"):
|
||||
await summarizer.summarize(request)
|
||||
|
||||
|
||||
class TestCloudSummarizerAnthropic:
|
||||
"""Tests for CloudSummarizer with Anthropic backend."""
|
||||
|
||||
@pytest.fixture
|
||||
def meeting_id(self) -> MeetingId:
|
||||
"""Create test meeting ID."""
|
||||
return MeetingId(uuid4())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_returns_result(
|
||||
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Summarize should return SummarizationResult."""
|
||||
response_content = _valid_json_response(
|
||||
summary="Anthropic summary.",
|
||||
key_points=[{"text": "Point", "segment_ids": [0]}],
|
||||
)
|
||||
|
||||
def create_response(**_: Any) -> types.SimpleNamespace:
|
||||
return types.SimpleNamespace(
|
||||
content=[types.SimpleNamespace(text=response_content)],
|
||||
usage=types.SimpleNamespace(input_tokens=50, output_tokens=100),
|
||||
)
|
||||
|
||||
mock_client = types.SimpleNamespace(messages=types.SimpleNamespace(create=create_response))
|
||||
mock_module = types.ModuleType("anthropic")
|
||||
mock_module.Anthropic = lambda **_: mock_client
|
||||
monkeypatch.setitem(sys.modules, "anthropic", mock_module)
|
||||
|
||||
from noteflow.infrastructure.summarization import CloudSummarizer
|
||||
|
||||
summarizer = CloudSummarizer(api_key="test-key", backend=CloudBackend.ANTHROPIC)
|
||||
segments = [_segment(0, "Test point")]
|
||||
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
||||
|
||||
result = await summarizer.summarize(request)
|
||||
|
||||
assert result.provider_name == "anthropic"
|
||||
assert result.summary.executive_summary == "Anthropic summary."
|
||||
assert result.tokens_used == 150
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_unavailable_when_package_missing(
|
||||
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Should raise ProviderUnavailableError when package not installed."""
|
||||
monkeypatch.delitem(sys.modules, "anthropic", raising=False)
|
||||
|
||||
import builtins
|
||||
|
||||
original_import = builtins.__import__
|
||||
|
||||
def mock_import(name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
if name == "anthropic":
|
||||
raise ImportError("No module named 'anthropic'")
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, "__import__", mock_import)
|
||||
|
||||
from noteflow.infrastructure.summarization import cloud_provider
|
||||
|
||||
summarizer = cloud_provider.CloudSummarizer(
|
||||
api_key="test-key", backend=CloudBackend.ANTHROPIC
|
||||
)
|
||||
summarizer._client = None
|
||||
|
||||
segments = [_segment(0, "Test")]
|
||||
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
||||
|
||||
with pytest.raises(ProviderUnavailableError, match="anthropic package"):
|
||||
await summarizer.summarize(request)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_invalid_response_on_empty_content(
|
||||
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Should raise InvalidResponseError on empty response."""
|
||||
|
||||
def create_empty_response(**_: Any) -> types.SimpleNamespace:
|
||||
return types.SimpleNamespace(
|
||||
content=[],
|
||||
usage=types.SimpleNamespace(input_tokens=10, output_tokens=0),
|
||||
)
|
||||
|
||||
mock_client = types.SimpleNamespace(
|
||||
messages=types.SimpleNamespace(create=create_empty_response)
|
||||
)
|
||||
mock_module = types.ModuleType("anthropic")
|
||||
mock_module.Anthropic = lambda **_: mock_client
|
||||
monkeypatch.setitem(sys.modules, "anthropic", mock_module)
|
||||
|
||||
from noteflow.infrastructure.summarization import CloudSummarizer
|
||||
|
||||
summarizer = CloudSummarizer(api_key="test-key", backend=CloudBackend.ANTHROPIC)
|
||||
segments = [_segment(0, "Test")]
|
||||
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
||||
|
||||
with pytest.raises(InvalidResponseError, match="Empty response"):
|
||||
await summarizer.summarize(request)
|
||||
|
||||
|
||||
class TestCloudSummarizerFiltering:
|
||||
"""Tests for response filtering in CloudSummarizer."""
|
||||
|
||||
@pytest.fixture
|
||||
def meeting_id(self) -> MeetingId:
|
||||
"""Create test meeting ID."""
|
||||
return MeetingId(uuid4())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filters_invalid_segment_ids(
|
||||
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Invalid segment_ids should be filtered from response."""
|
||||
response_content = _valid_json_response(
|
||||
summary="Test",
|
||||
key_points=[{"text": "Point", "segment_ids": [0, 99, 100]}],
|
||||
)
|
||||
|
||||
def create_response(**_: Any) -> types.SimpleNamespace:
|
||||
return types.SimpleNamespace(
|
||||
choices=[
|
||||
types.SimpleNamespace(message=types.SimpleNamespace(content=response_content))
|
||||
],
|
||||
usage=None,
|
||||
)
|
||||
|
||||
mock_client = types.SimpleNamespace(
|
||||
chat=types.SimpleNamespace(completions=types.SimpleNamespace(create=create_response))
|
||||
)
|
||||
mock_module = types.ModuleType("openai")
|
||||
mock_module.OpenAI = lambda **_: mock_client
|
||||
monkeypatch.setitem(sys.modules, "openai", mock_module)
|
||||
|
||||
from noteflow.infrastructure.summarization import CloudSummarizer
|
||||
|
||||
summarizer = CloudSummarizer(api_key="test-key")
|
||||
segments = [_segment(0, "Only valid segment")]
|
||||
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
||||
|
||||
result = await summarizer.summarize(request)
|
||||
|
||||
assert result.summary.key_points[0].segment_ids == [0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_respects_max_limits(
|
||||
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Max limits should truncate response items."""
|
||||
response_content = _valid_json_response(
|
||||
summary="Test",
|
||||
key_points=[{"text": f"Point {i}", "segment_ids": [0]} for i in range(10)],
|
||||
action_items=[{"text": f"Action {i}", "segment_ids": [0]} for i in range(10)],
|
||||
)
|
||||
|
||||
def create_response(**_: Any) -> types.SimpleNamespace:
|
||||
return types.SimpleNamespace(
|
||||
choices=[
|
||||
types.SimpleNamespace(message=types.SimpleNamespace(content=response_content))
|
||||
],
|
||||
usage=None,
|
||||
)
|
||||
|
||||
mock_client = types.SimpleNamespace(
|
||||
chat=types.SimpleNamespace(completions=types.SimpleNamespace(create=create_response))
|
||||
)
|
||||
mock_module = types.ModuleType("openai")
|
||||
mock_module.OpenAI = lambda **_: mock_client
|
||||
monkeypatch.setitem(sys.modules, "openai", mock_module)
|
||||
|
||||
from noteflow.infrastructure.summarization import CloudSummarizer
|
||||
|
||||
summarizer = CloudSummarizer(api_key="test-key")
|
||||
segments = [_segment(0, "Test")]
|
||||
request = SummarizationRequest(
|
||||
meeting_id=meeting_id,
|
||||
segments=segments,
|
||||
max_key_points=2,
|
||||
max_action_items=3,
|
||||
)
|
||||
|
||||
result = await summarizer.summarize(request)
|
||||
|
||||
assert len(result.summary.key_points) == 2
|
||||
assert len(result.summary.action_items) == 3
|
||||
196
tests/infrastructure/summarization/test_mock_provider.py
Normal file
196
tests/infrastructure/summarization/test_mock_provider.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""Tests for mock summarization provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.domain.entities import Segment
|
||||
from noteflow.domain.summarization import SummarizationRequest
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
from noteflow.infrastructure.summarization import MockSummarizer
|
||||
|
||||
|
||||
def _segment(
|
||||
segment_id: int,
|
||||
text: str,
|
||||
start: float = 0.0,
|
||||
end: float = 5.0,
|
||||
) -> Segment:
|
||||
"""Create a test segment."""
|
||||
return Segment(
|
||||
segment_id=segment_id,
|
||||
text=text,
|
||||
start_time=start,
|
||||
end_time=end,
|
||||
)
|
||||
|
||||
|
||||
class TestMockSummarizer:
|
||||
"""Tests for MockSummarizer."""
|
||||
|
||||
@pytest.fixture
|
||||
def summarizer(self) -> MockSummarizer:
|
||||
"""Create MockSummarizer instance."""
|
||||
return MockSummarizer(latency_ms=0.0)
|
||||
|
||||
@pytest.fixture
|
||||
def meeting_id(self) -> MeetingId:
|
||||
"""Create a test meeting ID."""
|
||||
return MeetingId(uuid4())
|
||||
|
||||
def test_provider_name(self, summarizer: MockSummarizer) -> None:
|
||||
"""Provider name should be 'mock'."""
|
||||
assert summarizer.provider_name == "mock"
|
||||
|
||||
def test_is_available(self, summarizer: MockSummarizer) -> None:
|
||||
"""Mock provider should always be available."""
|
||||
assert summarizer.is_available is True
|
||||
|
||||
def test_requires_cloud_consent(self, summarizer: MockSummarizer) -> None:
|
||||
"""Mock provider should not require cloud consent."""
|
||||
assert summarizer.requires_cloud_consent is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_returns_result(
|
||||
self,
|
||||
summarizer: MockSummarizer,
|
||||
meeting_id: MeetingId,
|
||||
) -> None:
|
||||
"""Summarize should return a SummarizationResult."""
|
||||
segments = [
|
||||
_segment(0, "First segment text.", 0.0, 5.0),
|
||||
_segment(1, "Second segment text.", 5.0, 10.0),
|
||||
]
|
||||
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
||||
|
||||
result = await summarizer.summarize(request)
|
||||
|
||||
assert result.provider_name == "mock"
|
||||
assert result.model_name == "mock-1.0"
|
||||
assert result.summary.meeting_id == meeting_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_generates_executive_summary(
|
||||
self,
|
||||
summarizer: MockSummarizer,
|
||||
meeting_id: MeetingId,
|
||||
) -> None:
|
||||
"""Summarize should generate executive summary with segment count."""
|
||||
segments = [
|
||||
_segment(0, "Hello", 0.0, 5.0),
|
||||
_segment(1, "World", 5.0, 10.0),
|
||||
_segment(2, "Test", 10.0, 15.0),
|
||||
]
|
||||
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
||||
|
||||
result = await summarizer.summarize(request)
|
||||
|
||||
assert "3 segments" in result.summary.executive_summary
|
||||
assert "15.0 seconds" in result.summary.executive_summary
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_generates_key_points_with_citations(
|
||||
self,
|
||||
summarizer: MockSummarizer,
|
||||
meeting_id: MeetingId,
|
||||
) -> None:
|
||||
"""Key points should have valid segment_id citations."""
|
||||
segments = [
|
||||
_segment(0, "First point", 0.0, 5.0),
|
||||
_segment(1, "Second point", 5.0, 10.0),
|
||||
]
|
||||
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
||||
|
||||
result = await summarizer.summarize(request)
|
||||
|
||||
assert len(result.summary.key_points) == 2
|
||||
assert result.summary.key_points[0].segment_ids == [0]
|
||||
assert result.summary.key_points[1].segment_ids == [1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_respects_max_key_points(
|
||||
self,
|
||||
summarizer: MockSummarizer,
|
||||
meeting_id: MeetingId,
|
||||
) -> None:
|
||||
"""Key points should be limited to max_key_points."""
|
||||
segments = [_segment(i, f"Segment {i}", i * 5.0, (i + 1) * 5.0) for i in range(10)]
|
||||
request = SummarizationRequest(
|
||||
meeting_id=meeting_id,
|
||||
segments=segments,
|
||||
max_key_points=3,
|
||||
)
|
||||
|
||||
result = await summarizer.summarize(request)
|
||||
|
||||
assert len(result.summary.key_points) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_extracts_action_items(
|
||||
self,
|
||||
summarizer: MockSummarizer,
|
||||
meeting_id: MeetingId,
|
||||
) -> None:
|
||||
"""Action items should be extracted from segments with action keywords."""
|
||||
segments = [
|
||||
_segment(0, "General discussion", 0.0, 5.0),
|
||||
_segment(1, "We need to fix the bug", 5.0, 10.0),
|
||||
_segment(2, "TODO: Review the code", 10.0, 15.0),
|
||||
_segment(3, "The meeting went well", 15.0, 20.0),
|
||||
]
|
||||
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
||||
|
||||
result = await summarizer.summarize(request)
|
||||
|
||||
assert len(result.summary.action_items) == 2
|
||||
assert result.summary.action_items[0].segment_ids == [1]
|
||||
assert result.summary.action_items[1].segment_ids == [2]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_respects_max_action_items(
|
||||
self,
|
||||
summarizer: MockSummarizer,
|
||||
meeting_id: MeetingId,
|
||||
) -> None:
|
||||
"""Action items should be limited to max_action_items."""
|
||||
segments = [_segment(i, f"TODO: task {i}", i * 5.0, (i + 1) * 5.0) for i in range(10)]
|
||||
request = SummarizationRequest(
|
||||
meeting_id=meeting_id,
|
||||
segments=segments,
|
||||
max_action_items=2,
|
||||
)
|
||||
|
||||
result = await summarizer.summarize(request)
|
||||
|
||||
assert len(result.summary.action_items) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_sets_generated_at(
|
||||
self,
|
||||
summarizer: MockSummarizer,
|
||||
meeting_id: MeetingId,
|
||||
) -> None:
|
||||
"""Summary should have generated_at timestamp."""
|
||||
segments = [_segment(0, "Test", 0.0, 5.0)]
|
||||
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
||||
|
||||
result = await summarizer.summarize(request)
|
||||
|
||||
assert result.summary.generated_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_empty_segments(
|
||||
self,
|
||||
summarizer: MockSummarizer,
|
||||
meeting_id: MeetingId,
|
||||
) -> None:
|
||||
"""Summarize should handle empty segments list."""
|
||||
request = SummarizationRequest(meeting_id=meeting_id, segments=[])
|
||||
|
||||
result = await summarizer.summarize(request)
|
||||
|
||||
assert result.summary.key_points == []
|
||||
assert result.summary.action_items == []
|
||||
assert "0 segments" in result.summary.executive_summary
|
||||
436
tests/infrastructure/summarization/test_ollama_provider.py
Normal file
436
tests/infrastructure/summarization/test_ollama_provider.py
Normal file
@@ -0,0 +1,436 @@
|
||||
"""Tests for Ollama summarization provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.domain.entities import Segment
|
||||
from noteflow.domain.summarization import (
|
||||
InvalidResponseError,
|
||||
ProviderUnavailableError,
|
||||
SummarizationRequest,
|
||||
)
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
|
||||
|
||||
def _segment(
|
||||
segment_id: int,
|
||||
text: str,
|
||||
start: float = 0.0,
|
||||
end: float = 5.0,
|
||||
) -> Segment:
|
||||
"""Create a test segment."""
|
||||
return Segment(
|
||||
segment_id=segment_id,
|
||||
text=text,
|
||||
start_time=start,
|
||||
end_time=end,
|
||||
)
|
||||
|
||||
|
||||
def _valid_json_response(
|
||||
summary: str = "Test summary.",
|
||||
key_points: list[dict[str, Any]] | None = None,
|
||||
action_items: list[dict[str, Any]] | None = None,
|
||||
) -> str:
|
||||
"""Build a valid JSON response string."""
|
||||
return json.dumps(
|
||||
{
|
||||
"executive_summary": summary,
|
||||
"key_points": key_points or [],
|
||||
"action_items": action_items or [],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestOllamaSummarizerProperties:
|
||||
"""Tests for OllamaSummarizer properties."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ollama_module(self, monkeypatch: pytest.MonkeyPatch) -> types.ModuleType:
|
||||
"""Mock ollama module."""
|
||||
mock_client = types.SimpleNamespace(
|
||||
list=lambda: {"models": []},
|
||||
chat=lambda **_: {"message": {"content": _valid_json_response()}},
|
||||
)
|
||||
mock_module = types.ModuleType("ollama")
|
||||
mock_module.Client = lambda host: mock_client
|
||||
monkeypatch.setitem(sys.modules, "ollama", mock_module)
|
||||
return mock_module
|
||||
|
||||
def test_provider_name(self, mock_ollama_module: types.ModuleType) -> None:
|
||||
"""Provider name should be 'ollama'."""
|
||||
from noteflow.infrastructure.summarization import OllamaSummarizer
|
||||
|
||||
summarizer = OllamaSummarizer()
|
||||
assert summarizer.provider_name == "ollama"
|
||||
|
||||
def test_requires_cloud_consent_false(self, mock_ollama_module: types.ModuleType) -> None:
|
||||
"""Ollama should not require cloud consent (local processing)."""
|
||||
from noteflow.infrastructure.summarization import OllamaSummarizer
|
||||
|
||||
summarizer = OllamaSummarizer()
|
||||
assert summarizer.requires_cloud_consent is False
|
||||
|
||||
def test_is_available_when_server_responds(self, mock_ollama_module: types.ModuleType) -> None:
|
||||
"""is_available should be True when server responds."""
|
||||
from noteflow.infrastructure.summarization import OllamaSummarizer
|
||||
|
||||
summarizer = OllamaSummarizer()
|
||||
assert summarizer.is_available is True
|
||||
|
||||
def test_is_available_false_when_connection_fails(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""is_available should be False when server unreachable."""
|
||||
|
||||
def raise_error() -> None:
|
||||
raise ConnectionError("Connection refused")
|
||||
|
||||
mock_client = types.SimpleNamespace(list=raise_error)
|
||||
mock_module = types.ModuleType("ollama")
|
||||
mock_module.Client = lambda host: mock_client
|
||||
monkeypatch.setitem(sys.modules, "ollama", mock_module)
|
||||
|
||||
from noteflow.infrastructure.summarization import OllamaSummarizer
|
||||
|
||||
summarizer = OllamaSummarizer()
|
||||
assert summarizer.is_available is False
|
||||
|
||||
|
||||
class TestOllamaSummarizerSummarize:
|
||||
"""Tests for OllamaSummarizer.summarize method."""
|
||||
|
||||
@pytest.fixture
|
||||
def meeting_id(self) -> MeetingId:
|
||||
"""Create test meeting ID."""
|
||||
return MeetingId(uuid4())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_empty_segments(
|
||||
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Empty segments should return empty summary without calling LLM."""
|
||||
call_count = 0
|
||||
|
||||
def mock_chat(**_: Any) -> dict[str, Any]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return {"message": {"content": _valid_json_response()}}
|
||||
|
||||
mock_client = types.SimpleNamespace(list=lambda: {}, chat=mock_chat)
|
||||
mock_module = types.ModuleType("ollama")
|
||||
mock_module.Client = lambda host: mock_client
|
||||
monkeypatch.setitem(sys.modules, "ollama", mock_module)
|
||||
|
||||
from noteflow.infrastructure.summarization import OllamaSummarizer
|
||||
|
||||
summarizer = OllamaSummarizer()
|
||||
request = SummarizationRequest(meeting_id=meeting_id, segments=[])
|
||||
|
||||
result = await summarizer.summarize(request)
|
||||
|
||||
assert result.summary.key_points == []
|
||||
assert result.summary.action_items == []
|
||||
assert call_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_returns_result(
|
||||
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Summarize should return SummarizationResult."""
|
||||
response = _valid_json_response(
|
||||
summary="Meeting discussed project updates.",
|
||||
key_points=[{"text": "Project on track", "segment_ids": [0]}],
|
||||
action_items=[
|
||||
{"text": "Review code", "assignee": "Alice", "priority": 2, "segment_ids": [1]}
|
||||
],
|
||||
)
|
||||
|
||||
mock_client = types.SimpleNamespace(
|
||||
list=lambda: {},
|
||||
chat=lambda **_: {"message": {"content": response}},
|
||||
)
|
||||
mock_module = types.ModuleType("ollama")
|
||||
mock_module.Client = lambda host: mock_client
|
||||
monkeypatch.setitem(sys.modules, "ollama", mock_module)
|
||||
|
||||
from noteflow.infrastructure.summarization import OllamaSummarizer
|
||||
|
||||
summarizer = OllamaSummarizer()
|
||||
segments = [
|
||||
_segment(0, "Project is on track.", 0.0, 5.0),
|
||||
_segment(1, "Alice needs to review the code.", 5.0, 10.0),
|
||||
]
|
||||
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
||||
|
||||
result = await summarizer.summarize(request)
|
||||
|
||||
assert result.provider_name == "ollama"
|
||||
assert result.summary.meeting_id == meeting_id
|
||||
assert result.summary.executive_summary == "Meeting discussed project updates."
|
||||
assert len(result.summary.key_points) == 1
|
||||
assert result.summary.key_points[0].segment_ids == [0]
|
||||
assert len(result.summary.action_items) == 1
|
||||
assert result.summary.action_items[0].assignee == "Alice"
|
||||
assert result.summary.action_items[0].priority == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_filters_invalid_segment_ids(
|
||||
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Invalid segment_ids in response should be filtered out."""
|
||||
response = _valid_json_response(
|
||||
summary="Test",
|
||||
key_points=[{"text": "Point", "segment_ids": [0, 99, 100]}], # 99, 100 invalid
|
||||
)
|
||||
|
||||
mock_client = types.SimpleNamespace(
|
||||
list=lambda: {},
|
||||
chat=lambda **_: {"message": {"content": response}},
|
||||
)
|
||||
mock_module = types.ModuleType("ollama")
|
||||
mock_module.Client = lambda host: mock_client
|
||||
monkeypatch.setitem(sys.modules, "ollama", mock_module)
|
||||
|
||||
from noteflow.infrastructure.summarization import OllamaSummarizer
|
||||
|
||||
summarizer = OllamaSummarizer()
|
||||
segments = [_segment(0, "Only segment")]
|
||||
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
||||
|
||||
result = await summarizer.summarize(request)
|
||||
|
||||
assert result.summary.key_points[0].segment_ids == [0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_respects_max_limits(
|
||||
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Response items exceeding max limits should be truncated."""
|
||||
response = _valid_json_response(
|
||||
summary="Test",
|
||||
key_points=[{"text": f"Point {i}", "segment_ids": [0]} for i in range(10)],
|
||||
action_items=[{"text": f"Action {i}", "segment_ids": [0]} for i in range(10)],
|
||||
)
|
||||
|
||||
mock_client = types.SimpleNamespace(
|
||||
list=lambda: {},
|
||||
chat=lambda **_: {"message": {"content": response}},
|
||||
)
|
||||
mock_module = types.ModuleType("ollama")
|
||||
mock_module.Client = lambda host: mock_client
|
||||
monkeypatch.setitem(sys.modules, "ollama", mock_module)
|
||||
|
||||
from noteflow.infrastructure.summarization import OllamaSummarizer
|
||||
|
||||
summarizer = OllamaSummarizer()
|
||||
segments = [_segment(0, "Test segment")]
|
||||
request = SummarizationRequest(
|
||||
meeting_id=meeting_id,
|
||||
segments=segments,
|
||||
max_key_points=3,
|
||||
max_action_items=2,
|
||||
)
|
||||
|
||||
result = await summarizer.summarize(request)
|
||||
|
||||
assert len(result.summary.key_points) == 3
|
||||
assert len(result.summary.action_items) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_handles_markdown_fenced_json(
|
||||
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Markdown code fences around JSON should be stripped."""
|
||||
json_content = _valid_json_response(summary="Fenced response")
|
||||
response = f"```json\n{json_content}\n```"
|
||||
|
||||
mock_client = types.SimpleNamespace(
|
||||
list=lambda: {},
|
||||
chat=lambda **_: {"message": {"content": response}},
|
||||
)
|
||||
mock_module = types.ModuleType("ollama")
|
||||
mock_module.Client = lambda host: mock_client
|
||||
monkeypatch.setitem(sys.modules, "ollama", mock_module)
|
||||
|
||||
from noteflow.infrastructure.summarization import OllamaSummarizer
|
||||
|
||||
summarizer = OllamaSummarizer()
|
||||
segments = [_segment(0, "Test")]
|
||||
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
||||
|
||||
result = await summarizer.summarize(request)
|
||||
|
||||
assert result.summary.executive_summary == "Fenced response"
|
||||
|
||||
|
||||
class TestOllamaSummarizerErrors:
|
||||
"""Tests for OllamaSummarizer error handling."""
|
||||
|
||||
@pytest.fixture
|
||||
def meeting_id(self) -> MeetingId:
|
||||
"""Create test meeting ID."""
|
||||
return MeetingId(uuid4())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_unavailable_when_package_missing(
|
||||
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Should raise ProviderUnavailableError when ollama not installed."""
|
||||
# Remove ollama from sys.modules if present
|
||||
monkeypatch.delitem(sys.modules, "ollama", raising=False)
|
||||
|
||||
# Make import fail
|
||||
import builtins
|
||||
|
||||
original_import = builtins.__import__
|
||||
|
||||
def mock_import(name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
if name == "ollama":
|
||||
raise ImportError("No module named 'ollama'")
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, "__import__", mock_import)
|
||||
|
||||
# Need to reload the module to trigger fresh import
|
||||
from noteflow.infrastructure.summarization import ollama_provider
|
||||
|
||||
# Create fresh instance that will try to import
|
||||
summarizer = ollama_provider.OllamaSummarizer()
|
||||
summarizer._client = None # Force re-import attempt
|
||||
|
||||
segments = [_segment(0, "Test")]
|
||||
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
||||
|
||||
with pytest.raises(ProviderUnavailableError, match="ollama package not installed"):
|
||||
await summarizer.summarize(request)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_unavailable_on_connection_error(
|
||||
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Should raise ProviderUnavailableError on connection failure."""
|
||||
|
||||
def raise_connection_error(**_: Any) -> None:
|
||||
raise ConnectionRefusedError("Connection refused")
|
||||
|
||||
mock_client = types.SimpleNamespace(
|
||||
list=lambda: {},
|
||||
chat=raise_connection_error,
|
||||
)
|
||||
mock_module = types.ModuleType("ollama")
|
||||
mock_module.Client = lambda host: mock_client
|
||||
monkeypatch.setitem(sys.modules, "ollama", mock_module)
|
||||
|
||||
from noteflow.infrastructure.summarization import OllamaSummarizer
|
||||
|
||||
summarizer = OllamaSummarizer()
|
||||
segments = [_segment(0, "Test")]
|
||||
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
||||
|
||||
with pytest.raises(ProviderUnavailableError, match="Cannot connect"):
|
||||
await summarizer.summarize(request)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_invalid_response_on_bad_json(
|
||||
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Should raise InvalidResponseError on malformed JSON."""
|
||||
mock_client = types.SimpleNamespace(
|
||||
list=lambda: {},
|
||||
chat=lambda **_: {"message": {"content": "not valid json {{{"}},
|
||||
)
|
||||
mock_module = types.ModuleType("ollama")
|
||||
mock_module.Client = lambda host: mock_client
|
||||
monkeypatch.setitem(sys.modules, "ollama", mock_module)
|
||||
|
||||
from noteflow.infrastructure.summarization import OllamaSummarizer
|
||||
|
||||
summarizer = OllamaSummarizer()
|
||||
segments = [_segment(0, "Test")]
|
||||
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
||||
|
||||
with pytest.raises(InvalidResponseError, match="Invalid JSON"):
|
||||
await summarizer.summarize(request)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_invalid_response_on_empty_content(
|
||||
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Should raise InvalidResponseError on empty response."""
|
||||
mock_client = types.SimpleNamespace(
|
||||
list=lambda: {},
|
||||
chat=lambda **_: {"message": {"content": ""}},
|
||||
)
|
||||
mock_module = types.ModuleType("ollama")
|
||||
mock_module.Client = lambda host: mock_client
|
||||
monkeypatch.setitem(sys.modules, "ollama", mock_module)
|
||||
|
||||
from noteflow.infrastructure.summarization import OllamaSummarizer
|
||||
|
||||
summarizer = OllamaSummarizer()
|
||||
segments = [_segment(0, "Test")]
|
||||
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
||||
|
||||
with pytest.raises(InvalidResponseError, match="Empty response"):
|
||||
await summarizer.summarize(request)
|
||||
|
||||
|
||||
class TestOllamaSummarizerConfiguration:
|
||||
"""Tests for OllamaSummarizer configuration."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_model_name(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Custom model name should be used."""
|
||||
captured_model = None
|
||||
|
||||
def capture_chat(**kwargs: Any) -> dict[str, Any]:
|
||||
nonlocal captured_model
|
||||
captured_model = kwargs.get("model")
|
||||
return {"message": {"content": _valid_json_response()}}
|
||||
|
||||
mock_client = types.SimpleNamespace(list=lambda: {}, chat=capture_chat)
|
||||
mock_module = types.ModuleType("ollama")
|
||||
mock_module.Client = lambda host: mock_client
|
||||
monkeypatch.setitem(sys.modules, "ollama", mock_module)
|
||||
|
||||
from noteflow.infrastructure.summarization import OllamaSummarizer
|
||||
|
||||
summarizer = OllamaSummarizer(model="mistral")
|
||||
meeting_id = MeetingId(uuid4())
|
||||
segments = [_segment(0, "Test")]
|
||||
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
||||
|
||||
await summarizer.summarize(request)
|
||||
|
||||
assert captured_model == "mistral"
|
||||
|
||||
def test_custom_host(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Custom host should be passed to client."""
|
||||
captured_host = None
|
||||
|
||||
def capture_client(host: str) -> types.SimpleNamespace:
|
||||
nonlocal captured_host
|
||||
captured_host = host
|
||||
return types.SimpleNamespace(
|
||||
list=lambda: {},
|
||||
chat=lambda **_: {"message": {"content": _valid_json_response()}},
|
||||
)
|
||||
|
||||
mock_module = types.ModuleType("ollama")
|
||||
mock_module.Client = capture_client
|
||||
monkeypatch.setitem(sys.modules, "ollama", mock_module)
|
||||
|
||||
from noteflow.infrastructure.summarization import OllamaSummarizer
|
||||
|
||||
summarizer = OllamaSummarizer(host="http://custom:8080")
|
||||
_ = summarizer.is_available
|
||||
|
||||
assert captured_host == "http://custom:8080"
|
||||
57
tests/infrastructure/triggers/conftest.py
Normal file
57
tests/infrastructure/triggers/conftest.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Test fixtures for trigger infrastructure tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummyWindow:
|
||||
"""Mock window object for pywinctl tests."""
|
||||
|
||||
title: str | None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pywinctl(monkeypatch: pytest.MonkeyPatch) -> Callable[[str | None], None]:
|
||||
"""Factory fixture to install mocked pywinctl module.
|
||||
|
||||
Usage:
|
||||
mock_pywinctl("Zoom Meeting") # Window with title
|
||||
mock_pywinctl(None) # No active window
|
||||
"""
|
||||
|
||||
def _install(title: str | None) -> None:
|
||||
window = DummyWindow(title) if title is not None else None
|
||||
module = types.SimpleNamespace(getActiveWindow=lambda: window)
|
||||
monkeypatch.setitem(sys.modules, "pywinctl", module)
|
||||
|
||||
return _install
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pywinctl_unavailable(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Install pywinctl mock that raises ImportError."""
|
||||
|
||||
def raise_import_error() -> None:
|
||||
msg = "No module named 'pywinctl'"
|
||||
raise ImportError(msg)
|
||||
|
||||
monkeypatch.setitem(sys.modules, "pywinctl", None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pywinctl_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Install pywinctl mock that raises RuntimeError on getActiveWindow."""
|
||||
|
||||
def raise_runtime_error() -> None:
|
||||
msg = "No display available"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
module = types.SimpleNamespace(getActiveWindow=raise_runtime_error)
|
||||
monkeypatch.setitem(sys.modules, "pywinctl", module)
|
||||
185
tests/infrastructure/triggers/test_audio_activity.py
Normal file
185
tests/infrastructure/triggers/test_audio_activity.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""Tests for audio activity trigger provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from noteflow.infrastructure.audio import RmsLevelProvider
|
||||
from noteflow.infrastructure.triggers.audio_activity import (
|
||||
AudioActivityProvider,
|
||||
AudioActivitySettings,
|
||||
)
|
||||
|
||||
|
||||
def _settings(**overrides: object) -> AudioActivitySettings:
|
||||
defaults: dict[str, object] = {
|
||||
"enabled": True,
|
||||
"threshold_db": -20.0,
|
||||
"window_seconds": 10.0,
|
||||
"min_active_ratio": 0.6,
|
||||
"min_samples": 3,
|
||||
"max_history": 10,
|
||||
"weight": 0.3,
|
||||
} | overrides
|
||||
return AudioActivitySettings(**defaults)
|
||||
|
||||
|
||||
def test_audio_activity_settings_validation() -> None:
|
||||
"""Settings should reject min_samples greater than max_history."""
|
||||
with pytest.raises(ValueError, match="min_samples"):
|
||||
AudioActivitySettings(
|
||||
enabled=True,
|
||||
threshold_db=-20.0,
|
||||
window_seconds=5.0,
|
||||
min_active_ratio=0.5,
|
||||
min_samples=11,
|
||||
max_history=10,
|
||||
weight=0.3,
|
||||
)
|
||||
|
||||
|
||||
def test_audio_activity_provider_disabled_ignores_updates() -> None:
|
||||
"""Disabled provider should not emit signals."""
|
||||
provider = AudioActivityProvider(RmsLevelProvider(), _settings(enabled=False))
|
||||
frames = np.ones(10, dtype=np.float32)
|
||||
|
||||
provider.update(frames, timestamp=1.0)
|
||||
|
||||
assert provider.get_signal() is None
|
||||
|
||||
|
||||
def test_audio_activity_provider_emits_signal(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Provider emits a signal when sustained activity passes ratio threshold."""
|
||||
provider = AudioActivityProvider(RmsLevelProvider(), _settings())
|
||||
active = np.ones(10, dtype=np.float32)
|
||||
inactive = np.zeros(10, dtype=np.float32)
|
||||
|
||||
provider.update(active, timestamp=1.0)
|
||||
provider.update(active, timestamp=2.0)
|
||||
provider.update(inactive, timestamp=3.0)
|
||||
|
||||
monkeypatch.setattr(time, "monotonic", lambda: 4.0)
|
||||
signal = provider.get_signal()
|
||||
|
||||
assert signal is not None
|
||||
assert signal.weight == pytest.approx(0.3)
|
||||
|
||||
|
||||
def test_audio_activity_provider_window_excludes_old_samples(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Samples outside the window should not contribute to activity ratio."""
|
||||
provider = AudioActivityProvider(RmsLevelProvider(), _settings(window_seconds=2.0))
|
||||
active = np.ones(10, dtype=np.float32)
|
||||
|
||||
provider.update(active, timestamp=1.0)
|
||||
provider.update(active, timestamp=2.0)
|
||||
provider.update(active, timestamp=3.0)
|
||||
|
||||
monkeypatch.setattr(time, "monotonic", lambda: 10.0)
|
||||
assert provider.get_signal() is None
|
||||
|
||||
|
||||
def test_audio_activity_provider_source_property() -> None:
|
||||
"""Provider source should be AUDIO_ACTIVITY."""
|
||||
from noteflow.domain.triggers.entities import TriggerSource
|
||||
|
||||
provider = AudioActivityProvider(RmsLevelProvider(), _settings())
|
||||
assert provider.source == TriggerSource.AUDIO_ACTIVITY
|
||||
|
||||
|
||||
def test_audio_activity_provider_max_weight_property() -> None:
|
||||
"""Provider max_weight should reflect configured weight."""
|
||||
provider = AudioActivityProvider(RmsLevelProvider(), _settings(weight=0.5))
|
||||
assert provider.max_weight == pytest.approx(0.5)
|
||||
|
||||
|
||||
def test_audio_activity_provider_is_enabled_reflects_settings() -> None:
|
||||
"""is_enabled should reflect settings.enabled."""
|
||||
enabled_provider = AudioActivityProvider(RmsLevelProvider(), _settings(enabled=True))
|
||||
disabled_provider = AudioActivityProvider(RmsLevelProvider(), _settings(enabled=False))
|
||||
|
||||
assert enabled_provider.is_enabled() is True
|
||||
assert disabled_provider.is_enabled() is False
|
||||
|
||||
|
||||
def test_audio_activity_provider_clear_history() -> None:
|
||||
"""clear_history should reset the activity history."""
|
||||
provider = AudioActivityProvider(RmsLevelProvider(), _settings())
|
||||
active = np.ones(10, dtype=np.float32)
|
||||
|
||||
provider.update(active, timestamp=1.0)
|
||||
provider.update(active, timestamp=2.0)
|
||||
provider.update(active, timestamp=3.0)
|
||||
|
||||
provider.clear_history()
|
||||
|
||||
# After clearing, signal should be None due to insufficient samples
|
||||
assert provider.get_signal() is None
|
||||
|
||||
|
||||
def test_audio_activity_provider_insufficient_samples() -> None:
|
||||
"""Provider should return None when history has fewer than min_samples."""
|
||||
provider = AudioActivityProvider(RmsLevelProvider(), _settings(min_samples=5))
|
||||
active = np.ones(10, dtype=np.float32)
|
||||
|
||||
# Add only 3 samples (less than min_samples=5)
|
||||
provider.update(active, timestamp=1.0)
|
||||
provider.update(active, timestamp=2.0)
|
||||
provider.update(active, timestamp=3.0)
|
||||
|
||||
assert provider.get_signal() is None
|
||||
|
||||
|
||||
def test_audio_activity_provider_below_activity_ratio() -> None:
|
||||
"""Provider should return None when active ratio < min_active_ratio."""
|
||||
provider = AudioActivityProvider(RmsLevelProvider(), _settings(min_active_ratio=0.7))
|
||||
active = np.ones(10, dtype=np.float32)
|
||||
inactive = np.zeros(10, dtype=np.float32)
|
||||
|
||||
# Add 3 active, 7 inactive = 30% active ratio (below 70% threshold)
|
||||
provider.update(active, timestamp=1.0)
|
||||
provider.update(active, timestamp=2.0)
|
||||
provider.update(active, timestamp=3.0)
|
||||
provider.update(inactive, timestamp=4.0)
|
||||
provider.update(inactive, timestamp=5.0)
|
||||
provider.update(inactive, timestamp=6.0)
|
||||
provider.update(inactive, timestamp=7.0)
|
||||
provider.update(inactive, timestamp=8.0)
|
||||
provider.update(inactive, timestamp=9.0)
|
||||
provider.update(inactive, timestamp=10.0)
|
||||
|
||||
assert provider.get_signal() is None
|
||||
|
||||
|
||||
def test_audio_activity_provider_boundary_activity_ratio(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Provider should emit signal when ratio exactly equals min_active_ratio."""
|
||||
provider = AudioActivityProvider(
|
||||
RmsLevelProvider(),
|
||||
_settings(min_active_ratio=0.6, min_samples=5, max_history=10),
|
||||
)
|
||||
active = np.ones(10, dtype=np.float32)
|
||||
inactive = np.zeros(10, dtype=np.float32)
|
||||
|
||||
# Add 6 active, 4 inactive = 60% active ratio (exactly at threshold)
|
||||
provider.update(active, timestamp=1.0)
|
||||
provider.update(active, timestamp=2.0)
|
||||
provider.update(active, timestamp=3.0)
|
||||
provider.update(active, timestamp=4.0)
|
||||
provider.update(active, timestamp=5.0)
|
||||
provider.update(active, timestamp=6.0)
|
||||
provider.update(inactive, timestamp=7.0)
|
||||
provider.update(inactive, timestamp=8.0)
|
||||
provider.update(inactive, timestamp=9.0)
|
||||
provider.update(inactive, timestamp=10.0)
|
||||
|
||||
monkeypatch.setattr(time, "monotonic", lambda: 11.0)
|
||||
signal = provider.get_signal()
|
||||
|
||||
assert signal is not None
|
||||
assert signal.weight == pytest.approx(0.3)
|
||||
215
tests/infrastructure/triggers/test_foreground_app.py
Normal file
215
tests/infrastructure/triggers/test_foreground_app.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""Tests for foreground app trigger provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.domain.triggers.entities import TriggerSource
|
||||
from noteflow.infrastructure.triggers.foreground_app import (
|
||||
ForegroundAppProvider,
|
||||
ForegroundAppSettings,
|
||||
)
|
||||
|
||||
|
||||
class DummyWindow:
|
||||
"""Mock window object for pywinctl tests."""
|
||||
|
||||
def __init__(self, title: str | None) -> None:
|
||||
self.title = title
|
||||
|
||||
|
||||
def _install_pywinctl(monkeypatch: pytest.MonkeyPatch, title: str | None) -> None:
|
||||
"""Install mocked pywinctl with specified window title."""
|
||||
window = DummyWindow(title) if title is not None else None
|
||||
module = types.SimpleNamespace(getActiveWindow=lambda: window)
|
||||
monkeypatch.setitem(sys.modules, "pywinctl", module)
|
||||
|
||||
|
||||
def _settings(**overrides: object) -> ForegroundAppSettings:
|
||||
"""Create ForegroundAppSettings with defaults and overrides."""
|
||||
defaults: dict[str, object] = {
|
||||
"enabled": True,
|
||||
"weight": 0.4,
|
||||
"meeting_apps": {"zoom"},
|
||||
"suppressed_apps": set(),
|
||||
} | overrides
|
||||
return ForegroundAppSettings(**defaults)
|
||||
|
||||
|
||||
# --- Existing Tests ---
|
||||
|
||||
|
||||
def test_foreground_app_provider_emits_signal(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Provider emits signal when a meeting app is in foreground."""
|
||||
_install_pywinctl(monkeypatch, "Zoom Meeting")
|
||||
provider = ForegroundAppProvider(_settings())
|
||||
|
||||
signal = provider.get_signal()
|
||||
|
||||
assert signal is not None
|
||||
assert signal.weight == pytest.approx(0.4)
|
||||
assert signal.app_name == "Zoom Meeting"
|
||||
|
||||
|
||||
def test_foreground_app_provider_suppressed(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Suppressed apps should not emit signals."""
|
||||
_install_pywinctl(monkeypatch, "Zoom Meeting")
|
||||
provider = ForegroundAppProvider(_settings(suppressed_apps={"zoom"}))
|
||||
|
||||
assert provider.get_signal() is None
|
||||
|
||||
|
||||
def test_foreground_app_provider_unavailable(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Unavailable provider should report disabled."""
|
||||
provider = ForegroundAppProvider(_settings())
|
||||
monkeypatch.setattr(provider, "_is_available", lambda: False)
|
||||
|
||||
assert provider.is_enabled() is False
|
||||
|
||||
|
||||
# --- New Tests ---
|
||||
|
||||
|
||||
def test_foreground_app_provider_source_property() -> None:
|
||||
"""Provider source should be FOREGROUND_APP."""
|
||||
provider = ForegroundAppProvider(_settings())
|
||||
assert provider.source == TriggerSource.FOREGROUND_APP
|
||||
|
||||
|
||||
def test_foreground_app_provider_max_weight_property() -> None:
|
||||
"""Provider max_weight should reflect configured weight."""
|
||||
provider = ForegroundAppProvider(_settings(weight=0.5))
|
||||
assert provider.max_weight == pytest.approx(0.5)
|
||||
|
||||
|
||||
def test_foreground_app_settings_lowercases_apps() -> None:
|
||||
"""Settings __post_init__ should lowercase meeting_apps and suppressed_apps."""
|
||||
settings = ForegroundAppSettings(
|
||||
enabled=True,
|
||||
weight=0.4,
|
||||
meeting_apps={"ZOOM", "Teams", "GoToMeeting"},
|
||||
suppressed_apps={"SLACK", "Discord"},
|
||||
)
|
||||
|
||||
assert "zoom" in settings.meeting_apps
|
||||
assert "teams" in settings.meeting_apps
|
||||
assert "gotomeeting" in settings.meeting_apps
|
||||
assert "slack" in settings.suppressed_apps
|
||||
assert "discord" in settings.suppressed_apps
|
||||
# Original case should not be present
|
||||
assert "ZOOM" not in settings.meeting_apps
|
||||
assert "SLACK" not in settings.suppressed_apps
|
||||
|
||||
|
||||
def test_foreground_app_provider_disabled_returns_none(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Provider should return None when enabled=False."""
|
||||
_install_pywinctl(monkeypatch, "Zoom Meeting")
|
||||
provider = ForegroundAppProvider(_settings(enabled=False))
|
||||
|
||||
assert provider.get_signal() is None
|
||||
|
||||
|
||||
def test_foreground_app_provider_no_window_returns_none(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Provider should return None when getActiveWindow() returns None."""
|
||||
_install_pywinctl(monkeypatch, None)
|
||||
provider = ForegroundAppProvider(_settings())
|
||||
# Force availability check to succeed
|
||||
provider._available = True
|
||||
|
||||
assert provider.get_signal() is None
|
||||
|
||||
|
||||
def test_foreground_app_provider_empty_title_returns_none(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Provider should return None when window title is empty string."""
|
||||
_install_pywinctl(monkeypatch, "")
|
||||
provider = ForegroundAppProvider(_settings())
|
||||
provider._available = True
|
||||
|
||||
assert provider.get_signal() is None
|
||||
|
||||
|
||||
def test_foreground_app_provider_non_meeting_app_returns_none(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Provider should return None when foreground app is not a meeting app."""
|
||||
_install_pywinctl(monkeypatch, "Firefox Browser")
|
||||
provider = ForegroundAppProvider(_settings(meeting_apps={"zoom", "teams"}))
|
||||
provider._available = True
|
||||
|
||||
assert provider.get_signal() is None
|
||||
|
||||
|
||||
def test_foreground_app_provider_suppress_app() -> None:
|
||||
"""suppress_app should add lowercased app to suppressed_apps."""
|
||||
provider = ForegroundAppProvider(_settings(suppressed_apps=set()))
|
||||
|
||||
provider.suppress_app("ZOOM")
|
||||
provider.suppress_app("Teams")
|
||||
|
||||
assert "zoom" in provider.suppressed_apps
|
||||
assert "teams" in provider.suppressed_apps
|
||||
|
||||
|
||||
def test_foreground_app_provider_unsuppress_app() -> None:
|
||||
"""unsuppress_app should remove app from suppressed_apps."""
|
||||
provider = ForegroundAppProvider(_settings(suppressed_apps={"zoom", "teams"}))
|
||||
|
||||
provider.unsuppress_app("zoom")
|
||||
|
||||
assert "zoom" not in provider.suppressed_apps
|
||||
assert "teams" in provider.suppressed_apps
|
||||
|
||||
|
||||
def test_foreground_app_provider_add_meeting_app() -> None:
|
||||
"""add_meeting_app should add lowercased app to meeting_apps."""
|
||||
provider = ForegroundAppProvider(_settings(meeting_apps={"zoom"}))
|
||||
|
||||
provider.add_meeting_app("WEBEX")
|
||||
provider.add_meeting_app("RingCentral")
|
||||
|
||||
assert "webex" in provider._settings.meeting_apps
|
||||
assert "ringcentral" in provider._settings.meeting_apps
|
||||
|
||||
|
||||
def test_foreground_app_provider_suppressed_apps_property() -> None:
|
||||
"""suppressed_apps property should return frozenset."""
|
||||
provider = ForegroundAppProvider(_settings(suppressed_apps={"zoom", "teams"}))
|
||||
|
||||
result = provider.suppressed_apps
|
||||
|
||||
assert isinstance(result, frozenset)
|
||||
assert "zoom" in result
|
||||
assert "teams" in result
|
||||
|
||||
|
||||
def test_foreground_app_provider_case_insensitive_matching(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Provider should match meeting apps case-insensitively."""
|
||||
_install_pywinctl(monkeypatch, "ZOOM MEETING - Conference Room")
|
||||
provider = ForegroundAppProvider(_settings(meeting_apps={"zoom"}))
|
||||
provider._available = True
|
||||
|
||||
signal = provider.get_signal()
|
||||
|
||||
assert signal is not None
|
||||
assert signal.app_name == "ZOOM MEETING - Conference Room"
|
||||
|
||||
|
||||
def test_foreground_app_provider_is_enabled_when_enabled_and_available(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""is_enabled should return True when both enabled and available."""
|
||||
_install_pywinctl(monkeypatch, "Some Window")
|
||||
provider = ForegroundAppProvider(_settings(enabled=True))
|
||||
|
||||
assert provider.is_enabled() is True
|
||||
@@ -3,8 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from uuid import uuid4
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
29
tests/integration/test_trigger_settings.py
Normal file
29
tests/integration/test_trigger_settings.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Integration tests for trigger settings loading."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.config.settings import get_trigger_settings
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_trigger_settings_cache() -> None:
|
||||
get_trigger_settings.cache_clear()
|
||||
|
||||
|
||||
def test_trigger_settings_env_parsing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""TriggerSettings should parse CSV lists from environment variables."""
|
||||
monkeypatch.setenv("NOTEFLOW_TRIGGER_MEETING_APPS", "zoom, teams")
|
||||
monkeypatch.setenv("NOTEFLOW_TRIGGER_SUPPRESSED_APPS", "spotify")
|
||||
monkeypatch.setenv("NOTEFLOW_TRIGGER_AUDIO_MIN_SAMPLES", "5")
|
||||
monkeypatch.setenv("NOTEFLOW_TRIGGER_POLL_INTERVAL_SECONDS", "1.5")
|
||||
|
||||
settings = get_trigger_settings()
|
||||
|
||||
assert settings.trigger_meeting_apps == ["zoom", "teams"]
|
||||
assert settings.trigger_suppressed_apps == ["spotify"]
|
||||
assert settings.trigger_audio_min_samples == 5
|
||||
assert settings.trigger_poll_interval_seconds == pytest.approx(1.5)
|
||||
Reference in New Issue
Block a user