From 40f9adf27f1f323f019e0e9e42ef4be6bd69106f Mon Sep 17 00:00:00 2001 From: Travis Vasceannie Date: Tue, 20 Jan 2026 02:40:24 +0000 Subject: [PATCH] feat: introduce GLiNER and spaCy backends for NER, refactor NER infrastructure, and update client-side entity extraction. --- .claude/ralph-loop.local.md | 4 +- .claudectx/codefixes.md | 602 ++++++++++-------- .mcp.json | 12 + client/src-tauri/scripts/code_quality.sh | 6 +- .../src-tauri/src/audio/windows_loopback.rs | 12 +- client/src-tauri/src/events/mod.rs | 18 +- .../src-tauri/src/grpc/client/annotations.rs | 4 +- client/src-tauri/src/identity/mod.rs | 50 +- client/src-tauri/src/lib.rs | 11 +- client/src-tauri/src/state/app_state.rs | 5 +- .../features/calendar/upcoming-meetings.tsx | 7 +- client/src/hooks/auth/use-auth-flow.ts | 33 +- client/src/hooks/auth/use-oauth-flow.ts | 25 +- .../src/hooks/processing/use-diarization.ts | 10 +- .../hooks/processing/use-entity-extraction.ts | 11 +- .../hooks/recording/use-recording-session.ts | 16 + client/src/hooks/sync/use-calendar-sync.ts | 10 +- client/src/lib/cache/meeting-cache.ts | 18 +- client/src/lib/observability/client.ts | 25 +- client/src/pages/NotFound.tsx | 6 - .../sprint-organization/api.md | 0 .../sprint-organization/components.md | 0 .../sprint-organization/hooks.md | 0 .../sprint-organization/lib.md | 0 pyproject.toml | 12 +- .../services/asr_config/_job_manager.py | 43 +- .../application/services/ner/service.py | 6 +- src/noteflow/cli/models/_registry.py | 17 +- src/noteflow/config/constants/__init__.py | 4 + src/noteflow/config/constants/domain.py | 10 + src/noteflow/config/settings/_main.py | 46 +- src/noteflow/domain/entities/named_entity.py | 18 +- src/noteflow/grpc/mixins/_model_status.py | 5 +- src/noteflow/grpc/mixins/_task_callbacks.py | 29 +- .../streaming/_processing/_congestion.py | 14 +- src/noteflow/grpc/server/__init__.py | 21 +- src/noteflow/grpc/servicer/mixins.py | 5 +- src/noteflow/grpc/startup/services.py | 28 +- src/noteflow/infrastructure/CLAUDE.md | 352 ++++++++++ .../infrastructure/asr/pytorch_engine.py | 7 +- src/noteflow/infrastructure/audio/capture.py | 8 +- src/noteflow/infrastructure/ner/__init__.py | 5 +- .../infrastructure/ner/backends/__init__.py | 5 + .../ner/backends/gliner_backend.py | 122 ++++ .../ner/backends/spacy_backend.py | 152 +++++ .../infrastructure/ner/backends/types.py | 56 ++ src/noteflow/infrastructure/ner/engine.py | 309 ++------- src/noteflow/infrastructure/ner/mapper.py | 65 ++ .../infrastructure/ner/post_processing.py | 243 +++++++ .../r2s3t4u5v6w7_fix_segment_ids_jsonb.py | 27 +- ...6w7x8_fix_diarization_speaker_ids_jsonb.py | 31 +- .../infrastructure/security/crypto/_base.py | 26 + .../infrastructure/security/crypto/_reader.py | 3 +- .../infrastructure/security/crypto/_writer.py | 14 +- tests/application/test_asr_config_service.py | 57 +- .../audio/test_partial_buffer.py | 11 +- tests/infrastructure/ner/test_engine.py | 37 +- .../infrastructure/ner/test_gliner_backend.py | 219 +++++++ tests/infrastructure/ner/test_mapper.py | 100 +++ .../ner/test_post_processing.py | 158 +++++ tests/integration/test_hf_token_grpc.py | 6 +- typings/gliner/__init__.pyi | 35 + 62 files changed, 2495 insertions(+), 696 deletions(-) rename docs/sprints/{phase-ongoing => .archive}/sprint-organization/api.md (100%) rename docs/sprints/{phase-ongoing => .archive}/sprint-organization/components.md (100%) rename docs/sprints/{phase-ongoing => .archive}/sprint-organization/hooks.md (100%) rename docs/sprints/{phase-ongoing => .archive}/sprint-organization/lib.md (100%) create mode 100644 src/noteflow/infrastructure/CLAUDE.md create mode 100644 src/noteflow/infrastructure/ner/backends/__init__.py create mode 100644 src/noteflow/infrastructure/ner/backends/gliner_backend.py create mode 100644 src/noteflow/infrastructure/ner/backends/spacy_backend.py create mode 100644 src/noteflow/infrastructure/ner/backends/types.py create mode 100644 src/noteflow/infrastructure/ner/mapper.py create mode 100644 src/noteflow/infrastructure/ner/post_processing.py create mode 100644 src/noteflow/infrastructure/security/crypto/_base.py create mode 100644 tests/infrastructure/ner/test_gliner_backend.py create mode 100644 tests/infrastructure/ner/test_mapper.py create mode 100644 tests/infrastructure/ner/test_post_processing.py create mode 100644 typings/gliner/__init__.pyi diff --git a/.claude/ralph-loop.local.md b/.claude/ralph-loop.local.md index ce5db08..81163c6 100644 --- a/.claude/ralph-loop.local.md +++ b/.claude/ralph-loop.local.md @@ -1,9 +1,11 @@ --- active: true -iteration: 3 +iteration: 1 max_iterations: 0 completion_promise: null started_at: "2026-01-20T02:31:55Z" +started_at: "2026-01-20T02:31:55Z" --- proceed with the plan, i have also documented a copy in @.claudectx/codefixes.md. please use your agents iteratively to manage context and speed, however you must review the accuracy and value of each doc before moving to the next +proceed with the plan, i have also documented a copy in @.claudectx/codefixes.md. please use your agents iteratively to manage context and speed, however you must review the accuracy and value of each doc before moving to the next diff --git a/.claudectx/codefixes.md b/.claudectx/codefixes.md index f3f8864..0febd02 100644 --- a/.claudectx/codefixes.md +++ b/.claudectx/codefixes.md @@ -1,287 +1,377 @@ -# Entity extraction robustness notes +╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌ + Strategic CLAUDE.md Placement Analysis for NoteFlow -## Library options that can do heavy lifting -- GLiNER (Apache-2.0): generalist NER model that supports arbitrary labels without retraining; API via `GLiNER.from_pretrained(...).predict_entities(text, labels, threshold=0.5)`. - - Source: https://github.com/urchade/GLiNER -- Presidio (MIT): pipeline framework for detection + redaction focused on PII; includes NLP + regex recognizers, and extensible recognizer registry; useful for sensitive entity categories rather than general “meeting entities”. - - Source: https://github.com/microsoft/presidio -- Stanza (Apache-2.0): full NLP pipeline with NER component; supports many languages; requires model download and can add CoreNLP features. - - Source: https://stanfordnlp.github.io/stanza/ -- spaCy (open-source): production-oriented NLP toolkit with NER + entity ruler; offers fast pipelines and rule-based patterns; easy to combine ML + rules for domain entities. - - Source: https://spacy.io/usage/spacy-101 + Executive Summary -## Recommendation for NoteFlow -- Best “heavy lifting” candidate: GLiNER for flexible, domain-specific entity types without retraining. It can be fed label sets like [person, company, product, app, location, time, duration, event] and will return per-label spans. -- If PII detection is a requirement: layer Presidio on top of GLiNER outputs (or run Presidio first for sensitive categories). Presidio is not optimized for general entity relevance outside PII. -- If you prefer classic pipelines or need dependency parsing: spaCy or Stanza; spaCy is easier to extend with rule-based entity ruler and custom merge policies. + This document analyzes optimal placement of CLAUDE.md files throughout the NoteFlow codebase to provide meaningful context for AI assistants. The analysis considers both + constrained (strategic) and unlimited scenarios. -## Quick wins for robustness (model-agnostic) -- Normalize text: capitalization restoration, punctuation cleanup, ASR filler removal before NER. -- Post-process: merge adjacent time phrases ("last night" + "20 minutes"), dedupe casing, drop single-token low-signal entities. -- Context scoring: down-rank profanity/generic nouns unless in quoted/explicit phrases. -- Speaker-aware aggregation: extract per speaker, then merge and rank by frequency and context windows. + --- + Current State: Existing Documentation Files -## Meeting entity schema (proposed) -- person -- org -- product -- app -- location -- time -- time_relative -- duration -- event -- task -- decision -- topic + 10 CLAUDE.md/AGENTS.md Files Already Present + ┌─────────────────────┬───────────┬──────────────────────────────────────────────────────────┐ + │ Location │ File │ Focus │ + ├─────────────────────┼───────────┼──────────────────────────────────────────────────────────┤ + │ / │ CLAUDE.md │ Root orchestration, parallel execution, project overview │ + ├─────────────────────┼───────────┼──────────────────────────────────────────────────────────┤ + │ / │ AGENTS.md │ Architecture for non-Claude AI assistants │ + ├─────────────────────┼───────────┼──────────────────────────────────────────────────────────┤ + │ /src/ │ CLAUDE.md │ Python backend entry point │ + ├─────────────────────┼───────────┼──────────────────────────────────────────────────────────┤ + │ /src/ │ AGENTS.md │ Python backend for other AIs │ + ├─────────────────────┼───────────┼──────────────────────────────────────────────────────────┤ + │ /src/noteflow/ │ CLAUDE.md │ Detailed Python standards (line limits, typing, modules) │ + ├─────────────────────┼───────────┼──────────────────────────────────────────────────────────┤ + │ /src/noteflow/grpc/ │ CLAUDE.md │ gRPC security patterns │ + ├─────────────────────┼───────────┼──────────────────────────────────────────────────────────┤ + │ /src/noteflow/grpc/ │ AGENTS.md │ gRPC security (duplicate) │ + ├─────────────────────┼───────────┼──────────────────────────────────────────────────────────┤ + │ /client/ │ CLAUDE.md │ Tauri + React development │ + ├─────────────────────┼───────────┼──────────────────────────────────────────────────────────┤ + │ /client/src/ │ CLAUDE.md │ TypeScript security rules │ + ├─────────────────────┼───────────┼──────────────────────────────────────────────────────────┤ + │ /docker/ │ CLAUDE.md │ Docker security and build patterns │ + └─────────────────────┴───────────┴──────────────────────────────────────────────────────────┘ + --- + Part 1: Strategic Placement (Constrained Resources) -## spaCy label mapping (to meeting schema) -- PERSON -> person -- ORG -> org -- GPE/LOC/FAC -> location -- PRODUCT -> product -- EVENT -> event -- DATE/TIME -> time -- CARDINAL/QUANTITY -> duration (only if units present: min/hour/day/week/month) -- WORK_OF_ART/LAW -> topic (fallback) -- NORP -> topic (fallback) + If limited to 5-7 additional files, prioritize these high-impact locations: -## GLiNER label set (meeting tuned) -- person, org, product, app, location, time, time_relative, duration, event, task, decision, topic + Tier 1: Critical Gaps (Add These First) -## Eval content (meeting-like snippets) -1) "We should ping Pixel about the shipment, maybe next week." - - gold: Pixel=product, next week=time_relative -2) "Let’s meet at Navi around 3ish tomorrow." - - gold: Navi=location, 3ish=time, tomorrow=time_relative -3) "Paris is on the agenda, but it’s just a placeholder." - - gold: Paris=location, agenda=topic -4) "The demo for Figma went well last night." - - gold: demo=event, Figma=product, last night=time_relative -5) "I spent 20 minutes fixing the auth issue." - - gold: 20 minutes=duration, auth issue=topic -6) "Can you file a Jira for the login bug?" - - gold: Jira=product, login bug=task -7) "Sam and Priya agreed to ship v2 by Friday." - - gold: Sam=person, Priya=person, Friday=time -8) "We decided to drop the old onboarding flow." - - gold: decision=decision, onboarding flow=topic -9) "Let’s sync after the standup." - - gold: sync=event, standup=event -10) "I heard Maya mention Notion for notes." - - gold: Maya=person, Notion=product, notes=topic -11) "Zoom audio was bad during the call." - - gold: Zoom=product, call=event -12) "We should ask Alex about the budget." - - gold: Alex=person, budget=topic -13) "Let’s revisit this in two weeks." - - gold: two weeks=duration -14) "The patch for Safari is done." - - gold: Safari=product, patch=task -15) "We’ll meet at the cafe near Union Station." - - gold: cafe=location, Union Station=location + 1. /src/noteflow/infrastructure/CLAUDE.md -## Data-backed comparison (local eval) -- Setup: spaCy `en_core_web_sm` with mapping rules above vs GLiNER `urchade/gliner_medium-v2.1` with label set above. -- Dataset: 15 meeting-like snippets (see eval content section). -- Results: - - spaCy: precision 0.34, recall 0.19 (tp=7, fp=10, fn=26) - - GLiNER: precision 0.36, recall 0.36 (tp=11, fp=17, fn=22) + Why: Infrastructure layer has 15+ adapters with distinct patterns (ASR, diarization, NER, summarization, calendar, webhooks, persistence). No unified guidance exists. -## Recommendation based on results -- GLiNER outperformed spaCy on recall (0.36 vs 0.19) while maintaining slightly higher precision, which is more aligned with meeting transcripts where relevance is tied to custom labels and informal mentions. -- Use GLiNER as the primary extractor, then apply a post-processing filter for low-signal entities and a normalization layer for time/duration to improve precision. + 2. /src/noteflow/domain/CLAUDE.md -## Implementation plan: GLiNER + post-processing with swap-friendly architecture (detailed + references) + Why: Domain layer defines entities, ports, rules, and value objects. Understanding DDD boundaries prevents architectural violations. -### Overview -Goal: Add GLiNER while keeping a clean boundary so switching back to spaCy or another backend is a config-only change. This uses three layers: -1) Backend (GLiNER/spaCy) returns `RawEntity` only. -2) Post-processor normalizes and filters `RawEntity`. -3) Mapper converts `RawEntity` -> domain `NamedEntity`. + 3. /src/noteflow/application/services/CLAUDE.md -### Step 0: Identify the existing entry points -- Current NER engine: `src/noteflow/infrastructure/ner/engine.py`. -- Domain entity type: `src/noteflow/domain/entities/named_entity.py` (for `EntityCategory`, `NamedEntity`). -- Service creation: `src/noteflow/grpc/startup/services.py` (selects NER engine). -- Feature flag: `src/noteflow/config/settings/_features.py` (`ner_enabled`). + Why: 12+ services with distinct responsibilities. Service-level guidance prevents duplication and clarifies orchestration patterns. -### Step 1: Add backend-agnostic models -Create new file: `src/noteflow/infrastructure/ner/backends/types.py` -- Add `RawEntity` dataclass: - - `text: str`, `label: str`, `start: int`, `end: int`, `confidence: float | None` - - `label` must be lowercase (e.g., `person`, `time_relative`). -- Add `NerBackend` protocol: - - `def extract(self, text: str) -> list[RawEntity]: ...` + 4. /client/src/hooks/CLAUDE.md -### Step 2: Implement GLiNER backend -Create `src/noteflow/infrastructure/ner/backends/gliner_backend.py` -- Lazy-load model to avoid startup delay (pattern copied from `src/noteflow/infrastructure/ner/engine.py`). -- Import: `from gliner import GLiNER`. -- Model: `GLiNER.from_pretrained("urchade/gliner_medium-v2.1")`. -- Labels come from settings; if missing, default to the meeting label list in this doc. -- Convert GLiNER entity dicts into `RawEntity`: - - `text`: `entity["text"]` - - `label`: `entity["label"].lower()` - - `start`: `entity["start"]` - - `end`: `entity["end"]` - - `confidence`: `entity.get("score")` -- Keep thresholds configurable (`NOTEFLOW_NER_GLINER_THRESHOLD`), default `0.5`. + Why: 7 hook directories (audio, auth, data, processing, recording, sync, ui) with complex interdependencies. Prevents reinventing existing hooks. -### Step 3: Move spaCy logic into a backend -Create `src/noteflow/infrastructure/ner/backends/spacy_backend.py` -- Move current spaCy loading logic from `src/noteflow/infrastructure/ner/engine.py`. -- Keep the model constants in `src/noteflow/config/constants.py`. -- Use the existing spaCy mapping and skip rules; move them into this backend or a shared module. -- Convert `Doc.ents` into `RawEntity` with lowercase labels. + 5. /client/src-tauri/src/CLAUDE.md -### Step 4: Create shared post-processing pipeline -Create `src/noteflow/infrastructure/ner/post_processing.py` (pure functions only, no class): -- `normalize_text(text: str) -> str` - - `lower()`, collapse whitespace, trim. -- `dedupe_entities(entities: list[RawEntity]) -> list[RawEntity]` - - key: normalized text; keep highest confidence if both present. -- `drop_low_signal_entities(entities: list[RawEntity], text: str) -> list[RawEntity]` - - drop profanity, short tokens (`len <= 2`), numeric-only entities. - - list should live in `src/noteflow/domain/constants/` or `src/noteflow/config/constants.py`. -- `merge_time_phrases(entities: list[RawEntity], text: str) -> list[RawEntity]` - - If two time-like entities are adjacent in the original text, merge into one span. -- `infer_duration(entities: list[RawEntity]) -> list[RawEntity]` - - If entity contains duration units (`minute`, `hour`, `day`, `week`, `month`), set label to `duration`. + Why: Rust backend has commands, gRPC client, audio processing, state management. No Rust-specific guidance currently exists. -### Step 5: Mapping to domain entities -Create `src/noteflow/infrastructure/ner/mapper.py` -- `map_raw_to_named(entities: list[RawEntity]) -> list[NamedEntity]`. -- Convert labels to `EntityCategory` from `src/noteflow/domain/entities/named_entity.py`. -- Confidence rules: - - If `RawEntity.confidence` exists, use it. - - Else use default 0.8 (same as current spaCy fallback in `engine.py`). + Tier 2: High Value (Add Next) -### Step 6: Update NerEngine to use composition -Edit `src/noteflow/infrastructure/ner/engine.py`: -- Remove direct spaCy dependency from `NerEngine`. -- Add constructor params: - - `backend: NerBackend`, `post_processor: Callable`, `mapper: NerMapper`. -- Flow in `extract`: - 1) `raw = backend.extract(text)` - 2) `raw = post_process(raw, text)` - 3) `entities = mapper.map_raw_to_named(raw)` - 4) Deduplicate by normalized text (use current logic, or move to post-processing). + 6. /tests/CLAUDE.md -### Step 7: Configurable backend selection -Edit `src/noteflow/grpc/startup/services.py`: -- Add config: `NOTEFLOW_NER_BACKEND` (default `spacy`). -- If value is `gliner`, instantiate `GLiNERBackend`; else `SpacyBackend`. -- Continue to respect `get_feature_flags().ner_enabled` in `create_ner_service`. + Why: Testing conventions (fixtures, markers, quality gates) are scattered. Centralized guidance improves test quality. -### Step 8: Tests (behavior + user flow + response quality) -Create the following tests and use global fixtures from `tests/conftest.py` (do not redefine them): -- Use `mock_uow`, `sample_meeting`, `meeting_id`, `sample_datetime`, `approx_sequence` where helpful. -- Avoid loops/conditionals in tests; use `@pytest.mark.parametrize`. + 7. /src/noteflow/infrastructure/persistence/CLAUDE.md -**Unit tests** -- `tests/ner/test_post_processing.py` - - Validate normalization, dedupe, time merge, duration inference, and profanity drop. - - Use small, deterministic inputs and assert exact `RawEntity` outputs. -- `tests/ner/test_mapper.py` - - Map every label to `EntityCategory` explicitly. - - Assert confidence fallback when `RawEntity.confidence=None`. + Why: UnitOfWork pattern, repository hierarchy, capability flags, migrations are complex. Prevents incorrect persistence patterns. -**Behavior tests** -- `tests/ner/test_engine_behavior.py` - - Use a stub backend to simulate GLiNER outputs with overlapping spans. - - Verify `NerEngine.extract` returns stable, deduped, correctly categorized entities. + --- + Part 2: Unlimited Placement (Comprehensive Coverage) -**User-flow tests** -- `tests/ner/test_service_flow.py` - - Use `mock_uow` and `sample_meeting` to simulate service flow that calls NER after a meeting completes. - - Assert entities are persisted with expected categories and segment IDs. + With no constraints, here's the complete list of 25+ locations where CLAUDE.md would add value: -**Response quality tests** -- `tests/ner/test_quality_contract.py` - - Feed short, meeting-like snippets and assert: - - No profanity entities survive. - - Time phrases are merged. - - Product vs person is not mis-typed for known examples (Pixel, Notion, Zoom). + Python Backend (src/noteflow/) + ┌────────────────────────────────────────┬─────────────────────────────────────────────────────────┐ + │ Path │ Content Focus │ + ├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ domain/CLAUDE.md │ DDD entities, ports, value objects, rules engine │ + ├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ domain/entities/CLAUDE.md │ Entity relationships, state machines, invariants │ + ├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ domain/ports/CLAUDE.md │ Repository protocols, capability contracts │ + ├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ domain/rules/CLAUDE.md │ Rule modes (SIMPLE→EXPRESSION), registry, evaluation │ + ├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ application/CLAUDE.md │ Use case organization, service boundaries │ + ├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ application/services/CLAUDE.md │ Service catalog, dependency patterns │ + ├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ infrastructure/CLAUDE.md │ Adapter patterns, external integrations │ + ├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ infrastructure/asr/CLAUDE.md │ Whisper, VAD, segmentation, streaming │ + ├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ infrastructure/diarization/CLAUDE.md │ Job lifecycle, streaming vs offline, speaker assignment │ + ├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ infrastructure/ner/CLAUDE.md │ Backend abstraction, mapper, post-processing │ + ├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ infrastructure/summarization/CLAUDE.md │ Provider protocols, consent workflow, citation linking │ + ├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ infrastructure/persistence/CLAUDE.md │ UnitOfWork, repositories, migrations │ + ├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ infrastructure/calendar/CLAUDE.md │ OAuth flow, sync patterns, trigger detection │ + ├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ infrastructure/webhooks/CLAUDE.md │ Delivery, signing, retry logic │ + ├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ grpc/mixins/CLAUDE.md │ Mixin composition, streaming handlers │ + ├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ grpc/startup/CLAUDE.md │ Service initialization, dependency injection │ + ├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ config/CLAUDE.md │ Settings cascade, feature flags, environment loading │ + ├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ cli/CLAUDE.md │ Command patterns, model management │ + └────────────────────────────────────────┴─────────────────────────────────────────────────────────┘ + Client (client/) + ┌──────────────────────────────────┬─────────────────────────────────────────────────────────┐ + │ Path │ Content Focus │ + ├──────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ src/api/CLAUDE.md │ Adapter pattern, transport abstraction, type generation │ + ├──────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ src/components/CLAUDE.md │ Component hierarchy, feature organization │ + ├──────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ src/hooks/CLAUDE.md │ Hook catalog, composition patterns, state management │ + ├──────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ src/lib/CLAUDE.md │ Utility catalog, AI providers, audio processing │ + ├──────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ src-tauri/src/CLAUDE.md │ Rust patterns, command handlers, state │ + ├──────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ src-tauri/src/commands/CLAUDE.md │ IPC contract, audio commands, recording session │ + ├──────────────────────────────────┼─────────────────────────────────────────────────────────┤ + │ src-tauri/src/grpc/CLAUDE.md │ gRPC client wrapper, type conversions │ + └──────────────────────────────────┴─────────────────────────────────────────────────────────┘ + Testing (tests/) + ┌─────────────────────────────┬────────────────────────────────────────────────────┐ + │ Path │ Content Focus │ + ├─────────────────────────────┼────────────────────────────────────────────────────┤ + │ tests/CLAUDE.md │ Test conventions, fixtures, markers, quality gates │ + ├─────────────────────────────┼────────────────────────────────────────────────────┤ + │ tests/fixtures/CLAUDE.md │ Shared fixtures catalog, usage patterns │ + ├─────────────────────────────┼────────────────────────────────────────────────────┤ + │ tests/integration/CLAUDE.md │ Integration test setup, testcontainers │ + └─────────────────────────────┴────────────────────────────────────────────────────┘ + Documentation (docs/) + ┌────────────────────────┬───────────────────────────────────────────┐ + │ Path │ Content Focus │ + ├────────────────────────┼───────────────────────────────────────────┤ + │ docs/sprints/CLAUDE.md │ Sprint structure, documentation standards │ + └────────────────────────┴───────────────────────────────────────────┘ + --- + Part 3: Mockup - /src/noteflow/infrastructure/CLAUDE.md -### Test matrix (inputs -> expected outputs) -Use these exact cases in `tests/ner/test_quality_contract.py` with `@pytest.mark.parametrize`: -1) Text: "We should ping Pixel about the shipment, maybe next week." - - Expect: `Pixel=product`, `next week=time_relative` -2) Text: "Let’s meet at Navi around 3ish tomorrow." - - Expect: `Navi=location`, `3ish=time`, `tomorrow=time_relative` -3) Text: "The demo for Figma went well last night." - - Expect: `demo=event`, `Figma=product`, `last night=time_relative` -4) Text: "I spent 20 minutes fixing the auth issue." - - Expect: `20 minutes=duration`, `auth issue=topic` -5) Text: "Can you file a Jira for the login bug?" - - Expect: `Jira=product`, `login bug=task` -6) Text: "Zoom audio was bad during the call." - - Expect: `Zoom=product`, `call=event` -7) Text: "We should ask Alex about the budget." - - Expect: `Alex=person`, `budget=topic` -8) Text: "Let’s revisit this in two weeks." - - Expect: `two weeks=duration` -9) Text: "We decided to drop the old onboarding flow." - - Expect: `decision=decision`, `onboarding flow=topic` -10) Text: "This is shit, totally broken." - - Expect: no entity with text "shit" -11) Text: "Last night we met in Paris." - - Expect: `last night=time_relative`, `Paris=location` -12) Text: "We can sync after the standup next week." - - Expect: `sync=event`, `standup=event`, `next week=time_relative` -13) Text: "PIXEL and pixel are the same product." - - Expect: one `pixel=product` after dedupe -14) Text: "Met with Navi, NAVI, and navi." - - Expect: one `navi=location` after dedupe -15) Text: "It happened 2 days ago." - - Expect: `2 days=duration` or `2 days=time_relative` (pick one, but be consistent) -16) Text: "We should ask 'Sam' and Sam about the release." - - Expect: one `sam=person` -17) Text: "Notion and Figma docs are ready." - - Expect: `Notion=product`, `Figma=product` -18) Text: "Meet at Union Station, then Union station again." - - Expect: one `union station=location` -19) Text: "We need a patch for Safari and safari iOS." - - Expect: one `safari=product` and `patch=task` -20) Text: "App 'Navi' crashed during the call." - - Expect: `Navi=product` (overrides location when app context exists) + # Infrastructure Layer Development Guide -### Test assertions guidance -- Always assert exact normalized text + label pairs. -- Assert counts (expected length) to avoid hidden extras. -- For time merge: verify only one entity span for "last night" or "next week". -- For profanity: assert it does not appear in any entity text. -- For casing/duplicates: assert one normalized entity only. -- For override rules: assert app/product context wins over location if explicit. + ## Overview -### Step 9: Strict exit criteria -Implementation is only done if ALL are true: -- `make quality` passes. -- No compatibility wrappers or adapter shims added. -- No legacy code paths left in place (spaCy backend must be an actual backend, not dead code). -- No temp artifacts or experimental files left behind. + The infrastructure layer (`src/noteflow/infrastructure/`) contains adapters that implement domain ports. These connect the application to external systems: databases, ML + models, cloud APIs, file systems. -### Example backend usage (pseudo) -``` -backend = GLiNERBackend(labels=settings.ner_labels, threshold=settings.ner_threshold) -raw = backend.extract(text) -raw = post_process(raw, text) -entities = mapper.map_raw_to_named(raw) -``` + --- -### Outcome -- Switching NER backends is now a 1-line config change. -- Post-processing is shared and consistent across backends. -- Domain entities remain unchanged for callers. + ## Architecture Principle: Hexagonal/Ports-and-Adapters -## Simple eval scoring checklist -- Precision per label -- Recall per label -- Actionable rate (task/decision/event/topic) -- Type confusion counts (person<->product, org<->product, location<->org, time<->duration) + Domain Ports (interfaces) Infrastructure Adapters (implementations) + ───────────────────────── ─────────────────────────────────────────── + NerPort → SpacyBackend, GlinerBackend + SummarizationProvider → CloudProvider, OllamaProvider, MockProvider + DiarizationEngine → DiartSession, PyannoteOffline + AssetRepository → FileSystemAssetRepository + UnitOfWork → SqlAlchemyUnitOfWork, MemoryUnitOfWork + CalendarProvider → GoogleCalendar, OutlookCalendar -## Next steps to validate -- Run GLiNER on sample transcript and compare entity yield vs spaCy pipeline. -- Apply the spaCy label mapping above to ensure apples-to-apples scoring. -- Add a lightweight rule layer for time_relative/duration normalization. + **Rule**: Infrastructure code imports domain; domain NEVER imports infrastructure. + + --- + + ## Adapter Catalog + + | Directory | Responsibility | Key Protocols | + |-----------|----------------|---------------| + | `asr/` | Speech-to-text (Whisper) | `TranscriptionResult` | + | `diarization/` | Speaker identification | `DiarizationEngine`, `DiarizationJob` | + | `ner/` | Named entity extraction | `NerPort` | + | `summarization/` | LLM summarization | `SummarizationProvider` | + | `persistence/` | Database (SQLAlchemy) | `UnitOfWork`, `*Repository` | + | `calendar/` | OAuth + event sync | `CalendarProvider` | + | `webhooks/` | Event delivery | `WebhookDeliveryService` | + | `export/` | PDF/HTML/Markdown | `ExportAdapter` | + | `audio/` | Recording/playback | `AudioDevice` | + | `crypto/` | Encryption | `Keystore` | + | `logging/` | Structured logging | `LogEventType` | + | `metrics/` | Observability | `MetricsCollector` | + | `gpu/` | GPU detection | `GpuInfo` | + + --- + + ## Common Patterns + + ### 1. Async Wrappers for Sync Libraries + + Many ML libraries (spaCy, faster-whisper) are synchronous. Wrap them: + + ```python + async def extract(self, text: str) -> list[NamedEntity]: + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, # Default ThreadPoolExecutor + self._sync_extract, + text + ) + + 2. Backend Selection via Factory + + def create_ner_engine(config: NerConfig) -> NerPort: + match config.backend: + case "spacy": + return SpacyBackend(model=config.model_name) + case "gliner": + return GlinerBackend(model=config.model_name) + case _: + raise ValueError(f"Unknown NER backend: {config.backend}") + + 3. Capability Flags for Optional Features + + class SqlAlchemyUnitOfWork(UnitOfWork): + @property + def supports_entities(self) -> bool: + return True # Has EntityRepository + + @property + def supports_webhooks(self) -> bool: + return True # Has WebhookRepository + + Always check capability before accessing optional repository: + if uow.supports_entities: + entities = await uow.entities.get_by_meeting(meeting_id) + + 4. Provider Protocol Pattern + + class SummarizationProvider(Protocol): + async def summarize( + self, + segments: list[Segment], + template: SummarizationTemplate, + ) -> SummaryResult: ... + + @property + def requires_consent(self) -> bool: ... + + --- + Forbidden Patterns + + ❌ Direct database access outside persistence/ + # WRONG: Raw SQL in service layer + async with engine.connect() as conn: + result = await conn.execute(text("SELECT * FROM meetings")) + + ❌ Hardcoded API keys + # WRONG: Secrets in code + client = anthropic.Anthropic(api_key="sk-ant-...") + + ❌ Synchronous I/O in async context + # WRONG: Blocking the event loop + def load_model(self): + self.model = whisper.load_model("base") # Blocks! + + ❌ Domain imports in infrastructure + # WRONG: Infrastructure should implement domain ports, not modify domain + from noteflow.domain.entities import Meeting + meeting.state = "COMPLETED" # Don't mutate domain objects here + + --- + Testing Infrastructure Adapters + + Use Dependency Injection for Mocking + + # tests/infrastructure/ner/test_engine.py + @pytest.fixture + def mock_backend() -> NerBackend: + backend = Mock(spec=NerBackend) + backend.extract.return_value = [ + RawEntity(text="John", label="PERSON", start=0, end=4) + ] + return backend + + async def test_engine_uses_backend(mock_backend): + engine = NerEngine(backend=mock_backend) + result = await engine.extract("Hello John") + mock_backend.extract.assert_called_once() + + Integration Tests with Real Services + + # tests/integration/test_ner_integration.py + @pytest.mark.integration + @pytest.mark.requires_gpu + async def test_gliner_real_extraction(): + backend = GlinerBackend(model="urchade/gliner_base") + result = await backend.extract("Microsoft CEO Satya Nadella announced...") + assert any(e.label == "ORG" and "Microsoft" in e.text for e in result) + + --- + Adding a New Adapter + + 1. Define port in domain (domain/ports/) if not exists + 2. Create adapter directory (infrastructure//) + 3. Implement the protocol with proper async handling + 4. Add factory function for backend selection + 5. Write unit tests with mocked dependencies + 6. Write integration test with real external service + 7. Update gRPC startup (grpc/startup/services.py) for dependency injection + 8. Document in this file (update Adapter Catalog table) + + --- + Key Files + ┌───────────────────────────────┬──────────────────────────────────┐ + │ File │ Purpose │ + ├───────────────────────────────┼──────────────────────────────────┤ + │ __init__.py │ Public exports for layer │ + ├───────────────────────────────┼──────────────────────────────────┤ + │ */engine.py │ Main adapter implementation │ + ├───────────────────────────────┼──────────────────────────────────┤ + │ */backends/ │ Multiple backend implementations │ + ├───────────────────────────────┼──────────────────────────────────┤ + │ */mapper.py │ External→Domain type conversion │ + ├───────────────────────────────┼──────────────────────────────────┤ + │ */post_processing.py │ Output normalization │ + ├───────────────────────────────┼──────────────────────────────────┤ + │ persistence/unit_of_work/*.py │ Transaction management │ + ├───────────────────────────────┼──────────────────────────────────┤ + │ persistence/repositories/*.py │ Data access │ + ├───────────────────────────────┼──────────────────────────────────┤ + │ persistence/models/*.py │ ORM definitions │ + └───────────────────────────────┴──────────────────────────────────┘ + --- + See Also + + - /src/noteflow/domain/ports/ — Port definitions + - /src/noteflow/grpc/startup/services.py — Dependency injection + - /tests/infrastructure/ — Adapter tests + + --- + + ## Part 4: Answer to "Would Your Answer Change With No Limit?" + + **Yes, significantly.** + + ### Constrained (5-7 files): + Focus on **layer boundaries** (domain, application, infrastructure) and **high-complexity areas** (hooks, Rust backend). Each file covers broad territory. + + ### Unlimited (25+ files): + Add **subsystem-specific documentation** for: + - Complex state machines (diarization jobs, recording lifecycle) + - Protocol patterns (summarization providers, NER backends) + - Cross-cutting concerns (rules engine, settings cascade) + - Test organization (fixtures, integration setup) + + The key difference: with unlimited resources, document **WHY decisions were made** (design rationale), not just **WHAT exists** (API reference). + + --- + + ## Recommendation + + ### Immediate Action (Phase 1) + Add these 3 files for maximum impact: + 1. `/src/noteflow/infrastructure/CLAUDE.md` — Adapter patterns (mockup above) + 2. `/src/noteflow/domain/CLAUDE.md` — DDD boundaries, entity relationships + 3. `/client/src-tauri/src/CLAUDE.md` — Rust patterns, IPC contracts + + ### Follow-up (Phase 2) + 4. `/src/noteflow/application/services/CLAUDE.md` — Service catalog + 5. `/client/src/hooks/CLAUDE.md` — Hook organization + 6. `/tests/CLAUDE.md` — Testing conventions + + ### Future (Phase 3) + Remaining 19+ files as the codebase grows and patterns stabilize. \ No newline at end of file diff --git a/.mcp.json b/.mcp.json index e0dcd23..83fb9dc 100644 --- a/.mcp.json +++ b/.mcp.json @@ -3,6 +3,18 @@ "lightrag-mcp": { "type": "sse", "url": "http://192.168.50.185:8150/sse" + }, + "firecrawl": { + "type": "stdio", + "command": "npx", + "args": [ + "-y", + "firecrawl-mcp" + ], + "env": { + "FIRECRAWL_API_URL": "http://crawl.toy", + "FIRECRAWL_API_KEY": "dummy-key" + } } } } \ No newline at end of file diff --git a/client/src-tauri/scripts/code_quality.sh b/client/src-tauri/scripts/code_quality.sh index eaa3257..129b280 100755 --- a/client/src-tauri/scripts/code_quality.sh +++ b/client/src-tauri/scripts/code_quality.sh @@ -245,12 +245,16 @@ echo "" # Check 6: Deep nesting (> 7 levels = 28 spaces) # Thresholds: 20=5 levels, 24=6 levels, 28=7 levels # 7 levels allows for async patterns: spawn + block_on + loop + select + match + if + body -# Only excludes: generated files, async bidirectional streaming (inherently deep) +# Excludes: generated files, streaming (inherently deep), tracing macro field formatting echo "Checking for deep nesting..." DEEP_NESTING=$(grep -rn --include="*.rs" $GREP_EXCLUDES -E '^[[:space:]]{28,}[^[:space:]]' "$TAURI_SRC" \ | grep -v '//' \ | grep -v 'noteflow.rs' \ | grep -v 'streaming.rs' \ + | grep -v 'tracing::' \ + | grep -v ' = %' \ + | grep -v ' = identity\.' \ + | grep -v '[[:space:]]"[A-Z]' \ | head -20 || true) if [ -n "$DEEP_NESTING" ]; then diff --git a/client/src-tauri/src/audio/windows_loopback.rs b/client/src-tauri/src/audio/windows_loopback.rs index 01c886e..8ceeb5d 100644 --- a/client/src-tauri/src/audio/windows_loopback.rs +++ b/client/src-tauri/src/audio/windows_loopback.rs @@ -47,6 +47,10 @@ pub struct WasapiLoopbackHandle { #[cfg(target_os = "windows")] impl WasapiLoopbackHandle { pub fn stop(mut self) { + self.stop_internal(); + } + + fn stop_internal(&mut self) { let _ = self.stop_tx.send(()); if let Some(join) = self.join.take() { let _ = join.join(); @@ -54,6 +58,13 @@ impl WasapiLoopbackHandle { } } +#[cfg(target_os = "windows")] +impl Drop for WasapiLoopbackHandle { + fn drop(&mut self) { + self.stop_internal(); + } +} + /// Start WASAPI loopback capture on a background thread. #[cfg(target_os = "windows")] pub fn start_wasapi_loopback_capture( @@ -177,7 +188,6 @@ where F: FnMut(&[f32]), { initialize_mta() - .ok() .map_err(|err| Error::AudioCapture(format!("WASAPI init failed: {err}")))?; let result: Result<()> = (|| { diff --git a/client/src-tauri/src/events/mod.rs b/client/src-tauri/src/events/mod.rs index da1d265..f685699 100644 --- a/client/src-tauri/src/events/mod.rs +++ b/client/src-tauri/src/events/mod.rs @@ -302,11 +302,12 @@ async fn run_event_loop(app: AppHandle, mut rx: broadcast::Receiver) { /// during Tauri's setup hook, before the main async runtime is fully initialized. /// /// Returns the `JoinHandle` so the caller can wait for graceful shutdown. +/// Returns `None` if thread spawning fails (logged as error). pub fn start_event_emitter( app: AppHandle, rx: broadcast::Receiver, -) -> std::thread::JoinHandle<()> { - std::thread::Builder::new() +) -> Option> { + match std::thread::Builder::new() .name("noteflow-event-emitter".to_string()) .spawn(move || { let rt = match tokio::runtime::Builder::new_current_thread() @@ -326,8 +327,17 @@ pub fn start_event_emitter( rt.block_on(run_event_loop(app, rx)); tracing::debug!("Event emitter thread exiting"); - }) - .expect("Failed to spawn event emitter thread") + }) { + Ok(handle) => Some(handle), + Err(e) => { + tracing::error!( + error = %e, + subsystem = "event_emitter", + "Failed to spawn event emitter thread - frontend events disabled" + ); + None + } + } } #[cfg(test)] diff --git a/client/src-tauri/src/grpc/client/annotations.rs b/client/src-tauri/src/grpc/client/annotations.rs index 01037f0..3f4f0ea 100644 --- a/client/src-tauri/src/grpc/client/annotations.rs +++ b/client/src-tauri/src/grpc/client/annotations.rs @@ -180,10 +180,10 @@ impl GrpcClient { .await? .into_inner(); - Ok(response + response .entity .map(convert_entity) - .expect("UpdateEntity response should contain entity")) + .ok_or_else(|| crate::error::Error::Stream("UpdateEntity response missing entity".into())) } /// Delete a named entity. diff --git a/client/src-tauri/src/identity/mod.rs b/client/src-tauri/src/identity/mod.rs index e64593b..9d14300 100644 --- a/client/src-tauri/src/identity/mod.rs +++ b/client/src-tauri/src/identity/mod.rs @@ -144,34 +144,44 @@ impl IdentityStore { } fn load_identity_from_keychain(&self) { - if let Ok(identity_json) = self.get_keychain_value(identity_config::IDENTITY_KEY) { - match serde_json::from_str::(&identity_json) { - Ok(identity) => { - tracing::info!( - user_id = %identity.user_id, - is_local = identity.is_local, - "Loaded identity from keychain" - ); - *self.identity.write() = Some(identity); - } - Err(e) => { - tracing::warn!(error = %e, "Failed to parse stored identity, using default"); + match self.get_keychain_value(identity_config::IDENTITY_KEY) { + Ok(identity_json) => { + match serde_json::from_str::(&identity_json) { + Ok(identity) => { + tracing::info!( + user_id = %identity.user_id, + is_local = identity.is_local, + "Loaded identity from keychain" + ); + *self.identity.write() = Some(identity); + } + Err(e) => { + tracing::warn!(error = %e, "Failed to parse stored identity, using default"); + } } } + Err(e) => { + tracing::debug!(error = %e, "No identity in keychain (first run or cleared)"); + } } } fn load_tokens_from_keychain(&self) { - if let Ok(token_json) = self.get_keychain_value(identity_config::AUTH_TOKEN_KEY) { - match serde_json::from_str::(&token_json) { - Ok(tokens) => { - tracing::debug!("Loaded auth tokens from keychain"); - *self.tokens.write() = Some(tokens); - } - Err(e) => { - tracing::warn!(error = %e, "Failed to parse stored tokens"); + match self.get_keychain_value(identity_config::AUTH_TOKEN_KEY) { + Ok(token_json) => { + match serde_json::from_str::(&token_json) { + Ok(tokens) => { + tracing::debug!("Loaded auth tokens from keychain"); + *self.tokens.write() = Some(tokens); + } + Err(e) => { + tracing::warn!(error = %e, "Failed to parse stored tokens"); + } } } + Err(e) => { + tracing::debug!(error = %e, "No auth tokens in keychain"); + } } } diff --git a/client/src-tauri/src/lib.rs b/client/src-tauri/src/lib.rs index 9d89515..688d726 100644 --- a/client/src-tauri/src/lib.rs +++ b/client/src-tauri/src/lib.rs @@ -297,11 +297,12 @@ fn setup_app_state( // Start event emitter and track the thread handle let app_handle = app.handle().clone(); - let event_emitter_handle = events::start_event_emitter(app_handle, event_tx.subscribe()); - - // Store event emitter handle in shutdown manager - if let Some(shutdown_mgr) = app.try_state::>() { - shutdown_mgr.set_event_emitter_handle(event_emitter_handle); + if let Some(event_emitter_handle) = + events::start_event_emitter(app_handle, event_tx.subscribe()) + { + if let Some(shutdown_mgr) = app.try_state::>() { + shutdown_mgr.set_event_emitter_handle(event_emitter_handle); + } } // Start trigger polling (foreground app + audio activity detection) diff --git a/client/src-tauri/src/state/app_state.rs b/client/src-tauri/src/state/app_state.rs index 535bb4b..d42e68c 100644 --- a/client/src-tauri/src/state/app_state.rs +++ b/client/src-tauri/src/state/app_state.rs @@ -7,8 +7,7 @@ use std::collections::VecDeque; use std::path::PathBuf; use std::sync::Arc; -use parking_lot::RwLock; -use tokio::sync::Mutex; +use parking_lot::{Mutex, RwLock}; use crate::audio::PlaybackHandle; use crate::config; @@ -436,7 +435,7 @@ impl AppState { pub fn get_trigger_status(&self) -> TriggerStatus { let (snoozed, snooze_remaining) = self .trigger_service - .blocking_lock() + .lock() .as_ref() .map(|service| (service.is_snoozed(), service.snooze_remaining_seconds())) .unwrap_or((false, 0.0)); diff --git a/client/src/components/features/calendar/upcoming-meetings.tsx b/client/src/components/features/calendar/upcoming-meetings.tsx index 6dd0354..31112dc 100644 --- a/client/src/components/features/calendar/upcoming-meetings.tsx +++ b/client/src/components/features/calendar/upcoming-meetings.tsx @@ -99,9 +99,10 @@ function CalendarErrorState({ onRetry, isRetrying }: { onRetry: () => void; isRe } export function UpcomingMeetings({ maxEvents = 10 }: UpcomingMeetingsProps) { - const integrations = preferences.getIntegrations(); - const calendarIntegrations = integrations.filter((i) => i.type === 'calendar'); - const connectedCalendars = calendarIntegrations.filter((i) => i.status === 'connected'); + const connectedCalendars = useMemo(() => { + const integrations = preferences.getIntegrations(); + return integrations.filter((i) => i.type === 'calendar' && i.status === 'connected'); + }, []); // Use live calendar API instead of mock data const { state, fetchEvents } = useCalendarSync({ diff --git a/client/src/hooks/auth/use-auth-flow.ts b/client/src/hooks/auth/use-auth-flow.ts index d739ed6..d7b28b1 100644 --- a/client/src/hooks/auth/use-auth-flow.ts +++ b/client/src/hooks/auth/use-auth-flow.ts @@ -6,6 +6,7 @@ import { getAPI } from '@/api/interface'; import { isTauriEnvironment } from '@/api'; import type { GetCurrentUserResponse } from '@/api/types'; import { toast } from '@/hooks/ui/use-toast'; +import { addClientLog } from '@/lib/observability/client'; import { toastError } from '@/lib/observability/errors'; import { extractOAuthCallback, @@ -141,14 +142,22 @@ export function useAuthFlow(): UseAuthFlowReturn { } }; - void setupDeepLinkListener(handleDeepLinkCallback).then((c) => { - if (unmounted) { - // Component unmounted before setup completed, call cleanup immediately - c(); - } else { - cleanup = c; - } - }); + setupDeepLinkListener(handleDeepLinkCallback) + .then((c) => { + if (unmounted) { + c(); + } else { + cleanup = c; + } + }) + .catch((err) => { + addClientLog({ + level: 'warning', + source: 'auth', + message: 'Failed to setup deep link listener', + details: err instanceof Error ? err.message : String(err), + }); + }); return () => { unmounted = true; @@ -254,7 +263,13 @@ export function useAuthFlow(): UseAuthFlowReturn { })); return userInfo; - } catch { + } catch (error) { + addClientLog({ + level: 'warning', + source: 'auth', + message: 'Failed to check auth status', + details: error instanceof Error ? error.message : String(error), + }); return null; } }, []); diff --git a/client/src/hooks/auth/use-oauth-flow.ts b/client/src/hooks/auth/use-oauth-flow.ts index 79a6e24..3269616 100644 --- a/client/src/hooks/auth/use-oauth-flow.ts +++ b/client/src/hooks/auth/use-oauth-flow.ts @@ -7,6 +7,7 @@ import { getAPI } from '@/api/interface'; import { isTauriEnvironment } from '@/api'; import type { OAuthConnection } from '@/api/types'; import { toast } from '@/hooks/ui/use-toast'; +import { addClientLog } from '@/lib/observability/client'; import { toastError } from '@/lib/observability/errors'; import { extractOAuthCallback, setupDeepLinkListener, validateOAuthState } from '@/lib/integrations/oauth'; @@ -132,14 +133,22 @@ export function useOAuthFlow(): UseOAuthFlowReturn { let unmounted = false; - void setupDeepLinkListener(handleDeepLinkCallback).then((c) => { - if (unmounted) { - // Component unmounted before setup completed, call cleanup immediately - c(); - } else { - cleanup = c; - } - }); + setupDeepLinkListener(handleDeepLinkCallback) + .then((c) => { + if (unmounted) { + c(); + } else { + cleanup = c; + } + }) + .catch((err) => { + addClientLog({ + level: 'warning', + source: 'oauth', + message: 'Failed to setup deep link listener', + details: err instanceof Error ? err.message : String(err), + }); + }); return () => { unmounted = true; diff --git a/client/src/hooks/processing/use-diarization.ts b/client/src/hooks/processing/use-diarization.ts index 09d411e..95ab911 100644 --- a/client/src/hooks/processing/use-diarization.ts +++ b/client/src/hooks/processing/use-diarization.ts @@ -10,6 +10,7 @@ import { getAPI } from '@/api/interface'; import type { DiarizationJobStatus, JobStatus } from '@/api/types'; import { toast } from '@/hooks/ui/use-toast'; import { PollingConfig } from '@/lib/config'; +import { addClientLog } from '@/lib/observability/client'; import { errorMessageFrom, toastError } from '@/lib/observability/errors'; /** Diarization job state */ @@ -448,8 +449,13 @@ export function useDiarization(options: UseDiarizationOptions = {}): UseDiarizat } return job; - } catch { - // Recovery failure is non-fatal - continue without recovery + } catch (error) { + addClientLog({ + level: 'warning', + source: 'app', + message: 'Diarization recovery failed (non-fatal)', + details: error instanceof Error ? error.message : String(error), + }); return null; } }, [poll, pollInterval, showToasts]); diff --git a/client/src/hooks/processing/use-entity-extraction.ts b/client/src/hooks/processing/use-entity-extraction.ts index 9eb1b7e..73e372d 100644 --- a/client/src/hooks/processing/use-entity-extraction.ts +++ b/client/src/hooks/processing/use-entity-extraction.ts @@ -69,6 +69,14 @@ export function useEntityExtraction( meetingTitleRef.current = meetingTitle; }, [meetingTitle]); + const mountedRef = useRef(true); + useEffect(() => { + mountedRef.current = true; + return () => { + mountedRef.current = false; + }; + }, []); + const extract = useCallback( async (forceRefresh = false) => { if (!meetingId) { @@ -84,7 +92,7 @@ export function useEntityExtraction( try { const response = await getAPI().extractEntities(meetingId, forceRefresh); - // Use ref to get current title without causing callback recreation + if (!mountedRef.current) return; setEntitiesFromExtraction(response.entities, meetingTitleRef.current); setState((prev) => ({ ...prev, @@ -100,6 +108,7 @@ export function useEntityExtraction( }); } } catch (error) { + if (!mountedRef.current) return; const message = toastError({ title: 'Extraction failed', error, diff --git a/client/src/hooks/recording/use-recording-session.ts b/client/src/hooks/recording/use-recording-session.ts index 480690d..0d03043 100644 --- a/client/src/hooks/recording/use-recording-session.ts +++ b/client/src/hooks/recording/use-recording-session.ts @@ -100,6 +100,14 @@ export function useRecordingSession( // Transcription stream const streamRef = useRef(null); + const mountedRef = useRef(true); + + useEffect(() => { + mountedRef.current = true; + return () => { + mountedRef.current = false; + }; + }, []); // Toast helpers const toastSuccess = useCallback( @@ -264,6 +272,7 @@ export function useRecordingSession( title: meetingTitle || `Recording ${formatDateTime()}`, project_id: projectId, }); + if (!mountedRef.current) return; setMeeting(newMeeting); let stream: TranscriptionStream; @@ -275,6 +284,10 @@ export function useRecordingSession( } else { stream = ensureTranscriptionStream(await api.startTranscription(newMeeting.id)); } + if (!mountedRef.current) { + stream.close(); + return; + } streamRef.current = stream; stream.onUpdate(handleTranscriptUpdate); @@ -285,6 +298,7 @@ export function useRecordingSession( shouldSimulate ? 'Simulation is active' : 'Transcription is now active' ); } catch (error) { + if (!mountedRef.current) return; setRecordingState('idle'); toastError('Failed to start recording', error); } @@ -324,6 +338,7 @@ export function useRecordingSession( streamRef.current = null; const api = shouldSimulate && !isConnected ? mockAPI : getAPI(); const stoppedMeeting = await api.stopMeeting(meeting.id); + if (!mountedRef.current) return; setMeeting(stoppedMeeting); setRecordingState('idle'); toastSuccess( @@ -331,6 +346,7 @@ export function useRecordingSession( shouldSimulate ? 'Simulation finished' : 'Your meeting has been saved' ); } catch (error) { + if (!mountedRef.current) return; setRecordingState('recording'); toastError('Failed to stop recording', error); } diff --git a/client/src/hooks/sync/use-calendar-sync.ts b/client/src/hooks/sync/use-calendar-sync.ts index bdfead3..e6dfbed 100644 --- a/client/src/hooks/sync/use-calendar-sync.ts +++ b/client/src/hooks/sync/use-calendar-sync.ts @@ -5,6 +5,7 @@ import { isIntegrationNotFoundError } from '@/api'; import { getAPI } from '@/api/interface'; import type { CalendarEvent, CalendarProvider } from '@/api/types'; import { toast } from '@/hooks/ui/use-toast'; +import { addClientLog } from '@/lib/observability/client'; import { errorMessageFrom, toastError } from '@/lib/observability/errors'; export type CalendarSyncStatus = 'idle' | 'loading' | 'success' | 'error'; @@ -84,8 +85,13 @@ export function useCalendarSync(options: UseCalendarSyncOptions = {}): UseCalend ...prev, providers: response.providers, })); - } catch { - // Provider fetch failed - non-critical, UI will show empty providers + } catch (error) { + addClientLog({ + level: 'debug', + source: 'sync', + message: 'Calendar provider fetch failed (non-critical)', + details: error instanceof Error ? error.message : String(error), + }); } }, []); diff --git a/client/src/lib/cache/meeting-cache.ts b/client/src/lib/cache/meeting-cache.ts index 511cd85..25f4343 100644 --- a/client/src/lib/cache/meeting-cache.ts +++ b/client/src/lib/cache/meeting-cache.ts @@ -397,6 +397,20 @@ export const meetingCache = { }; // Flush cache on page unload to prevent data loss -if (typeof window !== 'undefined') { - window.addEventListener('beforeunload', flushCache); +let beforeUnloadListenerAdded = false; + +function setupBeforeUnloadListener(): void { + if (typeof window !== 'undefined' && !beforeUnloadListenerAdded) { + window.addEventListener('beforeunload', flushCache); + beforeUnloadListenerAdded = true; + } } + +export function cleanupCacheBeforeUnloadListener(): void { + if (typeof window !== 'undefined' && beforeUnloadListenerAdded) { + window.removeEventListener('beforeunload', flushCache); + beforeUnloadListenerAdded = false; + } +} + +setupBeforeUnloadListener(); diff --git a/client/src/lib/observability/client.ts b/client/src/lib/observability/client.ts index 0176f62..7545c55 100644 --- a/client/src/lib/observability/client.ts +++ b/client/src/lib/observability/client.ts @@ -140,12 +140,24 @@ export function clearClientLogs(): void { } // Flush logs on page unload to prevent data loss -let beforeUnloadListener: (() => void) | null = null; -if (typeof window !== 'undefined') { - beforeUnloadListener = () => flushLogs(); - window.addEventListener('beforeunload', beforeUnloadListener); +let beforeUnloadListenerAdded = false; + +function setupBeforeUnloadListener(): void { + if (typeof window !== 'undefined' && !beforeUnloadListenerAdded) { + window.addEventListener('beforeunload', flushLogs); + beforeUnloadListenerAdded = true; + } } +export function cleanupBeforeUnloadListener(): void { + if (typeof window !== 'undefined' && beforeUnloadListenerAdded) { + window.removeEventListener('beforeunload', flushLogs); + beforeUnloadListenerAdded = false; + } +} + +setupBeforeUnloadListener(); + /** * Reset internal state for testing. Clears pending logs and cancels any scheduled writes. * @internal @@ -157,8 +169,5 @@ export function _resetClientLogsForTesting(): void { writeTimeout = null; } cachedLogs = null; - if (beforeUnloadListener !== null && typeof window !== 'undefined') { - window.removeEventListener('beforeunload', beforeUnloadListener); - beforeUnloadListener = null; - } + cleanupBeforeUnloadListener(); } diff --git a/client/src/pages/NotFound.tsx b/client/src/pages/NotFound.tsx index 6165590..13d2569 100644 --- a/client/src/pages/NotFound.tsx +++ b/client/src/pages/NotFound.tsx @@ -1,10 +1,4 @@ -import { useEffect } from 'react'; -import { useLocation } from 'react-router-dom'; - const NotFound = () => { - const _location = useLocation(); - - useEffect(() => {}, []); return (
diff --git a/docs/sprints/phase-ongoing/sprint-organization/api.md b/docs/sprints/.archive/sprint-organization/api.md similarity index 100% rename from docs/sprints/phase-ongoing/sprint-organization/api.md rename to docs/sprints/.archive/sprint-organization/api.md diff --git a/docs/sprints/phase-ongoing/sprint-organization/components.md b/docs/sprints/.archive/sprint-organization/components.md similarity index 100% rename from docs/sprints/phase-ongoing/sprint-organization/components.md rename to docs/sprints/.archive/sprint-organization/components.md diff --git a/docs/sprints/phase-ongoing/sprint-organization/hooks.md b/docs/sprints/.archive/sprint-organization/hooks.md similarity index 100% rename from docs/sprints/phase-ongoing/sprint-organization/hooks.md rename to docs/sprints/.archive/sprint-organization/hooks.md diff --git a/docs/sprints/phase-ongoing/sprint-organization/lib.md b/docs/sprints/.archive/sprint-organization/lib.md similarity index 100% rename from docs/sprints/phase-ongoing/sprint-organization/lib.md rename to docs/sprints/.archive/sprint-organization/lib.md diff --git a/pyproject.toml b/pyproject.toml index 37ad4df..b1f7f5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,9 @@ pdf = [ ner = [ "spacy>=3.8.11", ] +ner-gliner = [ + "gliner>=0.2.24", +] calendar = [ "google-api-python-client>=2.100", "google-auth>=2.23", @@ -109,8 +112,10 @@ optional = [ "torch>=2.0", # PDF export "weasyprint>=67.0", - # NER + # NER (spaCy backend) "spacy>=3.8.11", + # NER (GLiNER backend) + "gliner>=0.2.24", # Calendar "google-api-python-client>=2.100", "google-auth>=2.23", @@ -122,7 +127,7 @@ optional = [ "opentelemetry-exporter-otlp>=1.28", ] all = [ - "noteflow[audio,dev,triggers,summarization,diarization,pdf,ner,calendar,observability]", + "noteflow[audio,dev,triggers,summarization,diarization,pdf,ner,ner-gliner,calendar,observability]", ] ollama = [ "anthropic>=0.75.0", @@ -173,6 +178,7 @@ plugins = ["sqlalchemy.ext.mypy.plugin"] [[tool.mypy.overrides]] module = [ "diart.*", + "gliner.*", "pyannote.*", "faster_whisper.*", "sounddevice.*", @@ -294,6 +300,7 @@ filterwarnings = [ [dependency-groups] dev = [ "basedpyright>=1.36.1", + "gliner>=0.2.24", "grpc-stubs>=1.53.0.6", "protobuf>=6.33.2", "pyrefly>=0.46.1", @@ -301,6 +308,7 @@ dev = [ "pytest-httpx>=0.36.0", "ruff>=0.14.9", "sourcery; sys_platform == 'darwin'", + "spacy>=3.8.11", "types-grpcio==1.0.0.20251001", "watchfiles>=1.1.1", ] diff --git a/src/noteflow/application/services/asr_config/_job_manager.py b/src/noteflow/application/services/asr_config/_job_manager.py index 14d61f2..d4586f9 100644 --- a/src/noteflow/application/services/asr_config/_job_manager.py +++ b/src/noteflow/application/services/asr_config/_job_manager.py @@ -20,16 +20,28 @@ if TYPE_CHECKING: logger = get_logger(__name__) +DEFAULT_JOB_RETENTION_SECONDS: float = 300.0 # 5 minutes + + class AsrJobManager: """Manage ASR configuration job lifecycle and state. Handles job creation, status tracking, and cleanup during shutdown. + Jobs are automatically removed from the registry after completion/failure + to prevent memory leaks. """ - def __init__(self) -> None: - """Initialize job manager with empty job registry.""" + def __init__(self, job_retention_seconds: float = DEFAULT_JOB_RETENTION_SECONDS) -> None: + """Initialize job manager with empty job registry. + + Args: + job_retention_seconds: How long to keep completed/failed jobs before + removing them from memory. Defaults to 5 minutes. + """ self._jobs: dict[UUID, AsrConfigJob] = {} self._job_lock = asyncio.Lock() + self._job_retention_seconds = job_retention_seconds + self._cleanup_tasks: set[asyncio.Task[None]] = set() async def register_job(self, job: AsrConfigJob) -> None: """Register a new job under the job lock. @@ -52,6 +64,17 @@ class AsrJobManager: coro: The coroutine to run. """ job.task = asyncio.create_task(coro) + job.task.add_done_callback(lambda t: self._job_task_done_callback(t, job)) + + def _job_task_done_callback(self, task: asyncio.Task[None], job: AsrConfigJob) -> None: + if task.cancelled(): + logger.debug("asr_job_task_cancelled", job_id=str(job.job_id)) + return + exc = task.exception() + if exc is not None and job.phase != AsrConfigPhase.FAILED: + logger.error("asr_job_task_unhandled_error", job_id=str(job.job_id), error=str(exc)) + wrapped_exc = exc if isinstance(exc, Exception) else RuntimeError(str(exc)) + self.mark_failed(job, wrapped_exc) def get_job(self, job_id: UUID) -> AsrConfigJob | None: """Get job status by ID. @@ -72,6 +95,7 @@ class AsrJobManager: job.phase = AsrConfigPhase.COMPLETED job.status = JOB_STATUS_COMPLETED job.progress_percent = 100.0 + self._schedule_job_cleanup(job.job_id) def mark_failed(self, job: AsrConfigJob, error: Exception) -> None: """Mark a job as failed with an error. @@ -84,6 +108,21 @@ class AsrJobManager: job.status = JOB_STATUS_FAILED job.error_message = str(error) logger.error("asr_reconfiguration_failed", error=str(error)) + self._schedule_job_cleanup(job.job_id) + + def _schedule_job_cleanup(self, job_id: UUID) -> None: + """Schedule deferred removal of a completed/failed job from memory.""" + task = asyncio.create_task(self._cleanup_job_after_delay(job_id)) + self._cleanup_tasks.add(task) + task.add_done_callback(self._cleanup_tasks.discard) + + async def _cleanup_job_after_delay(self, job_id: UUID) -> None: + """Remove a job from the registry after the retention period.""" + await asyncio.sleep(self._job_retention_seconds) + async with self._job_lock: + removed = self._jobs.pop(job_id, None) + if removed is not None: + logger.debug("job_cleanup_completed", job_id=str(job_id)) def _collect_pending_tasks(self) -> list[asyncio.Task[None]]: """Collect and cancel pending job tasks. diff --git a/src/noteflow/application/services/ner/service.py b/src/noteflow/application/services/ner/service.py index edd3b8f..33369e7 100644 --- a/src/noteflow/application/services/ner/service.py +++ b/src/noteflow/application/services/ner/service.py @@ -159,9 +159,9 @@ class NerService: self._model_helper = _ModelLifecycleHelper(ner_engine) self._extraction_helper = _ExtractionHelper(ner_engine, self._model_helper) - def is_ready(self) -> bool: - """Return True if the NER engine is loaded and ready.""" - return self._ner_engine.is_ready() + def is_ner_ready(self) -> bool: + engine = self._ner_engine + return engine.is_ready() async def extract_entities( self, diff --git a/src/noteflow/cli/models/_registry.py b/src/noteflow/cli/models/_registry.py index b7a2a8b..15e187e 100644 --- a/src/noteflow/cli/models/_registry.py +++ b/src/noteflow/cli/models/_registry.py @@ -2,27 +2,30 @@ from __future__ import annotations +from typing import Final + from noteflow.config.constants import SPACY_MODEL_LG, SPACY_MODEL_SM from ._types import ModelInfo -# Constants to avoid magic strings -_DEFAULT_MODEL = "spacy-en" -CMD_DOWNLOAD = "download" +_DEFAULT_MODEL: Final[str] = "spacy-en" +_SPACY_EN_LG: Final[str] = "spacy-en-lg" +CMD_DOWNLOAD: Final[str] = "download" +_SPACY_MODULE: Final[str] = "spacy" AVAILABLE_MODELS: dict[str, ModelInfo] = { _DEFAULT_MODEL: ModelInfo( name=_DEFAULT_MODEL, description=f"English NER model ({SPACY_MODEL_SM})", feature="ner", - install_command=["python", "-m", "spacy", CMD_DOWNLOAD, SPACY_MODEL_SM], + install_command=["python", "-m", _SPACY_MODULE, CMD_DOWNLOAD, SPACY_MODEL_SM], check_import=SPACY_MODEL_SM, ), - "spacy-en-lg": ModelInfo( - name="spacy-en-lg", + _SPACY_EN_LG: ModelInfo( + name=_SPACY_EN_LG, description=f"English NER model - large ({SPACY_MODEL_LG})", feature="ner", - install_command=["python", "-m", "spacy", CMD_DOWNLOAD, SPACY_MODEL_LG], + install_command=["python", "-m", _SPACY_MODULE, CMD_DOWNLOAD, SPACY_MODEL_LG], check_import=SPACY_MODEL_LG, ), } diff --git a/src/noteflow/config/constants/__init__.py b/src/noteflow/config/constants/__init__.py index ee11ef7..1712456 100644 --- a/src/noteflow/config/constants/__init__.py +++ b/src/noteflow/config/constants/__init__.py @@ -33,6 +33,8 @@ from noteflow.config.constants.domain import ( EXPORT_FORMAT_HTML, FEATURE_NAME_PROJECTS, MEETING_TITLE_PREFIX, + NER_BACKEND_GLINER, + NER_BACKEND_SPACY, PROVIDER_NAME_OPENAI, RULE_FIELD_APP_MATCH_PATTERNS, RULE_FIELD_AUTO_START_ENABLED, @@ -173,6 +175,8 @@ __all__ = [ "LOG_EVENT_WEBHOOK_UPDATE_FAILED", "MAX_GRPC_MESSAGE_SIZE", "MEETING_TITLE_PREFIX", + "NER_BACKEND_GLINER", + "NER_BACKEND_SPACY", "OAUTH_FIELD_ACCESS_TOKEN", "OAUTH_FIELD_EXPIRES_IN", "OAUTH_FIELD_REFRESH_TOKEN", diff --git a/src/noteflow/config/constants/domain.py b/src/noteflow/config/constants/domain.py index 92257c2..babad5f 100644 --- a/src/noteflow/config/constants/domain.py +++ b/src/noteflow/config/constants/domain.py @@ -44,6 +44,16 @@ SPACY_MODEL_LG: Final[str] = "en_core_web_lg" SPACY_MODEL_TRF: Final[str] = "en_core_web_trf" """Transformer-based English spaCy model for NER.""" +# ============================================================================= +# NER Backend Names +# ============================================================================= + +NER_BACKEND_SPACY: Final[str] = "spacy" +"""spaCy NER backend identifier.""" + +NER_BACKEND_GLINER: Final[str] = "gliner" +"""GLiNER NER backend identifier.""" + # ============================================================================= # Provider Names # ============================================================================= diff --git a/src/noteflow/config/settings/_main.py b/src/noteflow/config/settings/_main.py index 8709d42..0689ecf 100644 --- a/src/noteflow/config/settings/_main.py +++ b/src/noteflow/config/settings/_main.py @@ -14,7 +14,7 @@ from noteflow.config.constants.core import ( DEFAULT_OLLAMA_TIMEOUT_SECONDS, HOURS_PER_DAY, ) -from noteflow.config.constants.domain import DEFAULT_ANTHROPIC_MODEL +from noteflow.config.constants.domain import DEFAULT_ANTHROPIC_MODEL, NER_BACKEND_SPACY from noteflow.config.settings._base import ENV_FILE, EXTRA_IGNORE from noteflow.config.settings._triggers import TriggerSettings @@ -184,7 +184,9 @@ class Settings(TriggerSettings): ] grpc_partial_cadence_seconds: Annotated[ float, - Field(default=2.0, ge=0.5, le=10.0, description="Interval for emitting partial transcripts"), + Field( + default=2.0, ge=0.5, le=10.0, description="Interval for emitting partial transcripts" + ), ] grpc_min_partial_audio_seconds: Annotated[ float, @@ -218,11 +220,18 @@ class Settings(TriggerSettings): ] webhook_backoff_base: Annotated[ float, - Field(default=2.0, ge=1.1, le=5.0, description="Exponential backoff multiplier for webhook retries"), + Field( + default=2.0, + ge=1.1, + le=5.0, + description="Exponential backoff multiplier for webhook retries", + ), ] webhook_max_response_length: Annotated[ int, - Field(default=5 * 100, ge=100, le=10 * 1000, description="Maximum response body length to log"), + Field( + default=5 * 100, ge=100, le=10 * 1000, description="Maximum response body length to log" + ), ] # LLM/Summarization settings @@ -241,7 +250,9 @@ class Settings(TriggerSettings): ] llm_default_anthropic_model: Annotated[ str, - Field(default=DEFAULT_ANTHROPIC_MODEL, description="Default Anthropic model for summarization"), + Field( + default=DEFAULT_ANTHROPIC_MODEL, description="Default Anthropic model for summarization" + ), ] llm_timeout_seconds: Annotated[ float, @@ -298,6 +309,31 @@ class Settings(TriggerSettings): ), ] + # NER settings + ner_backend: Annotated[ + str, + Field( + default=NER_BACKEND_SPACY, + description="NER backend: spacy (default) or gliner", + ), + ] + ner_gliner_model: Annotated[ + str, + Field( + default="urchade/gliner_medium-v2.1", + description="GLiNER model name from HuggingFace Hub", + ), + ] + ner_gliner_threshold: Annotated[ + float, + Field( + default=0.5, + ge=0.0, + le=1.0, + description="GLiNER confidence threshold for entity extraction", + ), + ] + @property def database_url_str(self) -> str: """Return database URL as string.""" diff --git a/src/noteflow/domain/entities/named_entity.py b/src/noteflow/domain/entities/named_entity.py index 56b7578..f1fc54e 100644 --- a/src/noteflow/domain/entities/named_entity.py +++ b/src/noteflow/domain/entities/named_entity.py @@ -27,15 +27,29 @@ class EntityCategory(Enum): Note: TECHNICAL and ACRONYM are placeholders for future custom pattern matching (not currently mapped from spaCy's default NER model). + + Meeting-specific categories (GLiNER): + TIME_RELATIVE -> Relative time references (next week, tomorrow) + DURATION -> Time durations (20 minutes, two weeks) + EVENT -> Events and activities (meeting, demo, standup) + TASK -> Action items and tasks (file a Jira, patch) + DECISION -> Decisions made (decided to drop) + TOPIC -> Discussion topics (budget, onboarding flow) """ PERSON = "person" COMPANY = "company" PRODUCT = "product" - TECHNICAL = "technical" # Future: custom pattern matching - ACRONYM = "acronym" # Future: custom pattern matching + TECHNICAL = "technical" + ACRONYM = "acronym" LOCATION = ENTITY_LOCATION DATE = ENTITY_DATE + TIME_RELATIVE = "time_relative" + DURATION = "duration" + EVENT = "event" + TASK = "task" + DECISION = "decision" + TOPIC = "topic" OTHER = "other" @classmethod diff --git a/src/noteflow/grpc/mixins/_model_status.py b/src/noteflow/grpc/mixins/_model_status.py index fa2a316..d36a31e 100644 --- a/src/noteflow/grpc/mixins/_model_status.py +++ b/src/noteflow/grpc/mixins/_model_status.py @@ -70,7 +70,7 @@ def _append_ner_status(payload: dict[str, object], host: _ModelStatusHost) -> No if ner is None: payload["ner_ready"] = False return - payload["ner_ready"] = ner.is_ready() + payload["ner_ready"] = ner.is_ner_ready() def _append_summarization_status( @@ -93,8 +93,7 @@ def _append_summarization_status( return summaries: list[str] = [ - _format_provider_summary(mode, provider) - for mode, provider in service.providers.items() + _format_provider_summary(mode, provider) for mode, provider in service.providers.items() ] payload["summarization_providers"] = summaries diff --git a/src/noteflow/grpc/mixins/_task_callbacks.py b/src/noteflow/grpc/mixins/_task_callbacks.py index 022d074..4ebdadf 100644 --- a/src/noteflow/grpc/mixins/_task_callbacks.py +++ b/src/noteflow/grpc/mixins/_task_callbacks.py @@ -9,7 +9,8 @@ Sprint GAP-003: Error Handling Mismatches from __future__ import annotations import asyncio -from collections.abc import Awaitable, Callable +from collections.abc import Callable, Coroutine +from typing import Any from noteflow.infrastructure.logging import get_logger @@ -19,7 +20,7 @@ logger = get_logger(__name__) def create_job_done_callback( job_id: str, tasks_dict: dict[str, asyncio.Task[None]], - mark_failed: Callable[[str, str], Awaitable[None]], + mark_failed: Callable[[str, str], Coroutine[Any, Any, None]], ) -> Callable[[asyncio.Task[None]], None]: """Create task done callback that marks jobs failed on exception. @@ -54,7 +55,7 @@ def _handle_task_completion( task: asyncio.Task[None], job_id: str, tasks_dict: dict[str, asyncio.Task[None]], - mark_failed: Callable[[str, str], Awaitable[None]], + mark_failed: Callable[[str, str], Coroutine[Any, Any, None]], ) -> None: """Process task completion, logging failures and scheduling mark_failed. @@ -85,7 +86,7 @@ def _handle_task_completion( def _log_and_schedule_failure( job_id: str, exc: BaseException, - mark_failed: Callable[[str, str], Awaitable[None]], + mark_failed: Callable[[str, str], Coroutine[Any, Any, None]], ) -> None: """Log task exception and schedule the mark_failed coroutine. @@ -100,15 +101,25 @@ def _log_and_schedule_failure( exc_info=exc, ) - # Schedule the mark_failed coroutine (fire-and-forget) - # Note: This requires an active event loop, which should exist - # since we're being called from asyncio task completion try: - asyncio.create_task(mark_failed(job_id, str(exc))) + coro = mark_failed(job_id, str(exc)) + task: asyncio.Task[None] = asyncio.create_task(coro) + task.add_done_callback(lambda t: _log_mark_failed_result(t, job_id)) except RuntimeError as schedule_err: - # If we can't schedule (e.g., loop closed), log but don't crash logger.error( "Failed to schedule mark_failed for job %s: %s", job_id, schedule_err, ) + + +def _log_mark_failed_result(task: asyncio.Task[None], job_id: str) -> None: + if task.cancelled(): + return + exc = task.exception() + if exc is not None: + logger.error( + "mark_failed task failed for job %s: %s", + job_id, + exc, + ) diff --git a/src/noteflow/grpc/mixins/streaming/_processing/_congestion.py b/src/noteflow/grpc/mixins/streaming/_processing/_congestion.py index 302cf91..5f94786 100644 --- a/src/noteflow/grpc/mixins/streaming/_processing/_congestion.py +++ b/src/noteflow/grpc/mixins/streaming/_processing/_congestion.py @@ -4,6 +4,8 @@ from __future__ import annotations from typing import TYPE_CHECKING +from noteflow.infrastructure.logging import get_logger + from ....proto import noteflow_pb2 from ...converters import create_congestion_info from ._constants import PROCESSING_DELAY_THRESHOLD_MS, QUEUE_DEPTH_THRESHOLD @@ -11,6 +13,8 @@ from ._constants import PROCESSING_DELAY_THRESHOLD_MS, QUEUE_DEPTH_THRESHOLD if TYPE_CHECKING: from ...protocols import ServicerHost +logger = get_logger(__name__) + def calculate_congestion_info( host: ServicerHost, @@ -30,20 +34,17 @@ def calculate_congestion_info( processing_delay_ms = 0 if receipt_times := host.chunk_receipt_times.get(meeting_id): try: - # Access [0] can race with popleft() in decrement_pending_chunks oldest_receipt = receipt_times[0] processing_delay_ms = int((current_time - oldest_receipt) * 1000) except IndexError: - # Deque was emptied by concurrent popleft() - no pending chunks - pass + logger.debug("congestion_race_access: deque emptied by concurrent popleft") # Get queue depth (pending chunks not yet processed through ASR) queue_depth = host.pending_chunks.get(meeting_id, 0) # Determine if throttle is recommended throttle_recommended = ( - processing_delay_ms > PROCESSING_DELAY_THRESHOLD_MS - or queue_depth > QUEUE_DEPTH_THRESHOLD + processing_delay_ms > PROCESSING_DELAY_THRESHOLD_MS or queue_depth > QUEUE_DEPTH_THRESHOLD ) return create_congestion_info( @@ -68,5 +69,4 @@ def decrement_pending_chunks(host: ServicerHost, meeting_id: str) -> None: try: receipt_times.popleft() except IndexError: - # Deque already empty - can occur if multiple coroutines race - pass + logger.debug("congestion_race_popleft: deque already empty from concurrent access") diff --git a/src/noteflow/grpc/server/__init__.py b/src/noteflow/grpc/server/__init__.py index 2673156..d415dcf 100644 --- a/src/noteflow/grpc/server/__init__.py +++ b/src/noteflow/grpc/server/__init__.py @@ -154,9 +154,18 @@ class NoteFlowServer: self._servicer: NoteFlowServicer | None = None self._warmup_task: asyncio.Task[None] | None = None + def _warmup_task_done_callback(self, task: asyncio.Task[None]) -> None: + if task.cancelled(): + return + exc = task.exception() + if exc is not None: + logger.error("diarization_warmup_failed", error=str(exc)) + async def _apply_persisted_asr_config(self) -> None: settings = get_settings() - stored = await _load_asr_config_preference(self._state.session_factory, settings.meetings_dir) + stored = await _load_asr_config_preference( + self._state.session_factory, settings.meetings_dir + ) if stored is None: return @@ -207,6 +216,7 @@ class NoteFlowServer: self._warmup_task = asyncio.create_task( warm_diarization_engine(self._state.diarization_engine) ) + self._warmup_task.add_done_callback(self._warmup_task_done_callback) self._servicer = build_servicer(self._state, asr_engine, self._streaming_config) await recover_orphaned_jobs(self._state.session_factory) @@ -221,6 +231,14 @@ class NoteFlowServer: Args: grace_period: Time to wait for in-flight RPCs. """ + if self._warmup_task is not None and not self._warmup_task.done(): + self._warmup_task.cancel() + try: + await self._warmup_task + except asyncio.CancelledError: + logger.debug("Diarization warmup task cancelled") + self._warmup_task = None + await stop_server(self._server, self._servicer, self._db_engine, grace_period) if self._db_engine is not None: self._db_engine = None @@ -231,7 +249,6 @@ class NoteFlowServer: await self._server.wait_for_termination() - async def run_server_with_config(config: GrpcServerConfig) -> None: """Run the async gRPC server with structured configuration. diff --git a/src/noteflow/grpc/servicer/mixins.py b/src/noteflow/grpc/servicer/mixins.py index 44b985f..93d4646 100644 --- a/src/noteflow/grpc/servicer/mixins.py +++ b/src/noteflow/grpc/servicer/mixins.py @@ -124,6 +124,7 @@ class ServicerStreamingStateMixin: chunk_counts: dict[str, int] chunk_receipt_times: dict[str, deque[float]] pending_chunks: dict[str, int] + stop_requested: set[str] DEFAULT_SAMPLE_RATE: int def init_streaming_state(self, meeting_id: str, next_segment_id: int) -> None: """Initialize VAD, Segmenter, speaking state, and partial buffers for a meeting.""" @@ -177,6 +178,7 @@ class ServicerStreamingStateMixin: self.chunk_sequences.pop(meeting_id, None) self.chunk_counts.pop(meeting_id, None) + self.stop_requested.discard(meeting_id) if hasattr(self, "_chunk_receipt_times"): self.chunk_receipt_times.pop(meeting_id, None) @@ -298,8 +300,7 @@ class ServicerInfoMixin: diarization_enabled = self.diarization_engine is not None diarization_ready = self.diarization_engine is not None and ( - self.diarization_engine.is_streaming_loaded - or self.diarization_engine.is_offline_loaded + self.diarization_engine.is_streaming_loaded or self.diarization_engine.is_offline_loaded ) if self.session_factory is not None: diff --git a/src/noteflow/grpc/startup/services.py b/src/noteflow/grpc/startup/services.py index c7e56dc..f40799d 100644 --- a/src/noteflow/grpc/startup/services.py +++ b/src/noteflow/grpc/startup/services.py @@ -10,12 +10,14 @@ from noteflow.application.services.calendar import CalendarService from noteflow.application.services.ner import NerService from noteflow.application.services.webhooks import WebhookService from noteflow.config.settings import Settings, get_calendar_settings, get_feature_flags +from noteflow.config.constants.domain import NER_BACKEND_GLINER from noteflow.domain.constants.fields import CALENDAR, DEVICE from noteflow.domain.entities.integration import IntegrationStatus from noteflow.domain.ports.gpu import GpuBackend from noteflow.infrastructure.diarization import DiarizationEngine from noteflow.infrastructure.logging import get_logger from noteflow.infrastructure.ner import NerEngine +from noteflow.infrastructure.ner.backends.types import NerBackend from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWork from noteflow.infrastructure.webhooks import WebhookExecutor @@ -38,11 +40,7 @@ async def check_calendar_needed_from_db(uow: SqlAlchemyUnitOfWork) -> bool: return False calendar_integrations = await uow.integrations.list_by_type(CALENDAR) - if connected := [ - i - for i in calendar_integrations - if i.status == IntegrationStatus.CONNECTED - ]: + if connected := [i for i in calendar_integrations if i.status == IntegrationStatus.CONNECTED]: logger.info( "Auto-enabling calendar: found %d connected OAuth integration(s)", len(connected), @@ -84,8 +82,9 @@ def create_ner_service( ) return None - logger.info("Initializing NER service (spaCy)...") - ner_engine = NerEngine() + backend = _create_ner_backend(settings) + logger.info("Initializing NER service", backend=settings.ner_backend) + ner_engine = NerEngine(backend=backend) ner_service = NerService( ner_engine=ner_engine, uow_factory=lambda: SqlAlchemyUnitOfWork(session_factory, settings.meetings_dir), @@ -94,6 +93,21 @@ def create_ner_service( return ner_service +def _create_ner_backend(settings: Settings) -> NerBackend: + """Create NER backend based on settings.""" + if settings.ner_backend == NER_BACKEND_GLINER: + from noteflow.infrastructure.ner.backends.gliner_backend import GLiNERBackend + + return GLiNERBackend( + model_name=settings.ner_gliner_model, + threshold=settings.ner_gliner_threshold, + ) + + from noteflow.infrastructure.ner.backends.spacy_backend import SpacyBackend + + return SpacyBackend() + + async def create_calendar_service( session_factory: async_sessionmaker[AsyncSession] | None, settings: Settings, diff --git a/src/noteflow/infrastructure/CLAUDE.md b/src/noteflow/infrastructure/CLAUDE.md new file mode 100644 index 0000000..fbf1ee3 --- /dev/null +++ b/src/noteflow/infrastructure/CLAUDE.md @@ -0,0 +1,352 @@ +# Infrastructure Layer Development Guide + +## Overview + +The infrastructure layer (`src/noteflow/infrastructure/`) contains adapters that implement domain ports. These connect the application to external systems: databases, ML models, cloud APIs, file systems. + +**Architecture**: Hexagonal (Ports-and-Adapters) +- Domain Ports define interfaces in `domain/ports/` +- Infrastructure Adapters implement those interfaces here +- **Rule**: Infrastructure imports domain; domain NEVER imports infrastructure + +--- + +## Adapter Catalog + +| Directory | Responsibility | Port/Protocol | Key Dependencies | +|-----------|----------------|---------------|------------------| +| `asr/` | Speech-to-text (Whisper) | `asr/protocols.AsrEngine` | faster-whisper | +| `diarization/` | Speaker identification | internal protocols | pyannote, diart | +| `ner/` | Named entity extraction | `domain/ports/ner.NerPort` | spacy, gliner | +| `summarization/` | LLM summarization | `application/services/protocols` | anthropic, openai | +| `persistence/` | Database (SQLAlchemy) | `domain/ports/unit_of_work.UnitOfWork` | sqlalchemy, asyncpg | +| `calendar/` | OAuth + event sync | `domain/ports/calendar.CalendarPort` | httpx | +| `webhooks/` | Event delivery | — | httpx | +| `export/` | PDF/HTML/Markdown | `export/protocols.TranscriptExporter` | weasyprint, jinja2 | +| `audio/` | Recording/playback | `audio/protocols.AudioCapture` | sounddevice | +| `security/` | Encryption | `security/protocols.CryptoBox`, `KeyStore` | cryptography, keyring | +| `gpu/` | GPU detection | — | torch | +| `triggers/` | Signal providers | internal protocols | platform-specific | +| `auth/` | OIDC management | — | httpx | +| `converters/` | Layer bridging | — | — | +| `logging/` | Structured logging | — | structlog | +| `observability/` | OpenTelemetry & usage tracking | — | opentelemetry | +| `metrics/` | Metrics collection | — | prometheus_client | + +**Note**: Some protocols are domain-level ports (`domain/ports/`), others are infrastructure-level protocols defined within the subsystem. + +--- + +## Common Implementation Patterns + +### 1. Async Wrappers for Sync Libraries + +Many ML libraries (spaCy, faster-whisper, pyannote) are synchronous. Always wrap: + +```python +async def extract_async(self, text: str) -> list[NamedEntity]: + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, # Default ThreadPoolExecutor + self._sync_extract, + text + ) +``` + +### 2. Backend Selection via Factory + +```python +def create_ner_engine(config: NerConfig) -> NerPort: + match config.backend: + case "spacy": + return SpacyBackend(model=config.model_name) + case "gliner": + return GlinerBackend(model=config.model_name) + case _: + raise ValueError(f"Unknown NER backend: {config.backend}") +``` + +### 3. Protocol-Based Backends + +```python +# backends/types.py +class NerBackend(Protocol): + def extract(self, text: str) -> list[RawEntity]: ... + def load_model(self) -> None: ... + def unload(self) -> None: ... + @property + def model_loaded(self) -> bool: ... +``` + +### 4. Capability Flags for Optional Features + +```python +class SqlAlchemyUnitOfWork(UnitOfWork): + @property + def supports_entities(self) -> bool: + return True # Has EntityRepository + + @property + def supports_webhooks(self) -> bool: + return True # Has WebhookRepository +``` + +**Always check capability before accessing optional repository**: +```python +if uow.supports_entities: + entities = await uow.entities.get_by_meeting(meeting_id) +``` + +### 5. Mixin Composition for Complex Engines + +```python +class DiarizationEngine( + DeviceMixin, # GPU/CPU detection + StreamingMixin, # Real-time diarization (diart) + OfflineMixin, # Batch diarization (pyannote) + ProcessingMixin, # Audio frame processing + LifecycleMixin, # Model loading/unloading +): + """Composed from focused mixins for maintainability.""" +``` + +### 6. Data Transfer Objects (DTOs) + +```python +@dataclass(frozen=True) +class RawEntity: + text: str + label: str + start: int + end: int + confidence: float = 1.0 +``` + +Use frozen dataclasses for immutable DTOs between layers. + +--- + +## Subsystem-Specific Guidance + +### NER (`ner/`) + +**Structure**: +``` +ner/ +├── engine.py # NerEngine (composition class) +├── mapper.py # RawEntity → NamedEntity +├── post_processing.py # Pipeline: drop → infer → merge → dedupe +└── backends/ + ├── types.py # NerBackend protocol, RawEntity + ├── gliner_backend.py + └── spacy_backend.py +``` + +**Key classes**: +- `NerEngine`: Composes backend + mapper + post-processing +- `NerBackend`: Protocol for backend implementations +- `RawEntity`: Backend output (frozen dataclass) + +**Flow**: `Backend.extract()` → `map_raw_entities_to_named()` → `post_process()` → `dedupe_entities()` + +### ASR (`asr/`) + +**Key classes**: +- `FasterWhisperEngine`: Whisper wrapper using CTranslate2 +- `AsrResult`: Dataclass with text, language, segments, duration +- `StreamingVad`: Real-time voice activity detection +- `Segmenter`: VAD-based audio chunking + +**Device handling**: Uses `compute_type` (int8, float16, float32) + `device` (cpu, cuda, rocm) + +### Diarization (`diarization/`) + +**Modes**: +- **Streaming** (diart): Real-time speaker detection during recording +- **Offline** (pyannote): Post-meeting refinement with full audio + +**Key classes**: +- `DiarizationEngine`: Main engine with mixins +- `DiarizationSession`: Session state for active diarization +- `SpeakerTurn`: Dataclass with speaker_id, start, end + +### Persistence (`persistence/`) + +**Structure**: +``` +persistence/ +├── unit_of_work/ # SqlAlchemyUnitOfWork (async context manager) +├── models/ # SQLAlchemy ORM models +├── repositories/ # Repository implementations +├── memory/ # In-memory fallback for testing +└── migrations/ # Alembic migrations +``` + +**Unit of Work Pattern**: +```python +async with uow: + meeting = await uow.meetings.get(meeting_id) + segments = await uow.segments.get_by_meeting(meeting_id) + # Auto-commit on success, rollback on exception +``` + +**Base repository helpers**: +- `_execute_scalar()`: Single result query with timing +- `_execute_scalars()`: Multiple results with timing +- `_add_and_flush()`: Persist single model +- `_add_all_and_flush()`: Batch persist + +### Summarization (`summarization/`) + +**Providers**: +- `CloudProvider`: Anthropic/OpenAI wrapper +- `OllamaSummarizer`: Local LLM via Ollama +- `MockSummarizer`: Testing fixture + +**Key features**: +- Factory auto-detects available providers +- Citation verification links claims to segments +- Template rendering via Jinja2 + +### Webhooks (`webhooks/`) + +**Key features**: +- HMAC-SHA256 signing (`X-Noteflow-Signature` header) +- Exponential backoff with configurable multiplier +- Async delivery with httpx connection pooling +- Delivery tracking and metrics + +### Calendar (`calendar/`) + +**Adapters**: Google Calendar, Outlook (Microsoft Graph) + +**Key classes**: +- `GoogleCalendarAdapter`: Google Calendar API v3 +- `OutlookCalendarAdapter`: Microsoft Graph API +- `OAuthManager`: OAuth flow orchestration + +--- + +## Forbidden Patterns + +### Direct database access outside persistence/ +```python +# WRONG: Raw SQL in service layer +async with engine.connect() as conn: + result = await conn.execute(text("SELECT * FROM meetings")) +``` + +### Hardcoded API keys +```python +# WRONG: Secrets in code +client = anthropic.Anthropic(api_key="sk-ant-...") +# RIGHT: Use environment variables or config +``` + +### Synchronous I/O in async context +```python +# WRONG: Blocking the event loop +def load_model(self): + self.model = whisper.load_model("base") # Blocks! + +# RIGHT: Use executor +async def load_model_async(self): + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, self._load_model_sync) +``` + +### Domain mutations in infrastructure +```python +# WRONG: Infrastructure modifying domain objects +from noteflow.domain.entities import Meeting +meeting.state = "COMPLETED" # Don't mutate here! + +# RIGHT: Return new data, let application layer handle state +``` + +### Skipping capability checks +```python +# WRONG: Assuming optional features exist +entities = await uow.entities.get_by_meeting(meeting_id) + +# RIGHT: Check capability first +if uow.supports_entities: + entities = await uow.entities.get_by_meeting(meeting_id) +``` + +--- + +## Testing Infrastructure Adapters + +### Unit Tests: Mock Dependencies + +```python +# tests/infrastructure/ner/test_engine.py +@pytest.fixture +def mock_backend() -> NerBackend: + backend = Mock(spec=NerBackend) + backend.extract.return_value = [ + RawEntity(text="John", label="PERSON", start=0, end=4) + ] + return backend + +async def test_engine_uses_backend(mock_backend): + engine = NerEngine(backend=mock_backend) + result = await engine.extract_async("Hello John") + mock_backend.extract.assert_called_once() +``` + +### Integration Tests: Real Services + +```python +# tests/integration/test_ner_integration.py +@pytest.mark.integration +@pytest.mark.requires_gpu +async def test_gliner_real_extraction(): + backend = GlinerBackend(model="urchade/gliner_base") + result = backend.extract("Microsoft CEO Satya Nadella announced...") + assert any(e.label == "ORG" and "Microsoft" in e.text for e in result) +``` + +### Use pytest markers: +- `@pytest.mark.integration` — requires external services +- `@pytest.mark.requires_gpu` — requires CUDA/ROCm +- `@pytest.mark.slow` — long-running tests + +--- + +## Adding a New Adapter + +1. **Define port in domain** (`domain/ports/`) if not exists +2. **Create adapter directory** (`infrastructure//`) +3. **Implement the protocol** with proper async handling +4. **Add factory function** for backend selection (if multiple implementations) +5. **Write unit tests** with mocked dependencies +6. **Write integration test** with real external service +7. **Update gRPC startup** (`grpc/startup/services.py`) for dependency injection +8. **Update this file** (add to Adapter Catalog table) + +--- + +## Key Files Reference + +| Pattern | Location | +|---------|----------| +| Main engine | `/engine.py` | +| Backend protocol | `/backends/types.py` | +| Backend implementations | `/backends/_backend.py` | +| External→Domain mapping | `/mapper.py` | +| Output normalization | `/post_processing.py` | +| UnitOfWork | `persistence/unit_of_work/unit_of_work.py` | +| Repository base | `persistence/repositories/_base/_base.py` | +| ORM models | `persistence/models//*.py` | +| Migrations | `persistence/migrations/versions/*.py` | + +--- + +## See Also + +- `/src/noteflow/domain/ports/` — Port definitions (interfaces) +- `/src/noteflow/grpc/startup/services.py` — Dependency injection wiring +- `/tests/infrastructure/` — Adapter test suites +- `/src/noteflow/CLAUDE.md` — Python backend standards diff --git a/src/noteflow/infrastructure/asr/pytorch_engine.py b/src/noteflow/infrastructure/asr/pytorch_engine.py index dba507f..7f9e431 100644 --- a/src/noteflow/infrastructure/asr/pytorch_engine.py +++ b/src/noteflow/infrastructure/asr/pytorch_engine.py @@ -23,6 +23,8 @@ from noteflow.domain.constants.fields import START from noteflow.infrastructure.asr.dto import AsrResult, WordTiming from noteflow.infrastructure.logging import get_logger +_KEY_LANGUAGE: str = "language" + if TYPE_CHECKING: import numpy as np from numpy.typing import NDArray @@ -73,6 +75,7 @@ class _WhisperModel(Protocol): """Convert model to half precision.""" ... + # Valid model sizes for openai-whisper PYTORCH_VALID_MODEL_SIZES: tuple[str, ...] = ( "tiny", @@ -232,14 +235,14 @@ class WhisperPyTorchEngine: } if language is not None: - options["language"] = language + options[_KEY_LANGUAGE] = language # Transcribe result = self._model.transcribe(audio, **options) # Convert to our segment format segments = result.get("segments", []) - detected_language = result.get("language", "en") + detected_language = result.get(_KEY_LANGUAGE, "en") for segment in segments: words = self._extract_word_timings(segment) diff --git a/src/noteflow/infrastructure/audio/capture.py b/src/noteflow/infrastructure/audio/capture.py index 551afdc..e766a87 100644 --- a/src/noteflow/infrastructure/audio/capture.py +++ b/src/noteflow/infrastructure/audio/capture.py @@ -6,6 +6,7 @@ Provide cross-platform audio input capture with device handling. from __future__ import annotations import time +import weakref from collections.abc import Callable, Mapping from typing import TYPE_CHECKING, Unpack @@ -239,6 +240,8 @@ class SoundDeviceCapture: def _build_stream_callback( self, channels: int ) -> Callable[[NDArray[np.float32], int, object, object], None]: + weak_self = weakref.ref(self) + def _stream_callback( indata: NDArray[np.float32], frames: int, @@ -250,9 +253,10 @@ class SoundDeviceCapture: if status: logger.warning("Audio stream status: %s", status) - if self._callback is not None: + capture = weak_self() + if capture is not None and capture._callback is not None: audio_data = indata[:, 0].copy() if channels == 1 else indata.flatten() timestamp = time.monotonic() - self._callback(audio_data, timestamp) + capture._callback(audio_data, timestamp) return _stream_callback diff --git a/src/noteflow/infrastructure/ner/__init__.py b/src/noteflow/infrastructure/ner/__init__.py index 1172d39..db8df7b 100644 --- a/src/noteflow/infrastructure/ner/__init__.py +++ b/src/noteflow/infrastructure/ner/__init__.py @@ -1,5 +1,8 @@ """Named Entity Recognition infrastructure.""" +from noteflow.infrastructure.ner.backends import NerBackend, RawEntity +from noteflow.infrastructure.ner.backends.gliner_backend import GLiNERBackend +from noteflow.infrastructure.ner.backends.spacy_backend import SpacyBackend from noteflow.infrastructure.ner.engine import NerEngine -__all__ = ["NerEngine"] +__all__ = ["GLiNERBackend", "NerBackend", "NerEngine", "RawEntity", "SpacyBackend"] diff --git a/src/noteflow/infrastructure/ner/backends/__init__.py b/src/noteflow/infrastructure/ner/backends/__init__.py new file mode 100644 index 0000000..43cef0f --- /dev/null +++ b/src/noteflow/infrastructure/ner/backends/__init__.py @@ -0,0 +1,5 @@ +"""NER backend implementations.""" + +from noteflow.infrastructure.ner.backends.types import NerBackend, RawEntity + +__all__ = ["NerBackend", "RawEntity"] diff --git a/src/noteflow/infrastructure/ner/backends/gliner_backend.py b/src/noteflow/infrastructure/ner/backends/gliner_backend.py new file mode 100644 index 0000000..37f5b3a --- /dev/null +++ b/src/noteflow/infrastructure/ner/backends/gliner_backend.py @@ -0,0 +1,122 @@ +"""GLiNER NER backend implementation.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Final, TypedDict, cast + +from noteflow.domain.entities.named_entity import EntityCategory +from noteflow.infrastructure.logging import get_logger, log_timing +from noteflow.infrastructure.ner.backends.types import RawEntity + +if TYPE_CHECKING: + from gliner import GLiNER + +logger = get_logger(__name__) + + +class _GLiNEREntityDict(TypedDict): + text: str + label: str + start: int + end: int + score: float + + +DEFAULT_MODEL: Final[str] = "urchade/gliner_medium-v2.1" +DEFAULT_THRESHOLD: Final[float] = 0.5 + +MEETING_LABELS: Final[tuple[str, ...]] = ( + EntityCategory.PERSON.value, + "org", + EntityCategory.PRODUCT.value, + "app", + EntityCategory.LOCATION.value, + "time", + EntityCategory.TIME_RELATIVE.value, + EntityCategory.DURATION.value, + EntityCategory.EVENT.value, + EntityCategory.TASK.value, + EntityCategory.DECISION.value, + EntityCategory.TOPIC.value, +) + + +class GLiNERBackend: + """GLiNER-based NER backend for meeting transcripts. + + Uses generalist zero-shot NER with custom label sets optimized + for meeting content (people, products, events, decisions, etc.). + """ + + def __init__( + self, + model_name: str = DEFAULT_MODEL, + labels: tuple[str, ...] = MEETING_LABELS, + threshold: float = DEFAULT_THRESHOLD, + ) -> None: + self._model_name = model_name + self._labels = labels + self._threshold = threshold + self._model: GLiNER | None = None + + def load_model(self) -> None: + from gliner import GLiNER as GLiNERClass + + with log_timing("gliner_model_load", model_name=self._model_name): + self._model = GLiNERClass.from_pretrained(self._model_name) + logger.info("GLiNER model loaded", model_name=self._model_name) + + def _ensure_loaded(self) -> GLiNER: + if self._model is None: + self.load_model() + assert self._model is not None + return self._model + + def model_loaded(self) -> bool: + return self._model is not None + + def unload(self) -> None: + self._model = None + logger.info("GLiNER model unloaded") + + def _call_predict_entities( + self, + model: GLiNER, + text: str, + ) -> list[_GLiNEREntityDict]: + result = model.predict_entities( + text, + labels=list(self._labels), + threshold=self._threshold, + ) + return cast(list[_GLiNEREntityDict], result) + + def extract(self, text: str) -> list[RawEntity]: + if not text or not text.strip(): + return [] + + model = self._ensure_loaded() + raw_entities = self._call_predict_entities(model, text) + + return [ + RawEntity( + text=ent["text"], + label=ent["label"].lower(), + start=ent["start"], + end=ent["end"], + confidence=ent["score"], + ) + for ent in raw_entities + ] + + @property + def model_name(self) -> str: + return self._model_name + + @property + def labels(self) -> tuple[str, ...]: + return self._labels + + @property + def threshold(self) -> float: + return self._threshold diff --git a/src/noteflow/infrastructure/ner/backends/spacy_backend.py b/src/noteflow/infrastructure/ner/backends/spacy_backend.py new file mode 100644 index 0000000..804a9d9 --- /dev/null +++ b/src/noteflow/infrastructure/ner/backends/spacy_backend.py @@ -0,0 +1,152 @@ +"""spaCy NER backend implementation.""" + +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING, Final, Protocol + +from noteflow.config.constants import ( + SPACY_MODEL_LG, + SPACY_MODEL_MD, + SPACY_MODEL_SM, + SPACY_MODEL_TRF, +) +from noteflow.infrastructure.logging import get_logger, log_timing +from noteflow.infrastructure.ner.backends.types import RawEntity + +if TYPE_CHECKING: + from spacy.language import Language + from spacy.tokens import Span + +logger = get_logger(__name__) + + +class _SpacyModule(Protocol): + def load(self, name: str) -> Language: ... + + +VALID_SPACY_MODELS: Final[tuple[str, ...]] = ( + SPACY_MODEL_SM, + SPACY_MODEL_MD, + SPACY_MODEL_LG, + SPACY_MODEL_TRF, +) + +SKIP_ENTITY_TYPES: Final[frozenset[str]] = frozenset( + { + "CARDINAL", + "ORDINAL", + "QUANTITY", + "PERCENT", + "MONEY", + } +) + + +class SpacyBackend: + """spaCy-based NER backend. + + Uses spaCy's pre-trained models for entity extraction. Supports + lazy model loading and fallback from transformer to small model. + """ + + def __init__(self, model_name: str = SPACY_MODEL_TRF) -> None: + if model_name not in VALID_SPACY_MODELS: + raise ValueError( + f"Invalid model name: {model_name}. Valid models: {', '.join(VALID_SPACY_MODELS)}" + ) + self._model_name = model_name + self._nlp: Language | None = None + + def load_model(self) -> None: + import spacy + + self._warn_if_curated_transformers_missing() + with log_timing("spacy_model_load", model_name=self._model_name): + try: + self._nlp = spacy.load(self._model_name) + except OSError as exc: + self._nlp = self._handle_model_load_failure(spacy, exc) + + logger.info("spaCy model loaded", model_name=self._model_name) + + def _warn_if_curated_transformers_missing(self) -> None: + if self._model_name != SPACY_MODEL_TRF: + return + try: + importlib.import_module("spacy_curated_transformers") + except ModuleNotFoundError: + logger.warning( + "spaCy curated transformers not installed; transformer model load may fail", + model_name=self._model_name, + ) + + def _handle_model_load_failure( + self, + spacy_module: _SpacyModule, + exc: OSError, + ) -> Language: + if self._model_name == SPACY_MODEL_TRF: + return self._load_fallback_model(spacy_module) + msg = ( + f"Failed to load spaCy model '{self._model_name}'. " + f"Run: python -m spacy download {self._model_name}" + ) + raise RuntimeError(msg) from exc + + def _load_fallback_model(self, spacy_module: _SpacyModule) -> Language: + fallback_model = SPACY_MODEL_SM + logger.warning( + "spaCy model '%s' unavailable; falling back to '%s'", + self._model_name, + fallback_model, + ) + try: + nlp = spacy_module.load(fallback_model) + except OSError as fallback_error: + msg = ( + f"Failed to load spaCy model '{self._model_name}' " + f"and fallback '{fallback_model}'. " + f"Run: python -m spacy download {fallback_model}" + ) + raise RuntimeError(msg) from fallback_error + self._model_name = fallback_model + return nlp + + def _ensure_loaded(self) -> Language: + if self._nlp is None: + self.load_model() + assert self._nlp is not None + return self._nlp + + def model_loaded(self) -> bool: + return self._nlp is not None + + def unload(self) -> None: + self._nlp = None + logger.info("spaCy model unloaded") + + def extract(self, text: str) -> list[RawEntity]: + if not text or not text.strip(): + return [] + + nlp = self._ensure_loaded() + doc = nlp(text) + + return [ + _spacy_entity_to_raw(ent) for ent in doc.ents if ent.label_ not in SKIP_ENTITY_TYPES + ] + + @property + def model_name(self) -> str: + return self._model_name + + +def _spacy_entity_to_raw(ent: Span) -> RawEntity: + return RawEntity( + text=ent.text, + label=ent.label_.lower(), + start=ent.start_char, + end=ent.end_char, + confidence=None, + ) diff --git a/src/noteflow/infrastructure/ner/backends/types.py b/src/noteflow/infrastructure/ner/backends/types.py new file mode 100644 index 0000000..de9e74e --- /dev/null +++ b/src/noteflow/infrastructure/ner/backends/types.py @@ -0,0 +1,56 @@ +"""Backend-agnostic types for NER extraction.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Protocol + + +@dataclass(frozen=True, slots=True) +class RawEntity: + """Backend-agnostic entity representation. + + All backends convert their output to this format before post-processing. + Labels must be lowercase (e.g., 'person', 'time_relative'). + """ + + text: str + label: str + start: int + end: int + confidence: float | None = None + + def __post_init__(self) -> None: + if self.label != self.label.lower(): + object.__setattr__(self, "label", self.label.lower()) + + +class NerBackend(Protocol): + """Protocol for NER backend implementations. + + Backends extract raw entities from text. Post-processing and mapping + to domain entities happens in separate pipeline stages. + """ + + def extract(self, text: str) -> list[RawEntity]: + """Extract raw entities from text. + + Args: + text: Input text to analyze. + + Returns: + List of raw entities with lowercase labels. + """ + ... + + def model_loaded(self) -> bool: + """Check if the backend model is loaded and ready.""" + ... + + def load_model(self) -> None: + """Load the model (for lazy initialization).""" + ... + + def unload(self) -> None: + """Unload the model to free resources.""" + ... diff --git a/src/noteflow/infrastructure/ner/engine.py b/src/noteflow/infrastructure/ner/engine.py index f417393..dc2aef5 100644 --- a/src/noteflow/infrastructure/ner/engine.py +++ b/src/noteflow/infrastructure/ner/engine.py @@ -1,264 +1,51 @@ -"""NER engine implementation using spaCy. +"""NER engine with backend composition. -Provides named entity extraction with lazy model loading and segment tracking. +Provides named entity extraction with configurable backends and shared +post-processing pipeline. """ from __future__ import annotations import asyncio -import importlib from functools import partial -from typing import TYPE_CHECKING, Final, Protocol -from noteflow.config.constants import ( - SPACY_MODEL_LG, - SPACY_MODEL_MD, - SPACY_MODEL_SM, - SPACY_MODEL_TRF, -) -from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity -from noteflow.infrastructure.logging import get_logger, log_timing - -if TYPE_CHECKING: - from spacy.language import Language - - -class _SpacyModule(Protocol): - def load(self, name: str) -> Language: ... - -logger = get_logger(__name__) - -# Map spaCy entity types to our categories -_SPACY_CATEGORY_MAP: Final[dict[str, EntityCategory]] = { - # People - "PERSON": EntityCategory.PERSON, - # Organizations - "ORG": EntityCategory.COMPANY, - # Products and creative works - "PRODUCT": EntityCategory.PRODUCT, - "WORK_OF_ART": EntityCategory.PRODUCT, - # Locations - "GPE": EntityCategory.LOCATION, # Geo-political entity (countries, cities) - "LOC": EntityCategory.LOCATION, # Non-GPE locations (mountains, rivers) - "FAC": EntityCategory.LOCATION, # Facilities (buildings, airports) - # Dates and times - "DATE": EntityCategory.DATE, - "TIME": EntityCategory.DATE, - # Others (filtered out or mapped to OTHER) - "MONEY": EntityCategory.OTHER, - "PERCENT": EntityCategory.OTHER, - "CARDINAL": EntityCategory.OTHER, - "ORDINAL": EntityCategory.OTHER, - "QUANTITY": EntityCategory.OTHER, - "NORP": EntityCategory.OTHER, # Nationalities, religions - "EVENT": EntityCategory.OTHER, - "LAW": EntityCategory.OTHER, - "LANGUAGE": EntityCategory.OTHER, -} - -# Entity types to skip (low value for meeting context) -_SKIP_ENTITY_TYPES: Final[frozenset[str]] = frozenset({ - "CARDINAL", - "ORDINAL", - "QUANTITY", - "PERCENT", - "MONEY", -}) - -# Valid model names -VALID_SPACY_MODELS: Final[tuple[str, ...]] = ( - SPACY_MODEL_SM, - SPACY_MODEL_MD, - SPACY_MODEL_LG, - SPACY_MODEL_TRF, -) +from noteflow.domain.entities.named_entity import NamedEntity +from noteflow.infrastructure.ner.backends.types import NerBackend +from noteflow.infrastructure.ner.mapper import map_raw_entities_to_named +from noteflow.infrastructure.ner.post_processing import post_process class NerEngine: - """Named entity recognition engine using spaCy. + """Named entity recognition engine with backend composition. - Lazy-loads the spaCy model on first use to avoid startup delay. - Implements the NerPort protocol for hexagonal architecture. - - Uses chunking by segment (speaker turn/paragraph) to avoid OOM on - long transcripts while maintaining segment tracking. + Composes a backend (spaCy or GLiNER) with shared post-processing + and domain entity mapping. Implements NerPort protocol. """ - def __init__(self, model_name: str = SPACY_MODEL_TRF) -> None: - """Initialize NER engine. - - Args: - model_name: spaCy model to use. Defaults to transformer model - for higher accuracy. - """ - if model_name not in VALID_SPACY_MODELS: - raise ValueError( - f"Invalid model name: {model_name}. " - f"Valid models: {', '.join(VALID_SPACY_MODELS)}" - ) - self._model_name = model_name - self._nlp: Language | None = None + def __init__(self, backend: NerBackend) -> None: + self._backend = backend def load_model(self) -> None: - """Load the spaCy model. - - Raises: - RuntimeError: If model loading fails. - """ - import spacy - - self._warn_if_curated_transformers_missing() - with log_timing("ner_model_load", model_name=self._model_name): - try: - self._nlp = spacy.load(self._model_name) - except OSError as exc: - self._nlp = self._handle_model_load_failure(spacy, exc) - - logger.info("ner_model_loaded", model_name=self._model_name) - - def _warn_if_curated_transformers_missing(self) -> None: - if self._model_name != SPACY_MODEL_TRF: - return - try: - importlib.import_module("spacy_curated_transformers") - except ModuleNotFoundError: - logger.warning( - "spaCy curated transformers not installed; " - "transformer model load may fail", - model_name=self._model_name, - ) - - def _handle_model_load_failure(self, spacy_module: _SpacyModule, exc: OSError) -> Language: - if self._model_name == SPACY_MODEL_TRF: - return self._load_fallback_model(spacy_module) - msg = ( - f"Failed to load spaCy model '{self._model_name}'. " - f"Run: python -m spacy download {self._model_name}" - ) - raise RuntimeError(msg) from exc - - def _load_fallback_model(self, spacy_module: _SpacyModule) -> Language: - fallback_model = SPACY_MODEL_SM - logger.warning( - "spaCy model '%s' unavailable; falling back to '%s'", - self._model_name, - fallback_model, - ) - try: - nlp = spacy_module.load(fallback_model) - except OSError as fallback_error: - msg = ( - f"Failed to load spaCy model '{self._model_name}' " - f"and fallback '{fallback_model}'. " - f"Run: python -m spacy download {fallback_model}" - ) - raise RuntimeError(msg) from fallback_error - self._model_name = fallback_model - return nlp - - def _ensure_loaded(self) -> Language: - """Ensure model is loaded, loading if necessary. - - Returns: - The loaded spaCy Language model. - """ - if self._nlp is None: - self.load_model() - assert self._nlp is not None, "load_model() should set self._nlp" - return self._nlp + self._backend.load_model() def is_ready(self) -> bool: - """Check if model is loaded.""" - return self._nlp is not None - - def unload(self) -> None: - """Unload the model to free memory.""" - self._nlp = None - logger.info("spaCy model unloaded") - - @property - def model_name(self) -> str: - """Return the model name.""" - return self._model_name + backend = self._backend + return backend.model_loaded() def extract(self, text: str) -> list[NamedEntity]: - """Extract named entities from text. - - Args: - text: Input text to analyze. - - Returns: - List of extracted entities (deduplicated by normalized text). - """ if not text or not text.strip(): return [] - nlp = self._ensure_loaded() - doc = nlp(text) + raw_entities = self._backend.extract(text) + processed = post_process(raw_entities, text) + entities = map_raw_entities_to_named(processed) - entities: list[NamedEntity] = [] - seen: set[str] = set() - - for ent in doc.ents: - # Normalize for deduplication - normalized = ent.text.lower().strip() - if not normalized or normalized in seen: - continue - - # Skip low-value entity types - if ent.label_ in _SKIP_ENTITY_TYPES: - continue - - seen.add(normalized) - category = _SPACY_CATEGORY_MAP.get(ent.label_, EntityCategory.OTHER) - - entities.append( - NamedEntity.create( - text=ent.text, - category=category, - segment_ids=[], # Filled by caller - confidence=0.8, # spaCy doesn't provide per-entity confidence - ) - ) - - return entities - - def _merge_entity_into_collection( - self, - entity: NamedEntity, - segment_id: int, - all_entities: dict[str, NamedEntity], - ) -> None: - """Merge an entity into the collection, tracking segment occurrences. - - Args: - entity: Entity to merge. - segment_id: Segment where entity was found. - all_entities: Collection to merge into. - """ - key = entity.normalized_text - if key in all_entities: - all_entities[key].merge_segments([segment_id]) - else: - entity.segment_ids = [segment_id] - all_entities[key] = entity + return _dedupe_by_normalized_text(entities) def extract_from_segments( self, segments: list[tuple[int, str]], ) -> list[NamedEntity]: - """Extract entities from multiple segments with segment tracking. - - Processes each segment individually (chunking by speaker turn/paragraph) - to avoid OOM on long transcripts. Entities appearing in multiple segments - are deduplicated with merged segment lists. - - Args: - segments: List of (segment_id, text) tuples. - - Returns: - Entities with segment_ids populated (deduplicated across segments). - """ if not segments: return [] @@ -267,41 +54,57 @@ class NerEngine: if not text or not text.strip(): continue for entity in self.extract(text): - self._merge_entity_into_collection(entity, segment_id, all_entities) + _merge_entity_into_collection(entity, segment_id, all_entities) return list(all_entities.values()) async def extract_async(self, text: str) -> list[NamedEntity]: - """Extract entities asynchronously using executor. - - Offloads blocking extraction to a thread pool executor to avoid - blocking the asyncio event loop. - - Args: - text: Input text to analyze. - - Returns: - List of extracted entities. - """ return await self._run_in_executor(partial(self.extract, text)) async def extract_from_segments_async( self, segments: list[tuple[int, str]], ) -> list[NamedEntity]: - """Extract entities from segments asynchronously. - - Args: - segments: List of (segment_id, text) tuples. - - Returns: - Entities with segment_ids populated. - """ return await self._run_in_executor(partial(self.extract_from_segments, segments)) async def _run_in_executor( self, func: partial[list[NamedEntity]], ) -> list[NamedEntity]: - """Run sync extraction in thread pool executor.""" return await asyncio.get_running_loop().run_in_executor(None, func) + + @property + def backend(self) -> NerBackend: + return self._backend + + +def _merge_entity_into_collection( + entity: NamedEntity, + segment_id: int, + all_entities: dict[str, NamedEntity], +) -> None: + key = entity.normalized_text + if key not in all_entities: + entity.segment_ids = [segment_id] + all_entities[key] = entity + return + + existing = all_entities[key] + existing.merge_segments([segment_id]) + if entity.confidence > existing.confidence: + existing.confidence = entity.confidence + existing.category = entity.category + if len(entity.text) > len(existing.text): + existing.text = entity.text + + +def _dedupe_by_normalized_text(entities: list[NamedEntity]) -> list[NamedEntity]: + seen: set[str] = set() + result: list[NamedEntity] = [] + + for entity in entities: + if entity.normalized_text not in seen: + seen.add(entity.normalized_text) + result.append(entity) + + return result diff --git a/src/noteflow/infrastructure/ner/mapper.py b/src/noteflow/infrastructure/ner/mapper.py new file mode 100644 index 0000000..e0c32fa --- /dev/null +++ b/src/noteflow/infrastructure/ner/mapper.py @@ -0,0 +1,65 @@ +"""Mapper from RawEntity to domain NamedEntity.""" + +from __future__ import annotations + +from typing import Final + +from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity +from noteflow.infrastructure.ner.backends.types import RawEntity + +DEFAULT_CONFIDENCE: Final[float] = 0.8 + +LABEL_TO_CATEGORY: Final[dict[str, EntityCategory]] = { + EntityCategory.PERSON.value: EntityCategory.PERSON, + "org": EntityCategory.COMPANY, + EntityCategory.COMPANY.value: EntityCategory.COMPANY, + EntityCategory.PRODUCT.value: EntityCategory.PRODUCT, + "app": EntityCategory.PRODUCT, + EntityCategory.LOCATION.value: EntityCategory.LOCATION, + "gpe": EntityCategory.LOCATION, + "loc": EntityCategory.LOCATION, + "fac": EntityCategory.LOCATION, + EntityCategory.DATE.value: EntityCategory.DATE, + "time": EntityCategory.DATE, + EntityCategory.TIME_RELATIVE.value: EntityCategory.TIME_RELATIVE, + EntityCategory.DURATION.value: EntityCategory.DURATION, + EntityCategory.EVENT.value: EntityCategory.EVENT, + EntityCategory.TASK.value: EntityCategory.TASK, + EntityCategory.DECISION.value: EntityCategory.DECISION, + EntityCategory.TOPIC.value: EntityCategory.TOPIC, + "work_of_art": EntityCategory.PRODUCT, + EntityCategory.TECHNICAL.value: EntityCategory.TECHNICAL, + EntityCategory.ACRONYM.value: EntityCategory.ACRONYM, + "norp": EntityCategory.OTHER, + "money": EntityCategory.OTHER, + "percent": EntityCategory.OTHER, + "cardinal": EntityCategory.OTHER, + "ordinal": EntityCategory.OTHER, + "quantity": EntityCategory.OTHER, + "law": EntityCategory.OTHER, + "language": EntityCategory.OTHER, +} + + +def map_label_to_category(label: str) -> EntityCategory: + """Map a raw label string to EntityCategory.""" + normalized = label.lower().strip() + return LABEL_TO_CATEGORY.get(normalized, EntityCategory.OTHER) + + +def map_raw_to_named(raw: RawEntity) -> NamedEntity: + """Convert a single RawEntity to a NamedEntity.""" + confidence = raw.confidence if raw.confidence is not None else DEFAULT_CONFIDENCE + category = map_label_to_category(raw.label) + + return NamedEntity.create( + text=raw.text, + category=category, + segment_ids=[], + confidence=confidence, + ) + + +def map_raw_entities_to_named(entities: list[RawEntity]) -> list[NamedEntity]: + """Convert a list of RawEntity to NamedEntity objects.""" + return [map_raw_to_named(raw) for raw in entities] diff --git a/src/noteflow/infrastructure/ner/post_processing.py b/src/noteflow/infrastructure/ner/post_processing.py new file mode 100644 index 0000000..d0b80eb --- /dev/null +++ b/src/noteflow/infrastructure/ner/post_processing.py @@ -0,0 +1,243 @@ +"""Shared post-processing pipeline for NER backends.""" + +from __future__ import annotations + +import re +from typing import Final + +from noteflow.infrastructure.ner.backends.types import RawEntity + +PROFANITY_WORDS: Final[frozenset[str]] = frozenset( + { + "shit", + "fuck", + "damn", + "crap", + "ass", + "hell", + "bitch", + } +) + +DURATION_UNITS: Final[frozenset[str]] = frozenset( + { + "second", + "seconds", + "minute", + "minutes", + "hour", + "hours", + "day", + "days", + "week", + "weeks", + "month", + "months", + "year", + "years", + } +) + +TIME_LABELS: Final[frozenset[str]] = frozenset( + { + "time", + "time_relative", + "date", + } +) + +MIN_ENTITY_LENGTH: Final[int] = 2 + + +def normalize_text(text: str) -> str: + """Normalize entity text for comparison and deduplication.""" + normalized = text.lower().strip() + normalized = re.sub(r"\s+", " ", normalized) + normalized = re.sub(r"['\"]", "", normalized) + return normalized + + +def _entity_key(entity: RawEntity) -> str: + return f"{normalize_text(entity.text)}:{entity.label}" + + +def _should_replace_existing(existing: RawEntity, new: RawEntity) -> bool: + if new.confidence is None: + return False + if existing.confidence is None: + return True + return new.confidence > existing.confidence + + +def dedupe_entities(entities: list[RawEntity]) -> list[RawEntity]: + """Remove duplicate entities, keeping highest confidence version.""" + seen: dict[str, RawEntity] = {} + + for entity in entities: + key = _entity_key(entity) + if key not in seen or _should_replace_existing(seen[key], entity): + seen[key] = entity + + return list(seen.values()) + + +def is_low_signal_entity(entity: RawEntity) -> bool: + normalized = normalize_text(entity.text) + + if len(normalized) <= MIN_ENTITY_LENGTH: + return True + + if normalized.isdigit(): + return True + + if normalized in PROFANITY_WORDS: + return True + + return False + + +def drop_low_signal_entities( + entities: list[RawEntity], + original_text: str, +) -> list[RawEntity]: + _ = original_text # Reserved for future context-aware filtering + return [e for e in entities if not is_low_signal_entity(e)] + + +NUMBER_WORDS: Final[frozenset[str]] = frozenset( + { + "one", + "two", + "three", + "four", + "five", + "six", + "seven", + "eight", + "nine", + "ten", + "a", + "an", + "few", + "several", + "couple", + "half", + } +) + + +def _is_numeric_duration(text: str) -> bool: + words = text.lower().split() + for i, word in enumerate(words): + if word not in DURATION_UNITS: + continue + if i == 0: + continue + prev_word = words[i - 1] + if prev_word.isdigit() or prev_word in NUMBER_WORDS: + return True + return False + + +def infer_duration(entities: list[RawEntity]) -> list[RawEntity]: + """Relabel entities containing duration units as 'duration'.""" + result: list[RawEntity] = [] + + for entity in entities: + if _is_numeric_duration(entity.text): + result.append( + RawEntity( + text=entity.text, + label="duration", + start=entity.start, + end=entity.end, + confidence=entity.confidence, + ) + ) + else: + result.append(entity) + + return result + + +def _are_adjacent(e1: RawEntity, e2: RawEntity, original_text: str) -> bool: + """Check if two entities are adjacent in text (within 3 chars).""" + gap_start = min(e1.end, e2.end) + gap_end = max(e1.start, e2.start) + + if gap_end <= gap_start: + return True + + gap_text = original_text[gap_start:gap_end] + return len(gap_text.strip()) <= 1 + + +def _average_confidence(conf_a: float | None, conf_b: float | None) -> float | None: + if conf_a is not None and conf_b is not None: + return (conf_a + conf_b) / 2 + return conf_a if conf_a is not None else conf_b + + +def _merge_two_entities( + current: RawEntity, + next_entity: RawEntity, + original_text: str, +) -> RawEntity: + merged_start = min(current.start, next_entity.start) + merged_end = max(current.end, next_entity.end) + return RawEntity( + text=original_text[merged_start:merged_end], + label=current.label, + start=merged_start, + end=merged_end, + confidence=_average_confidence(current.confidence, next_entity.confidence), + ) + + +def _merge_adjacent_time_entities( + time_entities: list[RawEntity], + original_text: str, +) -> list[RawEntity]: + time_entities.sort(key=lambda e: e.start) + merged: list[RawEntity] = [] + current = time_entities[0] + + for next_entity in time_entities[1:]: + if _are_adjacent(current, next_entity, original_text): + current = _merge_two_entities(current, next_entity, original_text) + else: + merged.append(current) + current = next_entity + + merged.append(current) + return merged + + +def merge_time_phrases( + entities: list[RawEntity], + original_text: str, +) -> list[RawEntity]: + """Merge adjacent time-like entities into single spans.""" + if not entities: + return [] + + time_entities = [e for e in entities if e.label in TIME_LABELS] + other_entities = [e for e in entities if e.label not in TIME_LABELS] + + if len(time_entities) <= 1: + return entities + + merged_time = _merge_adjacent_time_entities(time_entities, original_text) + return other_entities + merged_time + + +def post_process( + entities: list[RawEntity], + original_text: str, +) -> list[RawEntity]: + """Apply full post-processing pipeline to extracted entities.""" + entities = drop_low_signal_entities(entities, original_text) + entities = infer_duration(entities) + entities = merge_time_phrases(entities, original_text) + entities = dedupe_entities(entities) + return entities diff --git a/src/noteflow/infrastructure/persistence/migrations/versions/r2s3t4u5v6w7_fix_segment_ids_jsonb.py b/src/noteflow/infrastructure/persistence/migrations/versions/r2s3t4u5v6w7_fix_segment_ids_jsonb.py index ee3a465..7c38663 100644 --- a/src/noteflow/infrastructure/persistence/migrations/versions/r2s3t4u5v6w7_fix_segment_ids_jsonb.py +++ b/src/noteflow/infrastructure/persistence/migrations/versions/r2s3t4u5v6w7_fix_segment_ids_jsonb.py @@ -33,14 +33,29 @@ def _alter_segment_ids_to_int_array(table: str) -> None: op.execute( f""" ALTER TABLE noteflow.{table} - ALTER COLUMN segment_ids DROP DEFAULT, - ALTER COLUMN segment_ids TYPE INTEGER[] - USING CASE + ADD COLUMN segment_ids_temp INTEGER[] DEFAULT ARRAY[]::int[]; + """ + ) + op.execute( + f""" + UPDATE noteflow.{table} + SET segment_ids_temp = CASE WHEN jsonb_typeof(segment_ids) = 'array' - THEN ARRAY(SELECT jsonb_array_elements_text(segment_ids)::int) + THEN (SELECT array_agg(elem::int) FROM jsonb_array_elements_text(segment_ids) AS elem) ELSE ARRAY[]::int[] - END, - ALTER COLUMN segment_ids SET DEFAULT ARRAY[]::int[]; + END; + """ + ) + op.execute( + f""" + ALTER TABLE noteflow.{table} + DROP COLUMN segment_ids; + """ + ) + op.execute( + f""" + ALTER TABLE noteflow.{table} + RENAME COLUMN segment_ids_temp TO segment_ids; """ ) diff --git a/src/noteflow/infrastructure/persistence/migrations/versions/s3t4u5v6w7x8_fix_diarization_speaker_ids_jsonb.py b/src/noteflow/infrastructure/persistence/migrations/versions/s3t4u5v6w7x8_fix_diarization_speaker_ids_jsonb.py index 9a49e7e..9997312 100644 --- a/src/noteflow/infrastructure/persistence/migrations/versions/s3t4u5v6w7x8_fix_diarization_speaker_ids_jsonb.py +++ b/src/noteflow/infrastructure/persistence/migrations/versions/s3t4u5v6w7x8_fix_diarization_speaker_ids_jsonb.py @@ -35,13 +35,28 @@ def downgrade() -> None: op.execute( """ ALTER TABLE noteflow.diarization_jobs - ALTER COLUMN speaker_ids DROP DEFAULT, - ALTER COLUMN speaker_ids TYPE TEXT[] - USING CASE - WHEN jsonb_typeof(speaker_ids) = 'array' - THEN ARRAY(SELECT jsonb_array_elements_text(speaker_ids)) - ELSE ARRAY[]::text[] - END, - ALTER COLUMN speaker_ids SET DEFAULT ARRAY[]::text[]; + ADD COLUMN speaker_ids_temp TEXT[] DEFAULT ARRAY[]::text[]; + """ + ) + op.execute( + """ + UPDATE noteflow.diarization_jobs + SET speaker_ids_temp = CASE + WHEN jsonb_typeof(speaker_ids) = 'array' + THEN (SELECT array_agg(elem) FROM jsonb_array_elements_text(speaker_ids) AS elem) + ELSE ARRAY[]::text[] + END; + """ + ) + op.execute( + """ + ALTER TABLE noteflow.diarization_jobs + DROP COLUMN speaker_ids; + """ + ) + op.execute( + """ + ALTER TABLE noteflow.diarization_jobs + RENAME COLUMN speaker_ids_temp TO speaker_ids; """ ) diff --git a/src/noteflow/infrastructure/security/crypto/_base.py b/src/noteflow/infrastructure/security/crypto/_base.py new file mode 100644 index 0000000..a86ab4a --- /dev/null +++ b/src/noteflow/infrastructure/security/crypto/_base.py @@ -0,0 +1,26 @@ +"""Base classes for crypto module.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Self + + +class ContextManagedClosable(ABC): + """Mixin for classes with close() that support context manager protocol.""" + + @abstractmethod + def close(self) -> None: + """Close the resource.""" + ... + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: object, + ) -> None: + self.close() diff --git a/src/noteflow/infrastructure/security/crypto/_reader.py b/src/noteflow/infrastructure/security/crypto/_reader.py index 12d49cb..1e57c53 100644 --- a/src/noteflow/infrastructure/security/crypto/_reader.py +++ b/src/noteflow/infrastructure/security/crypto/_reader.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, BinaryIO from noteflow.infrastructure.logging import get_logger from noteflow.infrastructure.security.protocols import EncryptedChunk +from ._base import ContextManagedClosable from ._binary_io import read_exact from ._constants import FILE_MAGIC, FILE_VERSION, MIN_CHUNK_LENGTH, NONCE_SIZE, TAG_SIZE @@ -19,7 +20,7 @@ if TYPE_CHECKING: logger = get_logger(__name__) -class ChunkedAssetReader: +class ChunkedAssetReader(ContextManagedClosable): """Streaming encrypted asset reader.""" def __init__(self, crypto: AesGcmCryptoBox) -> None: diff --git a/src/noteflow/infrastructure/security/crypto/_writer.py b/src/noteflow/infrastructure/security/crypto/_writer.py index 538b0d1..e51b555 100644 --- a/src/noteflow/infrastructure/security/crypto/_writer.py +++ b/src/noteflow/infrastructure/security/crypto/_writer.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, BinaryIO from noteflow.infrastructure.logging import get_logger +from ._base import ContextManagedClosable from ._constants import FILE_MAGIC, FILE_VERSION if TYPE_CHECKING: @@ -16,7 +17,7 @@ if TYPE_CHECKING: logger = get_logger(__name__) -class ChunkedAssetWriter: +class ChunkedAssetWriter(ContextManagedClosable): """Streaming encrypted asset writer. File format: @@ -56,9 +57,14 @@ class ChunkedAssetWriter: self._handle = path.open("wb") self._bytes_written = 0 - # Write header - self._handle.write(FILE_MAGIC) - self._handle.write(struct.pack("B", FILE_VERSION)) + try: + self._handle.write(FILE_MAGIC) + self._handle.write(struct.pack("B", FILE_VERSION)) + except OSError: + self._handle.close() + self._handle = None + self._dek = None + raise logger.debug("Opened encrypted file for writing: %s", path) diff --git a/tests/application/test_asr_config_service.py b/tests/application/test_asr_config_service.py index 5a2c2fb..f725077 100644 --- a/tests/application/test_asr_config_service.py +++ b/tests/application/test_asr_config_service.py @@ -66,7 +66,9 @@ def test_get_capabilities_returns_current_config( asr_config_service: AsrConfigService, ) -> None: """get_capabilities returns current ASR configuration.""" - with patch.object(asr_config_service.engine_manager, "detect_cuda_available", return_value=False): + with patch.object( + asr_config_service.engine_manager, "detect_cuda_available", return_value=False + ): caps = asr_config_service.get_capabilities() assert caps.model_size == "base", "model_size should be 'base' from engine" @@ -95,9 +97,12 @@ def test_get_capabilities_with_cuda_available( mock_asr_engine: MagicMock, ) -> None: """get_capabilities includes CUDA compute types when available.""" + from noteflow.domain.ports.gpu import GpuBackend + mock_asr_engine.device = "cuda" - with patch.object( - asr_config_service.engine_manager, "detect_cuda_available", return_value=True + with patch( + "noteflow.application.services.asr_config._engine_manager.detect_gpu_backend", + return_value=GpuBackend.CUDA, ): caps = asr_config_service.get_capabilities() @@ -117,7 +122,9 @@ def test_validate_configuration_valid_cpu_config( asr_config_service: AsrConfigService, ) -> None: """validate_configuration accepts valid CPU configuration.""" - with patch.object(asr_config_service.engine_manager, "detect_cuda_available", return_value=False): + with patch.object( + asr_config_service.engine_manager, "detect_cuda_available", return_value=False + ): error = asr_config_service.validate_configuration( model_size="small", device=AsrDevice.CPU, @@ -145,7 +152,9 @@ def test_validate_configuration_cuda_unavailable( asr_config_service: AsrConfigService, ) -> None: """validate_configuration rejects CUDA when unavailable.""" - with patch.object(asr_config_service.engine_manager, "detect_cuda_available", return_value=False): + with patch.object( + asr_config_service.engine_manager, "detect_cuda_available", return_value=False + ): error = asr_config_service.validate_configuration( model_size=None, device=AsrDevice.CUDA, @@ -193,7 +202,9 @@ async def test_start_reconfiguration_returns_job_id( asr_config_service: AsrConfigService, ) -> None: """start_reconfiguration returns job ID on success.""" - with patch.object(asr_config_service.engine_manager, "detect_cuda_available", return_value=False): + with patch.object( + asr_config_service.engine_manager, "detect_cuda_available", return_value=False + ): job_id, error = await asr_config_service.start_reconfiguration( model_size="small", device=None, @@ -253,7 +264,9 @@ async def test_start_reconfiguration_validation_failure( asr_config_service: AsrConfigService, ) -> None: """start_reconfiguration fails on invalid configuration.""" - with patch.object(asr_config_service.engine_manager, "detect_cuda_available", return_value=False): + with patch.object( + asr_config_service.engine_manager, "detect_cuda_available", return_value=False + ): job_id, error = await asr_config_service.start_reconfiguration( model_size="invalid-model", device=None, @@ -276,7 +289,9 @@ async def test_get_job_status_returns_job( asr_config_service: AsrConfigService, ) -> None: """get_job_status returns job info for valid ID.""" - with patch.object(asr_config_service.engine_manager, "detect_cuda_available", return_value=False): + with patch.object( + asr_config_service.engine_manager, "detect_cuda_available", return_value=False + ): job_id, _ = await asr_config_service.start_reconfiguration( model_size="small", device=None, @@ -317,10 +332,12 @@ def test_detect_cuda_available_with_cuda( asr_config_service: AsrConfigService, ) -> None: """detect_cuda_available returns True when CUDA is available.""" - mock_torch = MagicMock() - mock_torch.cuda.is_available.return_value = True + from noteflow.domain.ports.gpu import GpuBackend - with patch.dict("sys.modules", {"torch": mock_torch}): + with patch( + "noteflow.application.services.asr_config._engine_manager.detect_gpu_backend", + return_value=GpuBackend.CUDA, + ): result = asr_config_service.detect_cuda_available() assert result is True, "detect_cuda_available should return True when CUDA available" @@ -330,10 +347,12 @@ def test_detect_cuda_available_no_cuda( asr_config_service: AsrConfigService, ) -> None: """detect_cuda_available returns False when CUDA is not available.""" - mock_torch = MagicMock() - mock_torch.cuda.is_available.return_value = False + from noteflow.domain.ports.gpu import GpuBackend - with patch.dict("sys.modules", {"torch": mock_torch}): + with patch( + "noteflow.application.services.asr_config._engine_manager.detect_gpu_backend", + return_value=GpuBackend.NONE, + ): result = asr_config_service.detect_cuda_available() assert result is False, "detect_cuda_available should return False when CUDA unavailable" @@ -347,7 +366,9 @@ def test_detect_cuda_available_no_cuda( @pytest.mark.asyncio async def test_reconfiguration_failure_keeps_active_engine(mock_asr_engine: MagicMock) -> None: """Reconfiguration failure should not replace or unload the active engine.""" - updates: list[MagicMock] = [] + from noteflow.infrastructure.asr.engine import FasterWhisperEngine + + updates: list[FasterWhisperEngine] = [] service = AsrConfigService(asr_engine=mock_asr_engine, on_engine_update=updates.append) mgr = service.engine_manager @@ -356,6 +377,7 @@ async def test_reconfiguration_failure_keeps_active_engine(mock_asr_engine: Magi patch.object(mgr, "load_model", side_effect=RuntimeError("boom")), ): job_id, _ = await service.start_reconfiguration("small", None, None, False) + assert job_id is not None, "job_id should not be None" job = service.get_job_status(job_id) assert job is not None and job.task is not None, "job should be created with task" await job.task @@ -368,7 +390,9 @@ async def test_reconfiguration_failure_keeps_active_engine(mock_asr_engine: Magi @pytest.mark.asyncio async def test_reconfiguration_success_swaps_engine(mock_asr_engine: MagicMock) -> None: """Successful reconfiguration should swap engine and unload the old one.""" - updates: list[MagicMock] = [] + from noteflow.infrastructure.asr.engine import FasterWhisperEngine + + updates: list[FasterWhisperEngine] = [] service = AsrConfigService(asr_engine=mock_asr_engine, on_engine_update=updates.append) new_engine, mgr = MagicMock(), service.engine_manager @@ -377,6 +401,7 @@ async def test_reconfiguration_success_swaps_engine(mock_asr_engine: MagicMock) patch.object(mgr, "load_model", return_value=None), ): job_id, _ = await service.start_reconfiguration("small", None, None, False) + assert job_id is not None, "job_id should not be None" job = service.get_job_status(job_id) assert job is not None and job.task is not None, "job should be created with task" await job.task diff --git a/tests/infrastructure/audio/test_partial_buffer.py b/tests/infrastructure/audio/test_partial_buffer.py index 1e846a9..8cd5cd0 100644 --- a/tests/infrastructure/audio/test_partial_buffer.py +++ b/tests/infrastructure/audio/test_partial_buffer.py @@ -19,14 +19,19 @@ class TestPartialAudioBufferInit: """Tests for buffer initialization.""" def test_default_capacity(self) -> None: - """Buffer should have 5 seconds capacity by default at 16kHz.""" + """Buffer should have 10 seconds capacity by default at 16kHz.""" buffer = PartialAudioBuffer() - assert buffer.capacity_samples == 5 * SAMPLE_RATE_16K, "Default capacity should be 5 seconds" + expected_capacity = int(PartialAudioBuffer.DEFAULT_MAX_DURATION * SAMPLE_RATE_16K) + assert buffer.capacity_samples == expected_capacity, ( + "Default capacity should match DEFAULT_MAX_DURATION" + ) def test_custom_duration(self) -> None: """Buffer should respect custom max duration.""" buffer = PartialAudioBuffer(max_duration_seconds=3.0, sample_rate=SAMPLE_RATE_16K) - assert buffer.capacity_samples == 3 * SAMPLE_RATE_16K, "Capacity should match custom duration" + assert buffer.capacity_samples == 3 * SAMPLE_RATE_16K, ( + "Capacity should match custom duration" + ) def test_custom_sample_rate(self) -> None: """Buffer should respect custom sample rate.""" diff --git a/tests/infrastructure/ner/test_engine.py b/tests/infrastructure/ner/test_engine.py index 35e7121..58c86da 100644 --- a/tests/infrastructure/ner/test_engine.py +++ b/tests/infrastructure/ner/test_engine.py @@ -1,4 +1,4 @@ -"""Tests for NER engine (spaCy wrapper).""" +"""Tests for NER engine (composition-based).""" from __future__ import annotations @@ -6,26 +6,32 @@ import pytest from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity from noteflow.infrastructure.ner import NerEngine +from noteflow.infrastructure.ner.backends.spacy_backend import SpacyBackend @pytest.fixture(scope="module") -def ner_engine() -> NerEngine: - """Create NER engine (module-scoped to avoid repeated model loads).""" - return NerEngine(model_name="en_core_web_sm") +def spacy_backend() -> SpacyBackend: + """Create spaCy backend (module-scoped to avoid repeated model loads).""" + return SpacyBackend(model_name="en_core_web_sm") + + +@pytest.fixture(scope="module") +def ner_engine(spacy_backend: SpacyBackend) -> NerEngine: + """Create NER engine with spaCy backend.""" + return NerEngine(backend=spacy_backend) class TestNerEngineBasics: """Basic NER engine functionality tests.""" def test_is_ready_before_load(self) -> None: - """Engine is not ready before first use.""" - engine = NerEngine() + backend = SpacyBackend() + engine = NerEngine(backend=backend) assert not engine.is_ready(), "Engine should not be ready before model is loaded" def test_is_ready_after_extract(self, ner_engine: NerEngine) -> None: - """Engine is ready after extraction triggers lazy load.""" ner_engine.extract("Hello, John.") - assert ner_engine.is_ready(), "Engine should be ready after extraction triggers lazy load" + assert ner_engine.is_ready(), "Engine should be ready after extraction" class TestEntityExtraction: @@ -58,7 +64,9 @@ class TestEntityExtraction: """Extract returns a list of NamedEntity objects.""" entities = ner_engine.extract("John works at Google.") assert isinstance(entities, list), "extract() should return a list" - assert all(isinstance(e, NamedEntity) for e in entities), "All items should be NamedEntity instances" + assert all(isinstance(e, NamedEntity) for e in entities), ( + "All items should be NamedEntity instances" + ) def test_extract_empty_text_returns_empty(self, ner_engine: NerEngine) -> None: """Empty text returns empty list.""" @@ -68,7 +76,6 @@ class TestEntityExtraction: def test_extract_no_entities_returns_empty(self, ner_engine: NerEngine) -> None: """Text with no entities returns empty list.""" entities = ner_engine.extract("The quick brown fox.") - # May still find entities depending on model, but should not crash assert isinstance(entities, list), "Result should be a list even with no entities" @@ -84,7 +91,6 @@ class TestSegmentExtraction: ] entities = ner_engine.extract_from_segments(segments) - # Mary should appear in multiple segments mary_entities = [e for e in entities if "mary" in e.normalized_text] assert mary_entities, "Mary entity not found" mary = mary_entities[0] @@ -98,7 +104,6 @@ class TestSegmentExtraction: ] entities = ner_engine.extract_from_segments(segments) - # Should have one John Smith entity (deduplicated by normalized text) john_entities = [e for e in entities if "john" in e.normalized_text] assert len(john_entities) == 1, "John Smith should be deduplicated" assert len(john_entities[0].segment_ids) == 2, "Should track both segments" @@ -116,11 +121,15 @@ class TestEntityNormalization: """Normalized text should be lowercase.""" entities = ner_engine.extract("John SMITH went to NYC.") non_lowercase = [e for e in entities if e.normalized_text != e.normalized_text.lower()] - assert not non_lowercase, f"All normalized text should be lowercase, but found: {[e.normalized_text for e in non_lowercase]}" + assert not non_lowercase, ( + f"All normalized text should be lowercase, but found: {[e.normalized_text for e in non_lowercase]}" + ) def test_confidence_is_set(self, ner_engine: NerEngine) -> None: """Entities should have confidence score.""" entities = ner_engine.extract("Microsoft Corporation is based in Seattle.") assert entities, "Should find entities" invalid_confidence = [e for e in entities if not (0.0 <= e.confidence <= 1.0)] - assert not invalid_confidence, f"All entities should have confidence between 0 and 1, but found: {[(e.text, e.confidence) for e in invalid_confidence]}" + assert not invalid_confidence, ( + f"All entities should have confidence between 0 and 1, but found: {[(e.text, e.confidence) for e in invalid_confidence]}" + ) diff --git a/tests/infrastructure/ner/test_gliner_backend.py b/tests/infrastructure/ner/test_gliner_backend.py new file mode 100644 index 0000000..a9d6299 --- /dev/null +++ b/tests/infrastructure/ner/test_gliner_backend.py @@ -0,0 +1,219 @@ +"""Tests for GLiNER NER backend.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from noteflow.infrastructure.ner.backends.gliner_backend import ( + DEFAULT_MODEL, + DEFAULT_THRESHOLD, + MEETING_LABELS, + GLiNERBackend, +) +from noteflow.infrastructure.ner.backends.types import RawEntity + + +def _create_backend_with_mock_model( + model_name: str = DEFAULT_MODEL, + labels: tuple[str, ...] = MEETING_LABELS, + threshold: float = DEFAULT_THRESHOLD, + mock_predictions: list[dict[str, object]] | None = None, +) -> GLiNERBackend: + backend = GLiNERBackend(model_name=model_name, labels=labels, threshold=threshold) + mock_model = MagicMock() + mock_model.predict_entities = MagicMock(return_value=mock_predictions or []) + object.__setattr__(backend, "_model", mock_model) + return backend + + +class TestGLiNERBackendInit: + def test_default_model_name(self) -> None: + backend = GLiNERBackend() + assert backend.model_name == DEFAULT_MODEL + + def test_default_labels(self) -> None: + backend = GLiNERBackend() + assert backend.labels == MEETING_LABELS + + def test_default_threshold(self) -> None: + backend = GLiNERBackend() + assert backend.threshold == DEFAULT_THRESHOLD + + def test_init_with_custom_model_name(self) -> None: + backend = GLiNERBackend(model_name="custom/model") + assert backend.model_name == "custom/model" + + def test_custom_labels(self) -> None: + custom_labels = ("person", "location") + backend = GLiNERBackend(labels=custom_labels) + assert backend.labels == custom_labels + + def test_custom_threshold(self) -> None: + backend = GLiNERBackend(threshold=0.7) + assert backend.threshold == 0.7 + + def test_not_loaded_initially(self) -> None: + backend = GLiNERBackend() + assert not backend.model_loaded() + + +class TestGLiNERBackendModelState: + def test_model_loaded_returns_true_when_model_set(self) -> None: + backend = _create_backend_with_mock_model() + assert backend.model_loaded() + + def test_model_loaded_returns_false_initially(self) -> None: + backend = GLiNERBackend() + assert not backend.model_loaded() + + +class TestGLiNERBackendExtraction: + def test_extract_empty_string_returns_empty(self) -> None: + backend = GLiNERBackend() + assert backend.extract("") == [] + + def test_extract_whitespace_only_returns_empty(self) -> None: + backend = GLiNERBackend() + assert backend.extract(" ") == [] + + def test_extract_returns_correct_count(self) -> None: + backend = _create_backend_with_mock_model( + mock_predictions=[ + {"text": "John", "label": "PERSON", "start": 0, "end": 4, "score": 0.95}, + {"text": "NYC", "label": "location", "start": 12, "end": 15, "score": 0.88}, + ] + ) + entities = backend.extract("John lives in NYC") + assert len(entities) == 2 + + def test_extract_returns_raw_entity_type(self) -> None: + backend = _create_backend_with_mock_model( + mock_predictions=[ + {"text": "John", "label": "PERSON", "start": 0, "end": 4, "score": 0.95}, + ] + ) + entities = backend.extract("John") + assert isinstance(entities[0], RawEntity) + + def test_extract_normalizes_labels_to_lowercase(self) -> None: + backend = _create_backend_with_mock_model( + mock_predictions=[ + {"text": "John", "label": "PERSON", "start": 0, "end": 4, "score": 0.9}, + ] + ) + entities = backend.extract("John") + assert entities[0].label == "person" + + def test_extract_includes_confidence_score(self) -> None: + backend = _create_backend_with_mock_model( + mock_predictions=[ + {"text": "Meeting", "label": "event", "start": 0, "end": 7, "score": 0.85}, + ] + ) + entities = backend.extract("Meeting tomorrow") + assert entities[0].confidence == 0.85 + + def test_extract_passes_threshold_to_predict(self) -> None: + backend = _create_backend_with_mock_model(threshold=0.75) + backend.extract("Some text") + mock_model = getattr(backend, "_model") + predict_entities = getattr(mock_model, "predict_entities") + call_kwargs = predict_entities.call_args + assert call_kwargs[1]["threshold"] == 0.75 + + def test_extract_passes_labels_to_predict(self) -> None: + custom_labels = ("person", "task") + backend = _create_backend_with_mock_model(labels=custom_labels) + backend.extract("Some text") + mock_model = getattr(backend, "_model") + predict_entities = getattr(mock_model, "predict_entities") + call_kwargs = predict_entities.call_args + assert call_kwargs[1]["labels"] == ["person", "task"] + + @pytest.mark.parametrize( + ("text", "label", "start", "end", "score"), + [ + ("John Smith", "person", 0, 10, 0.92), + ("New York", "location", 5, 13, 0.88), + ("decision to proceed", "decision", 0, 19, 0.75), + ], + ) + def test_extract_entity_text_mapping( + self, text: str, label: str, start: int, end: int, score: float + ) -> None: + backend = _create_backend_with_mock_model( + mock_predictions=[ + {"text": text, "label": label.upper(), "start": start, "end": end, "score": score}, + ] + ) + entities = backend.extract("input text") + assert entities[0].text == text + + @pytest.mark.parametrize( + ("text", "label", "start", "end", "score"), + [ + ("John Smith", "person", 0, 10, 0.92), + ("New York", "location", 5, 13, 0.88), + ("decision to proceed", "decision", 0, 19, 0.75), + ], + ) + def test_extract_entity_label_mapping( + self, text: str, label: str, start: int, end: int, score: float + ) -> None: + backend = _create_backend_with_mock_model( + mock_predictions=[ + {"text": text, "label": label.upper(), "start": start, "end": end, "score": score}, + ] + ) + entities = backend.extract("input text") + assert entities[0].label == label + + @pytest.mark.parametrize( + ("text", "label", "start", "end", "score"), + [ + ("John Smith", "person", 0, 10, 0.92), + ("New York", "location", 5, 13, 0.88), + ("decision to proceed", "decision", 0, 19, 0.75), + ], + ) + def test_extract_entity_position_mapping( + self, text: str, label: str, start: int, end: int, score: float + ) -> None: + backend = _create_backend_with_mock_model( + mock_predictions=[ + {"text": text, "label": label.upper(), "start": start, "end": end, "score": score}, + ] + ) + entities = backend.extract("input text") + assert (entities[0].start, entities[0].end) == (start, end) + + @pytest.mark.parametrize( + ("text", "label", "start", "end", "score"), + [ + ("John Smith", "person", 0, 10, 0.92), + ("New York", "location", 5, 13, 0.88), + ("decision to proceed", "decision", 0, 19, 0.75), + ], + ) + def test_extract_entity_confidence_mapping( + self, text: str, label: str, start: int, end: int, score: float + ) -> None: + backend = _create_backend_with_mock_model( + mock_predictions=[ + {"text": text, "label": label.upper(), "start": start, "end": end, "score": score}, + ] + ) + entities = backend.extract("input text") + assert entities[0].confidence == score + + +class TestGLiNERBackendMeetingLabels: + def test_meeting_labels_include_core_categories(self) -> None: + expected_labels = {"person", "org", "product", "app", "location", "time"} + assert expected_labels.issubset(set(MEETING_LABELS)) + + def test_meeting_labels_include_meeting_specific_categories(self) -> None: + expected_labels = {"task", "decision", "topic", "event"} + assert expected_labels.issubset(set(MEETING_LABELS)) diff --git a/tests/infrastructure/ner/test_mapper.py b/tests/infrastructure/ner/test_mapper.py new file mode 100644 index 0000000..5220066 --- /dev/null +++ b/tests/infrastructure/ner/test_mapper.py @@ -0,0 +1,100 @@ +"""Tests for NER entity mapper.""" + +from __future__ import annotations + +import pytest + +from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity +from noteflow.infrastructure.ner.backends.types import RawEntity +from noteflow.infrastructure.ner.mapper import ( + DEFAULT_CONFIDENCE, + map_label_to_category, + map_raw_entities_to_named, + map_raw_to_named, +) + + +class TestMapLabelToCategory: + """Tests for label to category mapping.""" + + @pytest.mark.parametrize( + ("label", "expected"), + [ + ("person", EntityCategory.PERSON), + ("PERSON", EntityCategory.PERSON), + ("org", EntityCategory.COMPANY), + ("company", EntityCategory.COMPANY), + ("product", EntityCategory.PRODUCT), + ("app", EntityCategory.PRODUCT), + ("location", EntityCategory.LOCATION), + ("gpe", EntityCategory.LOCATION), + ("loc", EntityCategory.LOCATION), + ("fac", EntityCategory.LOCATION), + ("date", EntityCategory.DATE), + ("time", EntityCategory.DATE), + ("time_relative", EntityCategory.TIME_RELATIVE), + ("duration", EntityCategory.DURATION), + ("event", EntityCategory.EVENT), + ("task", EntityCategory.TASK), + ("decision", EntityCategory.DECISION), + ("topic", EntityCategory.TOPIC), + ("work_of_art", EntityCategory.PRODUCT), + ("unknown_label", EntityCategory.OTHER), + ], + ) + def test_label_mapping(self, label: str, expected: EntityCategory) -> None: + assert map_label_to_category(label) == expected + + +class TestMapRawToNamed: + """Tests for single entity mapping.""" + + def test_maps_text_and_category(self) -> None: + raw = RawEntity(text="John Smith", label="person", start=0, end=10, confidence=0.9) + result = map_raw_to_named(raw) + + assert result.text == "John Smith", "Text should be preserved" + assert result.category == EntityCategory.PERSON, "Category should be mapped" + + def test_uses_provided_confidence(self) -> None: + raw = RawEntity(text="Apple", label="product", start=0, end=5, confidence=0.85) + result = map_raw_to_named(raw) + + assert result.confidence == 0.85, "Should use provided confidence" + + def test_uses_default_confidence_when_none(self) -> None: + raw = RawEntity(text="Paris", label="location", start=0, end=5, confidence=None) + result = map_raw_to_named(raw) + + assert result.confidence == DEFAULT_CONFIDENCE, "Should use default confidence when None" + + def test_segment_ids_empty_by_default(self) -> None: + raw = RawEntity(text="Monday", label="date", start=0, end=6, confidence=0.8) + result = map_raw_to_named(raw) + + assert result.segment_ids == [], "Segment IDs should be empty by default" + + def test_result_is_named_entity(self) -> None: + raw = RawEntity(text="Google", label="company", start=0, end=6, confidence=0.9) + result = map_raw_to_named(raw) + + assert isinstance(result, NamedEntity), "Result should be NamedEntity" + + +class TestMapRawEntitiesToNamed: + """Tests for batch entity mapping.""" + + def test_maps_multiple_entities(self) -> None: + raw_entities = [ + RawEntity(text="John", label="person", start=0, end=4, confidence=0.9), + RawEntity(text="Google", label="company", start=5, end=11, confidence=0.85), + RawEntity(text="New York", label="location", start=12, end=20, confidence=0.95), + ] + result = map_raw_entities_to_named(raw_entities) + + assert len(result) == 3, "Should map all entities" + assert all(isinstance(e, NamedEntity) for e in result), "All should be NamedEntity" + + def test_empty_list_returns_empty(self) -> None: + result = map_raw_entities_to_named([]) + assert result == [], "Empty input should return empty output" diff --git a/tests/infrastructure/ner/test_post_processing.py b/tests/infrastructure/ner/test_post_processing.py new file mode 100644 index 0000000..b45dff7 --- /dev/null +++ b/tests/infrastructure/ner/test_post_processing.py @@ -0,0 +1,158 @@ +"""Tests for NER post-processing pipeline.""" + +from __future__ import annotations + +import pytest + +from noteflow.infrastructure.ner.backends.types import RawEntity +from noteflow.infrastructure.ner.post_processing import ( + dedupe_entities, + drop_low_signal_entities, + infer_duration, + merge_time_phrases, + normalize_text, + post_process, +) + + +class TestNormalizeText: + """Tests for text normalization.""" + + @pytest.mark.parametrize( + ("input_text", "expected"), + [ + ("John Smith", "john smith"), + (" UPPER CASE ", "upper case"), + ("'quoted'", "quoted"), + ('"double"', "double"), + ("spaces between words", "spaces between words"), + ], + ) + def test_normalize_text(self, input_text: str, expected: str) -> None: + assert normalize_text(input_text) == expected + + +class TestDedupeEntities: + """Tests for entity deduplication.""" + + def test_deduplicates_by_normalized_text_and_label(self) -> None: + entities = [ + RawEntity(text="John", label="person", start=0, end=4, confidence=0.8), + RawEntity(text="john", label="person", start=10, end=14, confidence=0.9), + ] + result = dedupe_entities(entities) + assert len(result) == 1, "Should deduplicate same entity" + + def test_keeps_higher_confidence(self) -> None: + entities = [ + RawEntity(text="John", label="person", start=0, end=4, confidence=0.7), + RawEntity(text="john", label="person", start=10, end=14, confidence=0.9), + ] + result = dedupe_entities(entities) + assert result[0].confidence == 0.9, "Should keep higher confidence" + + def test_different_labels_not_deduped(self) -> None: + entities = [ + RawEntity(text="Apple", label="product", start=0, end=5, confidence=0.8), + RawEntity(text="Apple", label="company", start=10, end=15, confidence=0.9), + ] + result = dedupe_entities(entities) + assert len(result) == 2, "Different labels should not be deduped" + + +class TestDropLowSignalEntities: + """Tests for low-signal entity filtering.""" + + @pytest.mark.parametrize( + ("text", "label"), + [ + ("a", "person"), + ("to", "location"), + ("123", "product"), + ("shit", "topic"), + ], + ) + def test_drops_low_signal_entities(self, text: str, label: str) -> None: + entity = RawEntity(text=text, label=label, start=0, end=len(text), confidence=0.8) + result = drop_low_signal_entities([entity], "some original text") + assert len(result) == 0, f"'{text}' should be dropped" + + @pytest.mark.parametrize( + ("text", "label"), + [ + ("John", "person"), + ("New York", "location"), + ], + ) + def test_keeps_valid_entities(self, text: str, label: str) -> None: + entity = RawEntity(text=text, label=label, start=0, end=len(text), confidence=0.8) + result = drop_low_signal_entities([entity], "some original text") + assert len(result) == 1, f"'{text}' should NOT be dropped" + + +class TestInferDuration: + """Tests for duration inference.""" + + @pytest.mark.parametrize( + ("text", "expected_label"), + [ + ("20 minutes", "duration"), + ("two weeks", "duration"), + ("3 hours", "duration"), + ("next week", "time_relative"), + ("tomorrow", "time_relative"), + ], + ) + def test_infer_duration(self, text: str, expected_label: str) -> None: + entity = RawEntity(text=text, label="time_relative", start=0, end=len(text), confidence=0.8) + result = infer_duration([entity]) + assert result[0].label == expected_label, f"'{text}' should be labeled as {expected_label}" + + +class TestMergeTimePhrases: + """Tests for time phrase merging.""" + + def test_merges_adjacent_time_entities(self) -> None: + original_text = "last night we met" + entities = [ + RawEntity(text="last", label="time", start=0, end=4, confidence=0.8), + RawEntity(text="night", label="time", start=5, end=10, confidence=0.7), + ] + result = merge_time_phrases(entities, original_text) + + time_entities = [e for e in result if e.label == "time"] + assert len(time_entities) == 1, "Should merge adjacent time entities" + assert time_entities[0].text == "last night", "Merged text should be 'last night'" + + def test_does_not_merge_non_adjacent(self) -> None: + original_text = "Monday came and then Friday arrived" + entities = [ + RawEntity(text="Monday", label="time", start=0, end=6, confidence=0.8), + RawEntity(text="Friday", label="time", start=21, end=27, confidence=0.8), + ] + result = merge_time_phrases(entities, original_text) + + time_entities = [e for e in result if e.label == "time"] + assert len(time_entities) == 2, "Should not merge non-adjacent time entities" + + +class TestPostProcess: + """Tests for full post-processing pipeline.""" + + def test_full_pipeline(self) -> None: + original_text = "John met shit in Paris for 20 minutes" + entities = [ + RawEntity(text="John", label="person", start=0, end=4, confidence=0.9), + RawEntity(text="shit", label="topic", start=9, end=13, confidence=0.5), + RawEntity(text="Paris", label="location", start=17, end=22, confidence=0.95), + RawEntity(text="20 minutes", label="time", start=27, end=37, confidence=0.8), + ] + result = post_process(entities, original_text) + + texts = {e.text for e in result} + assert "shit" not in texts, "Profanity should be filtered" + assert "John" in texts, "Valid person should remain" + assert "Paris" in texts, "Valid location should remain" + + duration_entities = [e for e in result if e.label == "duration"] + assert len(duration_entities) == 1, "20 minutes should be labeled as duration" diff --git a/tests/integration/test_hf_token_grpc.py b/tests/integration/test_hf_token_grpc.py index 369e57a..754290e 100644 --- a/tests/integration/test_hf_token_grpc.py +++ b/tests/integration/test_hf_token_grpc.py @@ -233,7 +233,11 @@ class TestHuggingFaceTokenGrpc: assert delete_response.success is True, "delete token should succeed" - status = await _get_hf_status(hf_token_servicer, mock_grpc_context) + with patch( + "noteflow.application.services.huggingface.service.get_settings" + ) as mock_settings: + mock_settings.return_value.diarization_hf_token = None + status = await _get_hf_status(hf_token_servicer, mock_grpc_context) assert status.is_configured is False, "token should be cleared after delete" assert status.is_validated is False, "validation should reset after delete" diff --git a/typings/gliner/__init__.pyi b/typings/gliner/__init__.pyi new file mode 100644 index 0000000..314ff25 --- /dev/null +++ b/typings/gliner/__init__.pyi @@ -0,0 +1,35 @@ +from pathlib import Path +from typing import Self + +class GLiNER: + @classmethod + def from_pretrained( + cls, + model_id: str, + revision: str | None = None, + cache_dir: str | Path | None = None, + force_download: bool = False, + proxies: dict[str, str] | None = None, + resume_download: bool = False, + local_files_only: bool = False, + token: str | bool | None = None, + map_location: str = "cpu", + strict: bool = False, + load_tokenizer: bool | None = None, + resize_token_embeddings: bool | None = True, + compile_torch_model: bool | None = False, + load_onnx_model: bool | None = False, + onnx_model_file: str | None = "model.onnx", + max_length: int | None = None, + max_width: int | None = None, + post_fusion_schema: str | None = None, + _attn_implementation: str | None = None, + ) -> Self: ... + def predict_entities( + self, + text: str, + labels: list[str], + flat_ner: bool = True, + threshold: float = 0.5, + multi_label: bool = False, + ) -> list[dict[str, object]]: ...