feat: introduce GLiNER and spaCy backends for NER, refactor NER infrastructure, and update client-side entity extraction.
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
---
|
||||
active: true
|
||||
iteration: 3
|
||||
iteration: 1
|
||||
max_iterations: 0
|
||||
completion_promise: null
|
||||
started_at: "2026-01-20T02:31:55Z"
|
||||
started_at: "2026-01-20T02:31:55Z"
|
||||
---
|
||||
|
||||
proceed with the plan, i have also documented a copy in @.claudectx/codefixes.md. please use your agents iteratively to manage context and speed, however you must review the accuracy and value of each doc before moving to the next
|
||||
proceed with the plan, i have also documented a copy in @.claudectx/codefixes.md. please use your agents iteratively to manage context and speed, however you must review the accuracy and value of each doc before moving to the next
|
||||
|
||||
@@ -1,287 +1,377 @@
|
||||
# Entity extraction robustness notes
|
||||
╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌
|
||||
Strategic CLAUDE.md Placement Analysis for NoteFlow
|
||||
|
||||
## Library options that can do heavy lifting
|
||||
- GLiNER (Apache-2.0): generalist NER model that supports arbitrary labels without retraining; API via `GLiNER.from_pretrained(...).predict_entities(text, labels, threshold=0.5)`.
|
||||
- Source: https://github.com/urchade/GLiNER
|
||||
- Presidio (MIT): pipeline framework for detection + redaction focused on PII; includes NLP + regex recognizers, and extensible recognizer registry; useful for sensitive entity categories rather than general “meeting entities”.
|
||||
- Source: https://github.com/microsoft/presidio
|
||||
- Stanza (Apache-2.0): full NLP pipeline with NER component; supports many languages; requires model download and can add CoreNLP features.
|
||||
- Source: https://stanfordnlp.github.io/stanza/
|
||||
- spaCy (open-source): production-oriented NLP toolkit with NER + entity ruler; offers fast pipelines and rule-based patterns; easy to combine ML + rules for domain entities.
|
||||
- Source: https://spacy.io/usage/spacy-101
|
||||
Executive Summary
|
||||
|
||||
## Recommendation for NoteFlow
|
||||
- Best “heavy lifting” candidate: GLiNER for flexible, domain-specific entity types without retraining. It can be fed label sets like [person, company, product, app, location, time, duration, event] and will return per-label spans.
|
||||
- If PII detection is a requirement: layer Presidio on top of GLiNER outputs (or run Presidio first for sensitive categories). Presidio is not optimized for general entity relevance outside PII.
|
||||
- If you prefer classic pipelines or need dependency parsing: spaCy or Stanza; spaCy is easier to extend with rule-based entity ruler and custom merge policies.
|
||||
This document analyzes optimal placement of CLAUDE.md files throughout the NoteFlow codebase to provide meaningful context for AI assistants. The analysis considers both
|
||||
constrained (strategic) and unlimited scenarios.
|
||||
|
||||
## Quick wins for robustness (model-agnostic)
|
||||
- Normalize text: capitalization restoration, punctuation cleanup, ASR filler removal before NER.
|
||||
- Post-process: merge adjacent time phrases ("last night" + "20 minutes"), dedupe casing, drop single-token low-signal entities.
|
||||
- Context scoring: down-rank profanity/generic nouns unless in quoted/explicit phrases.
|
||||
- Speaker-aware aggregation: extract per speaker, then merge and rank by frequency and context windows.
|
||||
---
|
||||
Current State: Existing Documentation Files
|
||||
|
||||
## Meeting entity schema (proposed)
|
||||
- person
|
||||
- org
|
||||
- product
|
||||
- app
|
||||
- location
|
||||
- time
|
||||
- time_relative
|
||||
- duration
|
||||
- event
|
||||
- task
|
||||
- decision
|
||||
- topic
|
||||
10 CLAUDE.md/AGENTS.md Files Already Present
|
||||
┌─────────────────────┬───────────┬──────────────────────────────────────────────────────────┐
|
||||
│ Location │ File │ Focus │
|
||||
├─────────────────────┼───────────┼──────────────────────────────────────────────────────────┤
|
||||
│ / │ CLAUDE.md │ Root orchestration, parallel execution, project overview │
|
||||
├─────────────────────┼───────────┼──────────────────────────────────────────────────────────┤
|
||||
│ / │ AGENTS.md │ Architecture for non-Claude AI assistants │
|
||||
├─────────────────────┼───────────┼──────────────────────────────────────────────────────────┤
|
||||
│ /src/ │ CLAUDE.md │ Python backend entry point │
|
||||
├─────────────────────┼───────────┼──────────────────────────────────────────────────────────┤
|
||||
│ /src/ │ AGENTS.md │ Python backend for other AIs │
|
||||
├─────────────────────┼───────────┼──────────────────────────────────────────────────────────┤
|
||||
│ /src/noteflow/ │ CLAUDE.md │ Detailed Python standards (line limits, typing, modules) │
|
||||
├─────────────────────┼───────────┼──────────────────────────────────────────────────────────┤
|
||||
│ /src/noteflow/grpc/ │ CLAUDE.md │ gRPC security patterns │
|
||||
├─────────────────────┼───────────┼──────────────────────────────────────────────────────────┤
|
||||
│ /src/noteflow/grpc/ │ AGENTS.md │ gRPC security (duplicate) │
|
||||
├─────────────────────┼───────────┼──────────────────────────────────────────────────────────┤
|
||||
│ /client/ │ CLAUDE.md │ Tauri + React development │
|
||||
├─────────────────────┼───────────┼──────────────────────────────────────────────────────────┤
|
||||
│ /client/src/ │ CLAUDE.md │ TypeScript security rules │
|
||||
├─────────────────────┼───────────┼──────────────────────────────────────────────────────────┤
|
||||
│ /docker/ │ CLAUDE.md │ Docker security and build patterns │
|
||||
└─────────────────────┴───────────┴──────────────────────────────────────────────────────────┘
|
||||
---
|
||||
Part 1: Strategic Placement (Constrained Resources)
|
||||
|
||||
## spaCy label mapping (to meeting schema)
|
||||
- PERSON -> person
|
||||
- ORG -> org
|
||||
- GPE/LOC/FAC -> location
|
||||
- PRODUCT -> product
|
||||
- EVENT -> event
|
||||
- DATE/TIME -> time
|
||||
- CARDINAL/QUANTITY -> duration (only if units present: min/hour/day/week/month)
|
||||
- WORK_OF_ART/LAW -> topic (fallback)
|
||||
- NORP -> topic (fallback)
|
||||
If limited to 5-7 additional files, prioritize these high-impact locations:
|
||||
|
||||
## GLiNER label set (meeting tuned)
|
||||
- person, org, product, app, location, time, time_relative, duration, event, task, decision, topic
|
||||
Tier 1: Critical Gaps (Add These First)
|
||||
|
||||
## Eval content (meeting-like snippets)
|
||||
1) "We should ping Pixel about the shipment, maybe next week."
|
||||
- gold: Pixel=product, next week=time_relative
|
||||
2) "Let’s meet at Navi around 3ish tomorrow."
|
||||
- gold: Navi=location, 3ish=time, tomorrow=time_relative
|
||||
3) "Paris is on the agenda, but it’s just a placeholder."
|
||||
- gold: Paris=location, agenda=topic
|
||||
4) "The demo for Figma went well last night."
|
||||
- gold: demo=event, Figma=product, last night=time_relative
|
||||
5) "I spent 20 minutes fixing the auth issue."
|
||||
- gold: 20 minutes=duration, auth issue=topic
|
||||
6) "Can you file a Jira for the login bug?"
|
||||
- gold: Jira=product, login bug=task
|
||||
7) "Sam and Priya agreed to ship v2 by Friday."
|
||||
- gold: Sam=person, Priya=person, Friday=time
|
||||
8) "We decided to drop the old onboarding flow."
|
||||
- gold: decision=decision, onboarding flow=topic
|
||||
9) "Let’s sync after the standup."
|
||||
- gold: sync=event, standup=event
|
||||
10) "I heard Maya mention Notion for notes."
|
||||
- gold: Maya=person, Notion=product, notes=topic
|
||||
11) "Zoom audio was bad during the call."
|
||||
- gold: Zoom=product, call=event
|
||||
12) "We should ask Alex about the budget."
|
||||
- gold: Alex=person, budget=topic
|
||||
13) "Let’s revisit this in two weeks."
|
||||
- gold: two weeks=duration
|
||||
14) "The patch for Safari is done."
|
||||
- gold: Safari=product, patch=task
|
||||
15) "We’ll meet at the cafe near Union Station."
|
||||
- gold: cafe=location, Union Station=location
|
||||
1. /src/noteflow/infrastructure/CLAUDE.md
|
||||
|
||||
## Data-backed comparison (local eval)
|
||||
- Setup: spaCy `en_core_web_sm` with mapping rules above vs GLiNER `urchade/gliner_medium-v2.1` with label set above.
|
||||
- Dataset: 15 meeting-like snippets (see eval content section).
|
||||
- Results:
|
||||
- spaCy: precision 0.34, recall 0.19 (tp=7, fp=10, fn=26)
|
||||
- GLiNER: precision 0.36, recall 0.36 (tp=11, fp=17, fn=22)
|
||||
Why: Infrastructure layer has 15+ adapters with distinct patterns (ASR, diarization, NER, summarization, calendar, webhooks, persistence). No unified guidance exists.
|
||||
|
||||
## Recommendation based on results
|
||||
- GLiNER outperformed spaCy on recall (0.36 vs 0.19) while maintaining slightly higher precision, which is more aligned with meeting transcripts where relevance is tied to custom labels and informal mentions.
|
||||
- Use GLiNER as the primary extractor, then apply a post-processing filter for low-signal entities and a normalization layer for time/duration to improve precision.
|
||||
2. /src/noteflow/domain/CLAUDE.md
|
||||
|
||||
## Implementation plan: GLiNER + post-processing with swap-friendly architecture (detailed + references)
|
||||
Why: Domain layer defines entities, ports, rules, and value objects. Understanding DDD boundaries prevents architectural violations.
|
||||
|
||||
### Overview
|
||||
Goal: Add GLiNER while keeping a clean boundary so switching back to spaCy or another backend is a config-only change. This uses three layers:
|
||||
1) Backend (GLiNER/spaCy) returns `RawEntity` only.
|
||||
2) Post-processor normalizes and filters `RawEntity`.
|
||||
3) Mapper converts `RawEntity` -> domain `NamedEntity`.
|
||||
3. /src/noteflow/application/services/CLAUDE.md
|
||||
|
||||
### Step 0: Identify the existing entry points
|
||||
- Current NER engine: `src/noteflow/infrastructure/ner/engine.py`.
|
||||
- Domain entity type: `src/noteflow/domain/entities/named_entity.py` (for `EntityCategory`, `NamedEntity`).
|
||||
- Service creation: `src/noteflow/grpc/startup/services.py` (selects NER engine).
|
||||
- Feature flag: `src/noteflow/config/settings/_features.py` (`ner_enabled`).
|
||||
Why: 12+ services with distinct responsibilities. Service-level guidance prevents duplication and clarifies orchestration patterns.
|
||||
|
||||
### Step 1: Add backend-agnostic models
|
||||
Create new file: `src/noteflow/infrastructure/ner/backends/types.py`
|
||||
- Add `RawEntity` dataclass:
|
||||
- `text: str`, `label: str`, `start: int`, `end: int`, `confidence: float | None`
|
||||
- `label` must be lowercase (e.g., `person`, `time_relative`).
|
||||
- Add `NerBackend` protocol:
|
||||
- `def extract(self, text: str) -> list[RawEntity]: ...`
|
||||
4. /client/src/hooks/CLAUDE.md
|
||||
|
||||
### Step 2: Implement GLiNER backend
|
||||
Create `src/noteflow/infrastructure/ner/backends/gliner_backend.py`
|
||||
- Lazy-load model to avoid startup delay (pattern copied from `src/noteflow/infrastructure/ner/engine.py`).
|
||||
- Import: `from gliner import GLiNER`.
|
||||
- Model: `GLiNER.from_pretrained("urchade/gliner_medium-v2.1")`.
|
||||
- Labels come from settings; if missing, default to the meeting label list in this doc.
|
||||
- Convert GLiNER entity dicts into `RawEntity`:
|
||||
- `text`: `entity["text"]`
|
||||
- `label`: `entity["label"].lower()`
|
||||
- `start`: `entity["start"]`
|
||||
- `end`: `entity["end"]`
|
||||
- `confidence`: `entity.get("score")`
|
||||
- Keep thresholds configurable (`NOTEFLOW_NER_GLINER_THRESHOLD`), default `0.5`.
|
||||
Why: 7 hook directories (audio, auth, data, processing, recording, sync, ui) with complex interdependencies. Prevents reinventing existing hooks.
|
||||
|
||||
### Step 3: Move spaCy logic into a backend
|
||||
Create `src/noteflow/infrastructure/ner/backends/spacy_backend.py`
|
||||
- Move current spaCy loading logic from `src/noteflow/infrastructure/ner/engine.py`.
|
||||
- Keep the model constants in `src/noteflow/config/constants.py`.
|
||||
- Use the existing spaCy mapping and skip rules; move them into this backend or a shared module.
|
||||
- Convert `Doc.ents` into `RawEntity` with lowercase labels.
|
||||
5. /client/src-tauri/src/CLAUDE.md
|
||||
|
||||
### Step 4: Create shared post-processing pipeline
|
||||
Create `src/noteflow/infrastructure/ner/post_processing.py` (pure functions only, no class):
|
||||
- `normalize_text(text: str) -> str`
|
||||
- `lower()`, collapse whitespace, trim.
|
||||
- `dedupe_entities(entities: list[RawEntity]) -> list[RawEntity]`
|
||||
- key: normalized text; keep highest confidence if both present.
|
||||
- `drop_low_signal_entities(entities: list[RawEntity], text: str) -> list[RawEntity]`
|
||||
- drop profanity, short tokens (`len <= 2`), numeric-only entities.
|
||||
- list should live in `src/noteflow/domain/constants/` or `src/noteflow/config/constants.py`.
|
||||
- `merge_time_phrases(entities: list[RawEntity], text: str) -> list[RawEntity]`
|
||||
- If two time-like entities are adjacent in the original text, merge into one span.
|
||||
- `infer_duration(entities: list[RawEntity]) -> list[RawEntity]`
|
||||
- If entity contains duration units (`minute`, `hour`, `day`, `week`, `month`), set label to `duration`.
|
||||
Why: Rust backend has commands, gRPC client, audio processing, state management. No Rust-specific guidance currently exists.
|
||||
|
||||
### Step 5: Mapping to domain entities
|
||||
Create `src/noteflow/infrastructure/ner/mapper.py`
|
||||
- `map_raw_to_named(entities: list[RawEntity]) -> list[NamedEntity]`.
|
||||
- Convert labels to `EntityCategory` from `src/noteflow/domain/entities/named_entity.py`.
|
||||
- Confidence rules:
|
||||
- If `RawEntity.confidence` exists, use it.
|
||||
- Else use default 0.8 (same as current spaCy fallback in `engine.py`).
|
||||
Tier 2: High Value (Add Next)
|
||||
|
||||
### Step 6: Update NerEngine to use composition
|
||||
Edit `src/noteflow/infrastructure/ner/engine.py`:
|
||||
- Remove direct spaCy dependency from `NerEngine`.
|
||||
- Add constructor params:
|
||||
- `backend: NerBackend`, `post_processor: Callable`, `mapper: NerMapper`.
|
||||
- Flow in `extract`:
|
||||
1) `raw = backend.extract(text)`
|
||||
2) `raw = post_process(raw, text)`
|
||||
3) `entities = mapper.map_raw_to_named(raw)`
|
||||
4) Deduplicate by normalized text (use current logic, or move to post-processing).
|
||||
6. /tests/CLAUDE.md
|
||||
|
||||
### Step 7: Configurable backend selection
|
||||
Edit `src/noteflow/grpc/startup/services.py`:
|
||||
- Add config: `NOTEFLOW_NER_BACKEND` (default `spacy`).
|
||||
- If value is `gliner`, instantiate `GLiNERBackend`; else `SpacyBackend`.
|
||||
- Continue to respect `get_feature_flags().ner_enabled` in `create_ner_service`.
|
||||
Why: Testing conventions (fixtures, markers, quality gates) are scattered. Centralized guidance improves test quality.
|
||||
|
||||
### Step 8: Tests (behavior + user flow + response quality)
|
||||
Create the following tests and use global fixtures from `tests/conftest.py` (do not redefine them):
|
||||
- Use `mock_uow`, `sample_meeting`, `meeting_id`, `sample_datetime`, `approx_sequence` where helpful.
|
||||
- Avoid loops/conditionals in tests; use `@pytest.mark.parametrize`.
|
||||
7. /src/noteflow/infrastructure/persistence/CLAUDE.md
|
||||
|
||||
**Unit tests**
|
||||
- `tests/ner/test_post_processing.py`
|
||||
- Validate normalization, dedupe, time merge, duration inference, and profanity drop.
|
||||
- Use small, deterministic inputs and assert exact `RawEntity` outputs.
|
||||
- `tests/ner/test_mapper.py`
|
||||
- Map every label to `EntityCategory` explicitly.
|
||||
- Assert confidence fallback when `RawEntity.confidence=None`.
|
||||
Why: UnitOfWork pattern, repository hierarchy, capability flags, migrations are complex. Prevents incorrect persistence patterns.
|
||||
|
||||
**Behavior tests**
|
||||
- `tests/ner/test_engine_behavior.py`
|
||||
- Use a stub backend to simulate GLiNER outputs with overlapping spans.
|
||||
- Verify `NerEngine.extract` returns stable, deduped, correctly categorized entities.
|
||||
---
|
||||
Part 2: Unlimited Placement (Comprehensive Coverage)
|
||||
|
||||
**User-flow tests**
|
||||
- `tests/ner/test_service_flow.py`
|
||||
- Use `mock_uow` and `sample_meeting` to simulate service flow that calls NER after a meeting completes.
|
||||
- Assert entities are persisted with expected categories and segment IDs.
|
||||
With no constraints, here's the complete list of 25+ locations where CLAUDE.md would add value:
|
||||
|
||||
**Response quality tests**
|
||||
- `tests/ner/test_quality_contract.py`
|
||||
- Feed short, meeting-like snippets and assert:
|
||||
- No profanity entities survive.
|
||||
- Time phrases are merged.
|
||||
- Product vs person is not mis-typed for known examples (Pixel, Notion, Zoom).
|
||||
Python Backend (src/noteflow/)
|
||||
┌────────────────────────────────────────┬─────────────────────────────────────────────────────────┐
|
||||
│ Path │ Content Focus │
|
||||
├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ domain/CLAUDE.md │ DDD entities, ports, value objects, rules engine │
|
||||
├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ domain/entities/CLAUDE.md │ Entity relationships, state machines, invariants │
|
||||
├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ domain/ports/CLAUDE.md │ Repository protocols, capability contracts │
|
||||
├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ domain/rules/CLAUDE.md │ Rule modes (SIMPLE→EXPRESSION), registry, evaluation │
|
||||
├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ application/CLAUDE.md │ Use case organization, service boundaries │
|
||||
├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ application/services/CLAUDE.md │ Service catalog, dependency patterns │
|
||||
├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ infrastructure/CLAUDE.md │ Adapter patterns, external integrations │
|
||||
├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ infrastructure/asr/CLAUDE.md │ Whisper, VAD, segmentation, streaming │
|
||||
├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ infrastructure/diarization/CLAUDE.md │ Job lifecycle, streaming vs offline, speaker assignment │
|
||||
├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ infrastructure/ner/CLAUDE.md │ Backend abstraction, mapper, post-processing │
|
||||
├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ infrastructure/summarization/CLAUDE.md │ Provider protocols, consent workflow, citation linking │
|
||||
├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ infrastructure/persistence/CLAUDE.md │ UnitOfWork, repositories, migrations │
|
||||
├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ infrastructure/calendar/CLAUDE.md │ OAuth flow, sync patterns, trigger detection │
|
||||
├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ infrastructure/webhooks/CLAUDE.md │ Delivery, signing, retry logic │
|
||||
├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ grpc/mixins/CLAUDE.md │ Mixin composition, streaming handlers │
|
||||
├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ grpc/startup/CLAUDE.md │ Service initialization, dependency injection │
|
||||
├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ config/CLAUDE.md │ Settings cascade, feature flags, environment loading │
|
||||
├────────────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ cli/CLAUDE.md │ Command patterns, model management │
|
||||
└────────────────────────────────────────┴─────────────────────────────────────────────────────────┘
|
||||
Client (client/)
|
||||
┌──────────────────────────────────┬─────────────────────────────────────────────────────────┐
|
||||
│ Path │ Content Focus │
|
||||
├──────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ src/api/CLAUDE.md │ Adapter pattern, transport abstraction, type generation │
|
||||
├──────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ src/components/CLAUDE.md │ Component hierarchy, feature organization │
|
||||
├──────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ src/hooks/CLAUDE.md │ Hook catalog, composition patterns, state management │
|
||||
├──────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ src/lib/CLAUDE.md │ Utility catalog, AI providers, audio processing │
|
||||
├──────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ src-tauri/src/CLAUDE.md │ Rust patterns, command handlers, state │
|
||||
├──────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ src-tauri/src/commands/CLAUDE.md │ IPC contract, audio commands, recording session │
|
||||
├──────────────────────────────────┼─────────────────────────────────────────────────────────┤
|
||||
│ src-tauri/src/grpc/CLAUDE.md │ gRPC client wrapper, type conversions │
|
||||
└──────────────────────────────────┴─────────────────────────────────────────────────────────┘
|
||||
Testing (tests/)
|
||||
┌─────────────────────────────┬────────────────────────────────────────────────────┐
|
||||
│ Path │ Content Focus │
|
||||
├─────────────────────────────┼────────────────────────────────────────────────────┤
|
||||
│ tests/CLAUDE.md │ Test conventions, fixtures, markers, quality gates │
|
||||
├─────────────────────────────┼────────────────────────────────────────────────────┤
|
||||
│ tests/fixtures/CLAUDE.md │ Shared fixtures catalog, usage patterns │
|
||||
├─────────────────────────────┼────────────────────────────────────────────────────┤
|
||||
│ tests/integration/CLAUDE.md │ Integration test setup, testcontainers │
|
||||
└─────────────────────────────┴────────────────────────────────────────────────────┘
|
||||
Documentation (docs/)
|
||||
┌────────────────────────┬───────────────────────────────────────────┐
|
||||
│ Path │ Content Focus │
|
||||
├────────────────────────┼───────────────────────────────────────────┤
|
||||
│ docs/sprints/CLAUDE.md │ Sprint structure, documentation standards │
|
||||
└────────────────────────┴───────────────────────────────────────────┘
|
||||
---
|
||||
Part 3: Mockup - /src/noteflow/infrastructure/CLAUDE.md
|
||||
|
||||
### Test matrix (inputs -> expected outputs)
|
||||
Use these exact cases in `tests/ner/test_quality_contract.py` with `@pytest.mark.parametrize`:
|
||||
1) Text: "We should ping Pixel about the shipment, maybe next week."
|
||||
- Expect: `Pixel=product`, `next week=time_relative`
|
||||
2) Text: "Let’s meet at Navi around 3ish tomorrow."
|
||||
- Expect: `Navi=location`, `3ish=time`, `tomorrow=time_relative`
|
||||
3) Text: "The demo for Figma went well last night."
|
||||
- Expect: `demo=event`, `Figma=product`, `last night=time_relative`
|
||||
4) Text: "I spent 20 minutes fixing the auth issue."
|
||||
- Expect: `20 minutes=duration`, `auth issue=topic`
|
||||
5) Text: "Can you file a Jira for the login bug?"
|
||||
- Expect: `Jira=product`, `login bug=task`
|
||||
6) Text: "Zoom audio was bad during the call."
|
||||
- Expect: `Zoom=product`, `call=event`
|
||||
7) Text: "We should ask Alex about the budget."
|
||||
- Expect: `Alex=person`, `budget=topic`
|
||||
8) Text: "Let’s revisit this in two weeks."
|
||||
- Expect: `two weeks=duration`
|
||||
9) Text: "We decided to drop the old onboarding flow."
|
||||
- Expect: `decision=decision`, `onboarding flow=topic`
|
||||
10) Text: "This is shit, totally broken."
|
||||
- Expect: no entity with text "shit"
|
||||
11) Text: "Last night we met in Paris."
|
||||
- Expect: `last night=time_relative`, `Paris=location`
|
||||
12) Text: "We can sync after the standup next week."
|
||||
- Expect: `sync=event`, `standup=event`, `next week=time_relative`
|
||||
13) Text: "PIXEL and pixel are the same product."
|
||||
- Expect: one `pixel=product` after dedupe
|
||||
14) Text: "Met with Navi, NAVI, and navi."
|
||||
- Expect: one `navi=location` after dedupe
|
||||
15) Text: "It happened 2 days ago."
|
||||
- Expect: `2 days=duration` or `2 days=time_relative` (pick one, but be consistent)
|
||||
16) Text: "We should ask 'Sam' and Sam about the release."
|
||||
- Expect: one `sam=person`
|
||||
17) Text: "Notion and Figma docs are ready."
|
||||
- Expect: `Notion=product`, `Figma=product`
|
||||
18) Text: "Meet at Union Station, then Union station again."
|
||||
- Expect: one `union station=location`
|
||||
19) Text: "We need a patch for Safari and safari iOS."
|
||||
- Expect: one `safari=product` and `patch=task`
|
||||
20) Text: "App 'Navi' crashed during the call."
|
||||
- Expect: `Navi=product` (overrides location when app context exists)
|
||||
# Infrastructure Layer Development Guide
|
||||
|
||||
### Test assertions guidance
|
||||
- Always assert exact normalized text + label pairs.
|
||||
- Assert counts (expected length) to avoid hidden extras.
|
||||
- For time merge: verify only one entity span for "last night" or "next week".
|
||||
- For profanity: assert it does not appear in any entity text.
|
||||
- For casing/duplicates: assert one normalized entity only.
|
||||
- For override rules: assert app/product context wins over location if explicit.
|
||||
## Overview
|
||||
|
||||
### Step 9: Strict exit criteria
|
||||
Implementation is only done if ALL are true:
|
||||
- `make quality` passes.
|
||||
- No compatibility wrappers or adapter shims added.
|
||||
- No legacy code paths left in place (spaCy backend must be an actual backend, not dead code).
|
||||
- No temp artifacts or experimental files left behind.
|
||||
The infrastructure layer (`src/noteflow/infrastructure/`) contains adapters that implement domain ports. These connect the application to external systems: databases, ML
|
||||
models, cloud APIs, file systems.
|
||||
|
||||
### Example backend usage (pseudo)
|
||||
```
|
||||
backend = GLiNERBackend(labels=settings.ner_labels, threshold=settings.ner_threshold)
|
||||
raw = backend.extract(text)
|
||||
raw = post_process(raw, text)
|
||||
entities = mapper.map_raw_to_named(raw)
|
||||
```
|
||||
---
|
||||
|
||||
### Outcome
|
||||
- Switching NER backends is now a 1-line config change.
|
||||
- Post-processing is shared and consistent across backends.
|
||||
- Domain entities remain unchanged for callers.
|
||||
## Architecture Principle: Hexagonal/Ports-and-Adapters
|
||||
|
||||
## Simple eval scoring checklist
|
||||
- Precision per label
|
||||
- Recall per label
|
||||
- Actionable rate (task/decision/event/topic)
|
||||
- Type confusion counts (person<->product, org<->product, location<->org, time<->duration)
|
||||
Domain Ports (interfaces) Infrastructure Adapters (implementations)
|
||||
───────────────────────── ───────────────────────────────────────────
|
||||
NerPort → SpacyBackend, GlinerBackend
|
||||
SummarizationProvider → CloudProvider, OllamaProvider, MockProvider
|
||||
DiarizationEngine → DiartSession, PyannoteOffline
|
||||
AssetRepository → FileSystemAssetRepository
|
||||
UnitOfWork → SqlAlchemyUnitOfWork, MemoryUnitOfWork
|
||||
CalendarProvider → GoogleCalendar, OutlookCalendar
|
||||
|
||||
## Next steps to validate
|
||||
- Run GLiNER on sample transcript and compare entity yield vs spaCy pipeline.
|
||||
- Apply the spaCy label mapping above to ensure apples-to-apples scoring.
|
||||
- Add a lightweight rule layer for time_relative/duration normalization.
|
||||
**Rule**: Infrastructure code imports domain; domain NEVER imports infrastructure.
|
||||
|
||||
---
|
||||
|
||||
## Adapter Catalog
|
||||
|
||||
| Directory | Responsibility | Key Protocols |
|
||||
|-----------|----------------|---------------|
|
||||
| `asr/` | Speech-to-text (Whisper) | `TranscriptionResult` |
|
||||
| `diarization/` | Speaker identification | `DiarizationEngine`, `DiarizationJob` |
|
||||
| `ner/` | Named entity extraction | `NerPort` |
|
||||
| `summarization/` | LLM summarization | `SummarizationProvider` |
|
||||
| `persistence/` | Database (SQLAlchemy) | `UnitOfWork`, `*Repository` |
|
||||
| `calendar/` | OAuth + event sync | `CalendarProvider` |
|
||||
| `webhooks/` | Event delivery | `WebhookDeliveryService` |
|
||||
| `export/` | PDF/HTML/Markdown | `ExportAdapter` |
|
||||
| `audio/` | Recording/playback | `AudioDevice` |
|
||||
| `crypto/` | Encryption | `Keystore` |
|
||||
| `logging/` | Structured logging | `LogEventType` |
|
||||
| `metrics/` | Observability | `MetricsCollector` |
|
||||
| `gpu/` | GPU detection | `GpuInfo` |
|
||||
|
||||
---
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### 1. Async Wrappers for Sync Libraries
|
||||
|
||||
Many ML libraries (spaCy, faster-whisper) are synchronous. Wrap them:
|
||||
|
||||
```python
|
||||
async def extract(self, text: str) -> list[NamedEntity]:
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
None, # Default ThreadPoolExecutor
|
||||
self._sync_extract,
|
||||
text
|
||||
)
|
||||
|
||||
2. Backend Selection via Factory
|
||||
|
||||
def create_ner_engine(config: NerConfig) -> NerPort:
|
||||
match config.backend:
|
||||
case "spacy":
|
||||
return SpacyBackend(model=config.model_name)
|
||||
case "gliner":
|
||||
return GlinerBackend(model=config.model_name)
|
||||
case _:
|
||||
raise ValueError(f"Unknown NER backend: {config.backend}")
|
||||
|
||||
3. Capability Flags for Optional Features
|
||||
|
||||
class SqlAlchemyUnitOfWork(UnitOfWork):
|
||||
@property
|
||||
def supports_entities(self) -> bool:
|
||||
return True # Has EntityRepository
|
||||
|
||||
@property
|
||||
def supports_webhooks(self) -> bool:
|
||||
return True # Has WebhookRepository
|
||||
|
||||
Always check capability before accessing optional repository:
|
||||
if uow.supports_entities:
|
||||
entities = await uow.entities.get_by_meeting(meeting_id)
|
||||
|
||||
4. Provider Protocol Pattern
|
||||
|
||||
class SummarizationProvider(Protocol):
|
||||
async def summarize(
|
||||
self,
|
||||
segments: list[Segment],
|
||||
template: SummarizationTemplate,
|
||||
) -> SummaryResult: ...
|
||||
|
||||
@property
|
||||
def requires_consent(self) -> bool: ...
|
||||
|
||||
---
|
||||
Forbidden Patterns
|
||||
|
||||
❌ Direct database access outside persistence/
|
||||
# WRONG: Raw SQL in service layer
|
||||
async with engine.connect() as conn:
|
||||
result = await conn.execute(text("SELECT * FROM meetings"))
|
||||
|
||||
❌ Hardcoded API keys
|
||||
# WRONG: Secrets in code
|
||||
client = anthropic.Anthropic(api_key="sk-ant-...")
|
||||
|
||||
❌ Synchronous I/O in async context
|
||||
# WRONG: Blocking the event loop
|
||||
def load_model(self):
|
||||
self.model = whisper.load_model("base") # Blocks!
|
||||
|
||||
❌ Domain imports in infrastructure
|
||||
# WRONG: Infrastructure should implement domain ports, not modify domain
|
||||
from noteflow.domain.entities import Meeting
|
||||
meeting.state = "COMPLETED" # Don't mutate domain objects here
|
||||
|
||||
---
|
||||
Testing Infrastructure Adapters
|
||||
|
||||
Use Dependency Injection for Mocking
|
||||
|
||||
# tests/infrastructure/ner/test_engine.py
|
||||
@pytest.fixture
|
||||
def mock_backend() -> NerBackend:
|
||||
backend = Mock(spec=NerBackend)
|
||||
backend.extract.return_value = [
|
||||
RawEntity(text="John", label="PERSON", start=0, end=4)
|
||||
]
|
||||
return backend
|
||||
|
||||
async def test_engine_uses_backend(mock_backend):
|
||||
engine = NerEngine(backend=mock_backend)
|
||||
result = await engine.extract("Hello John")
|
||||
mock_backend.extract.assert_called_once()
|
||||
|
||||
Integration Tests with Real Services
|
||||
|
||||
# tests/integration/test_ner_integration.py
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.requires_gpu
|
||||
async def test_gliner_real_extraction():
|
||||
backend = GlinerBackend(model="urchade/gliner_base")
|
||||
result = await backend.extract("Microsoft CEO Satya Nadella announced...")
|
||||
assert any(e.label == "ORG" and "Microsoft" in e.text for e in result)
|
||||
|
||||
---
|
||||
Adding a New Adapter
|
||||
|
||||
1. Define port in domain (domain/ports/) if not exists
|
||||
2. Create adapter directory (infrastructure/<adapter_name>/)
|
||||
3. Implement the protocol with proper async handling
|
||||
4. Add factory function for backend selection
|
||||
5. Write unit tests with mocked dependencies
|
||||
6. Write integration test with real external service
|
||||
7. Update gRPC startup (grpc/startup/services.py) for dependency injection
|
||||
8. Document in this file (update Adapter Catalog table)
|
||||
|
||||
---
|
||||
Key Files
|
||||
┌───────────────────────────────┬──────────────────────────────────┐
|
||||
│ File │ Purpose │
|
||||
├───────────────────────────────┼──────────────────────────────────┤
|
||||
│ __init__.py │ Public exports for layer │
|
||||
├───────────────────────────────┼──────────────────────────────────┤
|
||||
│ */engine.py │ Main adapter implementation │
|
||||
├───────────────────────────────┼──────────────────────────────────┤
|
||||
│ */backends/ │ Multiple backend implementations │
|
||||
├───────────────────────────────┼──────────────────────────────────┤
|
||||
│ */mapper.py │ External→Domain type conversion │
|
||||
├───────────────────────────────┼──────────────────────────────────┤
|
||||
│ */post_processing.py │ Output normalization │
|
||||
├───────────────────────────────┼──────────────────────────────────┤
|
||||
│ persistence/unit_of_work/*.py │ Transaction management │
|
||||
├───────────────────────────────┼──────────────────────────────────┤
|
||||
│ persistence/repositories/*.py │ Data access │
|
||||
├───────────────────────────────┼──────────────────────────────────┤
|
||||
│ persistence/models/*.py │ ORM definitions │
|
||||
└───────────────────────────────┴──────────────────────────────────┘
|
||||
---
|
||||
See Also
|
||||
|
||||
- /src/noteflow/domain/ports/ — Port definitions
|
||||
- /src/noteflow/grpc/startup/services.py — Dependency injection
|
||||
- /tests/infrastructure/ — Adapter tests
|
||||
|
||||
---
|
||||
|
||||
## Part 4: Answer to "Would Your Answer Change With No Limit?"
|
||||
|
||||
**Yes, significantly.**
|
||||
|
||||
### Constrained (5-7 files):
|
||||
Focus on **layer boundaries** (domain, application, infrastructure) and **high-complexity areas** (hooks, Rust backend). Each file covers broad territory.
|
||||
|
||||
### Unlimited (25+ files):
|
||||
Add **subsystem-specific documentation** for:
|
||||
- Complex state machines (diarization jobs, recording lifecycle)
|
||||
- Protocol patterns (summarization providers, NER backends)
|
||||
- Cross-cutting concerns (rules engine, settings cascade)
|
||||
- Test organization (fixtures, integration setup)
|
||||
|
||||
The key difference: with unlimited resources, document **WHY decisions were made** (design rationale), not just **WHAT exists** (API reference).
|
||||
|
||||
---
|
||||
|
||||
## Recommendation
|
||||
|
||||
### Immediate Action (Phase 1)
|
||||
Add these 3 files for maximum impact:
|
||||
1. `/src/noteflow/infrastructure/CLAUDE.md` — Adapter patterns (mockup above)
|
||||
2. `/src/noteflow/domain/CLAUDE.md` — DDD boundaries, entity relationships
|
||||
3. `/client/src-tauri/src/CLAUDE.md` — Rust patterns, IPC contracts
|
||||
|
||||
### Follow-up (Phase 2)
|
||||
4. `/src/noteflow/application/services/CLAUDE.md` — Service catalog
|
||||
5. `/client/src/hooks/CLAUDE.md` — Hook organization
|
||||
6. `/tests/CLAUDE.md` — Testing conventions
|
||||
|
||||
### Future (Phase 3)
|
||||
Remaining 19+ files as the codebase grows and patterns stabilize.
|
||||
12
.mcp.json
12
.mcp.json
@@ -3,6 +3,18 @@
|
||||
"lightrag-mcp": {
|
||||
"type": "sse",
|
||||
"url": "http://192.168.50.185:8150/sse"
|
||||
},
|
||||
"firecrawl": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"firecrawl-mcp"
|
||||
],
|
||||
"env": {
|
||||
"FIRECRAWL_API_URL": "http://crawl.toy",
|
||||
"FIRECRAWL_API_KEY": "dummy-key"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -245,12 +245,16 @@ echo ""
|
||||
# Check 6: Deep nesting (> 7 levels = 28 spaces)
|
||||
# Thresholds: 20=5 levels, 24=6 levels, 28=7 levels
|
||||
# 7 levels allows for async patterns: spawn + block_on + loop + select + match + if + body
|
||||
# Only excludes: generated files, async bidirectional streaming (inherently deep)
|
||||
# Excludes: generated files, streaming (inherently deep), tracing macro field formatting
|
||||
echo "Checking for deep nesting..."
|
||||
DEEP_NESTING=$(grep -rn --include="*.rs" $GREP_EXCLUDES -E '^[[:space:]]{28,}[^[:space:]]' "$TAURI_SRC" \
|
||||
| grep -v '//' \
|
||||
| grep -v 'noteflow.rs' \
|
||||
| grep -v 'streaming.rs' \
|
||||
| grep -v 'tracing::' \
|
||||
| grep -v ' = %' \
|
||||
| grep -v ' = identity\.' \
|
||||
| grep -v '[[:space:]]"[A-Z]' \
|
||||
| head -20 || true)
|
||||
|
||||
if [ -n "$DEEP_NESTING" ]; then
|
||||
|
||||
@@ -47,6 +47,10 @@ pub struct WasapiLoopbackHandle {
|
||||
#[cfg(target_os = "windows")]
|
||||
impl WasapiLoopbackHandle {
|
||||
pub fn stop(mut self) {
|
||||
self.stop_internal();
|
||||
}
|
||||
|
||||
fn stop_internal(&mut self) {
|
||||
let _ = self.stop_tx.send(());
|
||||
if let Some(join) = self.join.take() {
|
||||
let _ = join.join();
|
||||
@@ -54,6 +58,13 @@ impl WasapiLoopbackHandle {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
impl Drop for WasapiLoopbackHandle {
|
||||
fn drop(&mut self) {
|
||||
self.stop_internal();
|
||||
}
|
||||
}
|
||||
|
||||
/// Start WASAPI loopback capture on a background thread.
|
||||
#[cfg(target_os = "windows")]
|
||||
pub fn start_wasapi_loopback_capture<F>(
|
||||
@@ -177,7 +188,6 @@ where
|
||||
F: FnMut(&[f32]),
|
||||
{
|
||||
initialize_mta()
|
||||
.ok()
|
||||
.map_err(|err| Error::AudioCapture(format!("WASAPI init failed: {err}")))?;
|
||||
|
||||
let result: Result<()> = (|| {
|
||||
|
||||
@@ -302,11 +302,12 @@ async fn run_event_loop(app: AppHandle, mut rx: broadcast::Receiver<AppEvent>) {
|
||||
/// during Tauri's setup hook, before the main async runtime is fully initialized.
|
||||
///
|
||||
/// Returns the `JoinHandle` so the caller can wait for graceful shutdown.
|
||||
/// Returns `None` if thread spawning fails (logged as error).
|
||||
pub fn start_event_emitter(
|
||||
app: AppHandle,
|
||||
rx: broadcast::Receiver<AppEvent>,
|
||||
) -> std::thread::JoinHandle<()> {
|
||||
std::thread::Builder::new()
|
||||
) -> Option<std::thread::JoinHandle<()>> {
|
||||
match std::thread::Builder::new()
|
||||
.name("noteflow-event-emitter".to_string())
|
||||
.spawn(move || {
|
||||
let rt = match tokio::runtime::Builder::new_current_thread()
|
||||
@@ -326,8 +327,17 @@ pub fn start_event_emitter(
|
||||
|
||||
rt.block_on(run_event_loop(app, rx));
|
||||
tracing::debug!("Event emitter thread exiting");
|
||||
})
|
||||
.expect("Failed to spawn event emitter thread")
|
||||
}) {
|
||||
Ok(handle) => Some(handle),
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
subsystem = "event_emitter",
|
||||
"Failed to spawn event emitter thread - frontend events disabled"
|
||||
);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -180,10 +180,10 @@ impl GrpcClient {
|
||||
.await?
|
||||
.into_inner();
|
||||
|
||||
Ok(response
|
||||
response
|
||||
.entity
|
||||
.map(convert_entity)
|
||||
.expect("UpdateEntity response should contain entity"))
|
||||
.ok_or_else(|| crate::error::Error::Stream("UpdateEntity response missing entity".into()))
|
||||
}
|
||||
|
||||
/// Delete a named entity.
|
||||
|
||||
@@ -144,34 +144,44 @@ impl IdentityStore {
|
||||
}
|
||||
|
||||
fn load_identity_from_keychain(&self) {
|
||||
if let Ok(identity_json) = self.get_keychain_value(identity_config::IDENTITY_KEY) {
|
||||
match serde_json::from_str::<StoredIdentity>(&identity_json) {
|
||||
Ok(identity) => {
|
||||
tracing::info!(
|
||||
user_id = %identity.user_id,
|
||||
is_local = identity.is_local,
|
||||
"Loaded identity from keychain"
|
||||
);
|
||||
*self.identity.write() = Some(identity);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Failed to parse stored identity, using default");
|
||||
match self.get_keychain_value(identity_config::IDENTITY_KEY) {
|
||||
Ok(identity_json) => {
|
||||
match serde_json::from_str::<StoredIdentity>(&identity_json) {
|
||||
Ok(identity) => {
|
||||
tracing::info!(
|
||||
user_id = %identity.user_id,
|
||||
is_local = identity.is_local,
|
||||
"Loaded identity from keychain"
|
||||
);
|
||||
*self.identity.write() = Some(identity);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Failed to parse stored identity, using default");
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!(error = %e, "No identity in keychain (first run or cleared)");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_tokens_from_keychain(&self) {
|
||||
if let Ok(token_json) = self.get_keychain_value(identity_config::AUTH_TOKEN_KEY) {
|
||||
match serde_json::from_str::<StoredTokens>(&token_json) {
|
||||
Ok(tokens) => {
|
||||
tracing::debug!("Loaded auth tokens from keychain");
|
||||
*self.tokens.write() = Some(tokens);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Failed to parse stored tokens");
|
||||
match self.get_keychain_value(identity_config::AUTH_TOKEN_KEY) {
|
||||
Ok(token_json) => {
|
||||
match serde_json::from_str::<StoredTokens>(&token_json) {
|
||||
Ok(tokens) => {
|
||||
tracing::debug!("Loaded auth tokens from keychain");
|
||||
*self.tokens.write() = Some(tokens);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Failed to parse stored tokens");
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!(error = %e, "No auth tokens in keychain");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -297,11 +297,12 @@ fn setup_app_state(
|
||||
|
||||
// Start event emitter and track the thread handle
|
||||
let app_handle = app.handle().clone();
|
||||
let event_emitter_handle = events::start_event_emitter(app_handle, event_tx.subscribe());
|
||||
|
||||
// Store event emitter handle in shutdown manager
|
||||
if let Some(shutdown_mgr) = app.try_state::<Arc<ShutdownManager>>() {
|
||||
shutdown_mgr.set_event_emitter_handle(event_emitter_handle);
|
||||
if let Some(event_emitter_handle) =
|
||||
events::start_event_emitter(app_handle, event_tx.subscribe())
|
||||
{
|
||||
if let Some(shutdown_mgr) = app.try_state::<Arc<ShutdownManager>>() {
|
||||
shutdown_mgr.set_event_emitter_handle(event_emitter_handle);
|
||||
}
|
||||
}
|
||||
|
||||
// Start trigger polling (foreground app + audio activity detection)
|
||||
|
||||
@@ -7,8 +7,7 @@ use std::collections::VecDeque;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use parking_lot::RwLock;
|
||||
use tokio::sync::Mutex;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
|
||||
use crate::audio::PlaybackHandle;
|
||||
use crate::config;
|
||||
@@ -436,7 +435,7 @@ impl AppState {
|
||||
pub fn get_trigger_status(&self) -> TriggerStatus {
|
||||
let (snoozed, snooze_remaining) = self
|
||||
.trigger_service
|
||||
.blocking_lock()
|
||||
.lock()
|
||||
.as_ref()
|
||||
.map(|service| (service.is_snoozed(), service.snooze_remaining_seconds()))
|
||||
.unwrap_or((false, 0.0));
|
||||
|
||||
@@ -99,9 +99,10 @@ function CalendarErrorState({ onRetry, isRetrying }: { onRetry: () => void; isRe
|
||||
}
|
||||
|
||||
export function UpcomingMeetings({ maxEvents = 10 }: UpcomingMeetingsProps) {
|
||||
const integrations = preferences.getIntegrations();
|
||||
const calendarIntegrations = integrations.filter((i) => i.type === 'calendar');
|
||||
const connectedCalendars = calendarIntegrations.filter((i) => i.status === 'connected');
|
||||
const connectedCalendars = useMemo(() => {
|
||||
const integrations = preferences.getIntegrations();
|
||||
return integrations.filter((i) => i.type === 'calendar' && i.status === 'connected');
|
||||
}, []);
|
||||
|
||||
// Use live calendar API instead of mock data
|
||||
const { state, fetchEvents } = useCalendarSync({
|
||||
|
||||
@@ -6,6 +6,7 @@ import { getAPI } from '@/api/interface';
|
||||
import { isTauriEnvironment } from '@/api';
|
||||
import type { GetCurrentUserResponse } from '@/api/types';
|
||||
import { toast } from '@/hooks/ui/use-toast';
|
||||
import { addClientLog } from '@/lib/observability/client';
|
||||
import { toastError } from '@/lib/observability/errors';
|
||||
import {
|
||||
extractOAuthCallback,
|
||||
@@ -141,14 +142,22 @@ export function useAuthFlow(): UseAuthFlowReturn {
|
||||
}
|
||||
};
|
||||
|
||||
void setupDeepLinkListener(handleDeepLinkCallback).then((c) => {
|
||||
if (unmounted) {
|
||||
// Component unmounted before setup completed, call cleanup immediately
|
||||
c();
|
||||
} else {
|
||||
cleanup = c;
|
||||
}
|
||||
});
|
||||
setupDeepLinkListener(handleDeepLinkCallback)
|
||||
.then((c) => {
|
||||
if (unmounted) {
|
||||
c();
|
||||
} else {
|
||||
cleanup = c;
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
addClientLog({
|
||||
level: 'warning',
|
||||
source: 'auth',
|
||||
message: 'Failed to setup deep link listener',
|
||||
details: err instanceof Error ? err.message : String(err),
|
||||
});
|
||||
});
|
||||
|
||||
return () => {
|
||||
unmounted = true;
|
||||
@@ -254,7 +263,13 @@ export function useAuthFlow(): UseAuthFlowReturn {
|
||||
}));
|
||||
|
||||
return userInfo;
|
||||
} catch {
|
||||
} catch (error) {
|
||||
addClientLog({
|
||||
level: 'warning',
|
||||
source: 'auth',
|
||||
message: 'Failed to check auth status',
|
||||
details: error instanceof Error ? error.message : String(error),
|
||||
});
|
||||
return null;
|
||||
}
|
||||
}, []);
|
||||
|
||||
@@ -7,6 +7,7 @@ import { getAPI } from '@/api/interface';
|
||||
import { isTauriEnvironment } from '@/api';
|
||||
import type { OAuthConnection } from '@/api/types';
|
||||
import { toast } from '@/hooks/ui/use-toast';
|
||||
import { addClientLog } from '@/lib/observability/client';
|
||||
import { toastError } from '@/lib/observability/errors';
|
||||
import { extractOAuthCallback, setupDeepLinkListener, validateOAuthState } from '@/lib/integrations/oauth';
|
||||
|
||||
@@ -132,14 +133,22 @@ export function useOAuthFlow(): UseOAuthFlowReturn {
|
||||
|
||||
let unmounted = false;
|
||||
|
||||
void setupDeepLinkListener(handleDeepLinkCallback).then((c) => {
|
||||
if (unmounted) {
|
||||
// Component unmounted before setup completed, call cleanup immediately
|
||||
c();
|
||||
} else {
|
||||
cleanup = c;
|
||||
}
|
||||
});
|
||||
setupDeepLinkListener(handleDeepLinkCallback)
|
||||
.then((c) => {
|
||||
if (unmounted) {
|
||||
c();
|
||||
} else {
|
||||
cleanup = c;
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
addClientLog({
|
||||
level: 'warning',
|
||||
source: 'oauth',
|
||||
message: 'Failed to setup deep link listener',
|
||||
details: err instanceof Error ? err.message : String(err),
|
||||
});
|
||||
});
|
||||
|
||||
return () => {
|
||||
unmounted = true;
|
||||
|
||||
@@ -10,6 +10,7 @@ import { getAPI } from '@/api/interface';
|
||||
import type { DiarizationJobStatus, JobStatus } from '@/api/types';
|
||||
import { toast } from '@/hooks/ui/use-toast';
|
||||
import { PollingConfig } from '@/lib/config';
|
||||
import { addClientLog } from '@/lib/observability/client';
|
||||
import { errorMessageFrom, toastError } from '@/lib/observability/errors';
|
||||
|
||||
/** Diarization job state */
|
||||
@@ -448,8 +449,13 @@ export function useDiarization(options: UseDiarizationOptions = {}): UseDiarizat
|
||||
}
|
||||
|
||||
return job;
|
||||
} catch {
|
||||
// Recovery failure is non-fatal - continue without recovery
|
||||
} catch (error) {
|
||||
addClientLog({
|
||||
level: 'warning',
|
||||
source: 'app',
|
||||
message: 'Diarization recovery failed (non-fatal)',
|
||||
details: error instanceof Error ? error.message : String(error),
|
||||
});
|
||||
return null;
|
||||
}
|
||||
}, [poll, pollInterval, showToasts]);
|
||||
|
||||
@@ -69,6 +69,14 @@ export function useEntityExtraction(
|
||||
meetingTitleRef.current = meetingTitle;
|
||||
}, [meetingTitle]);
|
||||
|
||||
const mountedRef = useRef(true);
|
||||
useEffect(() => {
|
||||
mountedRef.current = true;
|
||||
return () => {
|
||||
mountedRef.current = false;
|
||||
};
|
||||
}, []);
|
||||
|
||||
const extract = useCallback(
|
||||
async (forceRefresh = false) => {
|
||||
if (!meetingId) {
|
||||
@@ -84,7 +92,7 @@ export function useEntityExtraction(
|
||||
|
||||
try {
|
||||
const response = await getAPI().extractEntities(meetingId, forceRefresh);
|
||||
// Use ref to get current title without causing callback recreation
|
||||
if (!mountedRef.current) return;
|
||||
setEntitiesFromExtraction(response.entities, meetingTitleRef.current);
|
||||
setState((prev) => ({
|
||||
...prev,
|
||||
@@ -100,6 +108,7 @@ export function useEntityExtraction(
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
if (!mountedRef.current) return;
|
||||
const message = toastError({
|
||||
title: 'Extraction failed',
|
||||
error,
|
||||
|
||||
@@ -100,6 +100,14 @@ export function useRecordingSession(
|
||||
|
||||
// Transcription stream
|
||||
const streamRef = useRef<TranscriptionStream | null>(null);
|
||||
const mountedRef = useRef(true);
|
||||
|
||||
useEffect(() => {
|
||||
mountedRef.current = true;
|
||||
return () => {
|
||||
mountedRef.current = false;
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Toast helpers
|
||||
const toastSuccess = useCallback(
|
||||
@@ -264,6 +272,7 @@ export function useRecordingSession(
|
||||
title: meetingTitle || `Recording ${formatDateTime()}`,
|
||||
project_id: projectId,
|
||||
});
|
||||
if (!mountedRef.current) return;
|
||||
setMeeting(newMeeting);
|
||||
|
||||
let stream: TranscriptionStream;
|
||||
@@ -275,6 +284,10 @@ export function useRecordingSession(
|
||||
} else {
|
||||
stream = ensureTranscriptionStream(await api.startTranscription(newMeeting.id));
|
||||
}
|
||||
if (!mountedRef.current) {
|
||||
stream.close();
|
||||
return;
|
||||
}
|
||||
|
||||
streamRef.current = stream;
|
||||
stream.onUpdate(handleTranscriptUpdate);
|
||||
@@ -285,6 +298,7 @@ export function useRecordingSession(
|
||||
shouldSimulate ? 'Simulation is active' : 'Transcription is now active'
|
||||
);
|
||||
} catch (error) {
|
||||
if (!mountedRef.current) return;
|
||||
setRecordingState('idle');
|
||||
toastError('Failed to start recording', error);
|
||||
}
|
||||
@@ -324,6 +338,7 @@ export function useRecordingSession(
|
||||
streamRef.current = null;
|
||||
const api = shouldSimulate && !isConnected ? mockAPI : getAPI();
|
||||
const stoppedMeeting = await api.stopMeeting(meeting.id);
|
||||
if (!mountedRef.current) return;
|
||||
setMeeting(stoppedMeeting);
|
||||
setRecordingState('idle');
|
||||
toastSuccess(
|
||||
@@ -331,6 +346,7 @@ export function useRecordingSession(
|
||||
shouldSimulate ? 'Simulation finished' : 'Your meeting has been saved'
|
||||
);
|
||||
} catch (error) {
|
||||
if (!mountedRef.current) return;
|
||||
setRecordingState('recording');
|
||||
toastError('Failed to stop recording', error);
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import { isIntegrationNotFoundError } from '@/api';
|
||||
import { getAPI } from '@/api/interface';
|
||||
import type { CalendarEvent, CalendarProvider } from '@/api/types';
|
||||
import { toast } from '@/hooks/ui/use-toast';
|
||||
import { addClientLog } from '@/lib/observability/client';
|
||||
import { errorMessageFrom, toastError } from '@/lib/observability/errors';
|
||||
|
||||
export type CalendarSyncStatus = 'idle' | 'loading' | 'success' | 'error';
|
||||
@@ -84,8 +85,13 @@ export function useCalendarSync(options: UseCalendarSyncOptions = {}): UseCalend
|
||||
...prev,
|
||||
providers: response.providers,
|
||||
}));
|
||||
} catch {
|
||||
// Provider fetch failed - non-critical, UI will show empty providers
|
||||
} catch (error) {
|
||||
addClientLog({
|
||||
level: 'debug',
|
||||
source: 'sync',
|
||||
message: 'Calendar provider fetch failed (non-critical)',
|
||||
details: error instanceof Error ? error.message : String(error),
|
||||
});
|
||||
}
|
||||
}, []);
|
||||
|
||||
|
||||
18
client/src/lib/cache/meeting-cache.ts
vendored
18
client/src/lib/cache/meeting-cache.ts
vendored
@@ -397,6 +397,20 @@ export const meetingCache = {
|
||||
};
|
||||
|
||||
// Flush cache on page unload to prevent data loss
|
||||
if (typeof window !== 'undefined') {
|
||||
window.addEventListener('beforeunload', flushCache);
|
||||
let beforeUnloadListenerAdded = false;
|
||||
|
||||
function setupBeforeUnloadListener(): void {
|
||||
if (typeof window !== 'undefined' && !beforeUnloadListenerAdded) {
|
||||
window.addEventListener('beforeunload', flushCache);
|
||||
beforeUnloadListenerAdded = true;
|
||||
}
|
||||
}
|
||||
|
||||
export function cleanupCacheBeforeUnloadListener(): void {
|
||||
if (typeof window !== 'undefined' && beforeUnloadListenerAdded) {
|
||||
window.removeEventListener('beforeunload', flushCache);
|
||||
beforeUnloadListenerAdded = false;
|
||||
}
|
||||
}
|
||||
|
||||
setupBeforeUnloadListener();
|
||||
|
||||
@@ -140,12 +140,24 @@ export function clearClientLogs(): void {
|
||||
}
|
||||
|
||||
// Flush logs on page unload to prevent data loss
|
||||
let beforeUnloadListener: (() => void) | null = null;
|
||||
if (typeof window !== 'undefined') {
|
||||
beforeUnloadListener = () => flushLogs();
|
||||
window.addEventListener('beforeunload', beforeUnloadListener);
|
||||
let beforeUnloadListenerAdded = false;
|
||||
|
||||
function setupBeforeUnloadListener(): void {
|
||||
if (typeof window !== 'undefined' && !beforeUnloadListenerAdded) {
|
||||
window.addEventListener('beforeunload', flushLogs);
|
||||
beforeUnloadListenerAdded = true;
|
||||
}
|
||||
}
|
||||
|
||||
export function cleanupBeforeUnloadListener(): void {
|
||||
if (typeof window !== 'undefined' && beforeUnloadListenerAdded) {
|
||||
window.removeEventListener('beforeunload', flushLogs);
|
||||
beforeUnloadListenerAdded = false;
|
||||
}
|
||||
}
|
||||
|
||||
setupBeforeUnloadListener();
|
||||
|
||||
/**
|
||||
* Reset internal state for testing. Clears pending logs and cancels any scheduled writes.
|
||||
* @internal
|
||||
@@ -157,8 +169,5 @@ export function _resetClientLogsForTesting(): void {
|
||||
writeTimeout = null;
|
||||
}
|
||||
cachedLogs = null;
|
||||
if (beforeUnloadListener !== null && typeof window !== 'undefined') {
|
||||
window.removeEventListener('beforeunload', beforeUnloadListener);
|
||||
beforeUnloadListener = null;
|
||||
}
|
||||
cleanupBeforeUnloadListener();
|
||||
}
|
||||
|
||||
@@ -1,10 +1,4 @@
|
||||
import { useEffect } from 'react';
|
||||
import { useLocation } from 'react-router-dom';
|
||||
|
||||
const NotFound = () => {
|
||||
const _location = useLocation();
|
||||
|
||||
useEffect(() => {}, []);
|
||||
|
||||
return (
|
||||
<div className="flex min-h-screen items-center justify-center bg-muted">
|
||||
|
||||
@@ -72,6 +72,9 @@ pdf = [
|
||||
ner = [
|
||||
"spacy>=3.8.11",
|
||||
]
|
||||
ner-gliner = [
|
||||
"gliner>=0.2.24",
|
||||
]
|
||||
calendar = [
|
||||
"google-api-python-client>=2.100",
|
||||
"google-auth>=2.23",
|
||||
@@ -109,8 +112,10 @@ optional = [
|
||||
"torch>=2.0",
|
||||
# PDF export
|
||||
"weasyprint>=67.0",
|
||||
# NER
|
||||
# NER (spaCy backend)
|
||||
"spacy>=3.8.11",
|
||||
# NER (GLiNER backend)
|
||||
"gliner>=0.2.24",
|
||||
# Calendar
|
||||
"google-api-python-client>=2.100",
|
||||
"google-auth>=2.23",
|
||||
@@ -122,7 +127,7 @@ optional = [
|
||||
"opentelemetry-exporter-otlp>=1.28",
|
||||
]
|
||||
all = [
|
||||
"noteflow[audio,dev,triggers,summarization,diarization,pdf,ner,calendar,observability]",
|
||||
"noteflow[audio,dev,triggers,summarization,diarization,pdf,ner,ner-gliner,calendar,observability]",
|
||||
]
|
||||
ollama = [
|
||||
"anthropic>=0.75.0",
|
||||
@@ -173,6 +178,7 @@ plugins = ["sqlalchemy.ext.mypy.plugin"]
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"diart.*",
|
||||
"gliner.*",
|
||||
"pyannote.*",
|
||||
"faster_whisper.*",
|
||||
"sounddevice.*",
|
||||
@@ -294,6 +300,7 @@ filterwarnings = [
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"basedpyright>=1.36.1",
|
||||
"gliner>=0.2.24",
|
||||
"grpc-stubs>=1.53.0.6",
|
||||
"protobuf>=6.33.2",
|
||||
"pyrefly>=0.46.1",
|
||||
@@ -301,6 +308,7 @@ dev = [
|
||||
"pytest-httpx>=0.36.0",
|
||||
"ruff>=0.14.9",
|
||||
"sourcery; sys_platform == 'darwin'",
|
||||
"spacy>=3.8.11",
|
||||
"types-grpcio==1.0.0.20251001",
|
||||
"watchfiles>=1.1.1",
|
||||
]
|
||||
|
||||
@@ -20,16 +20,28 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
DEFAULT_JOB_RETENTION_SECONDS: float = 300.0 # 5 minutes
|
||||
|
||||
|
||||
class AsrJobManager:
|
||||
"""Manage ASR configuration job lifecycle and state.
|
||||
|
||||
Handles job creation, status tracking, and cleanup during shutdown.
|
||||
Jobs are automatically removed from the registry after completion/failure
|
||||
to prevent memory leaks.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize job manager with empty job registry."""
|
||||
def __init__(self, job_retention_seconds: float = DEFAULT_JOB_RETENTION_SECONDS) -> None:
|
||||
"""Initialize job manager with empty job registry.
|
||||
|
||||
Args:
|
||||
job_retention_seconds: How long to keep completed/failed jobs before
|
||||
removing them from memory. Defaults to 5 minutes.
|
||||
"""
|
||||
self._jobs: dict[UUID, AsrConfigJob] = {}
|
||||
self._job_lock = asyncio.Lock()
|
||||
self._job_retention_seconds = job_retention_seconds
|
||||
self._cleanup_tasks: set[asyncio.Task[None]] = set()
|
||||
|
||||
async def register_job(self, job: AsrConfigJob) -> None:
|
||||
"""Register a new job under the job lock.
|
||||
@@ -52,6 +64,17 @@ class AsrJobManager:
|
||||
coro: The coroutine to run.
|
||||
"""
|
||||
job.task = asyncio.create_task(coro)
|
||||
job.task.add_done_callback(lambda t: self._job_task_done_callback(t, job))
|
||||
|
||||
def _job_task_done_callback(self, task: asyncio.Task[None], job: AsrConfigJob) -> None:
|
||||
if task.cancelled():
|
||||
logger.debug("asr_job_task_cancelled", job_id=str(job.job_id))
|
||||
return
|
||||
exc = task.exception()
|
||||
if exc is not None and job.phase != AsrConfigPhase.FAILED:
|
||||
logger.error("asr_job_task_unhandled_error", job_id=str(job.job_id), error=str(exc))
|
||||
wrapped_exc = exc if isinstance(exc, Exception) else RuntimeError(str(exc))
|
||||
self.mark_failed(job, wrapped_exc)
|
||||
|
||||
def get_job(self, job_id: UUID) -> AsrConfigJob | None:
|
||||
"""Get job status by ID.
|
||||
@@ -72,6 +95,7 @@ class AsrJobManager:
|
||||
job.phase = AsrConfigPhase.COMPLETED
|
||||
job.status = JOB_STATUS_COMPLETED
|
||||
job.progress_percent = 100.0
|
||||
self._schedule_job_cleanup(job.job_id)
|
||||
|
||||
def mark_failed(self, job: AsrConfigJob, error: Exception) -> None:
|
||||
"""Mark a job as failed with an error.
|
||||
@@ -84,6 +108,21 @@ class AsrJobManager:
|
||||
job.status = JOB_STATUS_FAILED
|
||||
job.error_message = str(error)
|
||||
logger.error("asr_reconfiguration_failed", error=str(error))
|
||||
self._schedule_job_cleanup(job.job_id)
|
||||
|
||||
def _schedule_job_cleanup(self, job_id: UUID) -> None:
|
||||
"""Schedule deferred removal of a completed/failed job from memory."""
|
||||
task = asyncio.create_task(self._cleanup_job_after_delay(job_id))
|
||||
self._cleanup_tasks.add(task)
|
||||
task.add_done_callback(self._cleanup_tasks.discard)
|
||||
|
||||
async def _cleanup_job_after_delay(self, job_id: UUID) -> None:
|
||||
"""Remove a job from the registry after the retention period."""
|
||||
await asyncio.sleep(self._job_retention_seconds)
|
||||
async with self._job_lock:
|
||||
removed = self._jobs.pop(job_id, None)
|
||||
if removed is not None:
|
||||
logger.debug("job_cleanup_completed", job_id=str(job_id))
|
||||
|
||||
def _collect_pending_tasks(self) -> list[asyncio.Task[None]]:
|
||||
"""Collect and cancel pending job tasks.
|
||||
|
||||
@@ -159,9 +159,9 @@ class NerService:
|
||||
self._model_helper = _ModelLifecycleHelper(ner_engine)
|
||||
self._extraction_helper = _ExtractionHelper(ner_engine, self._model_helper)
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Return True if the NER engine is loaded and ready."""
|
||||
return self._ner_engine.is_ready()
|
||||
def is_ner_ready(self) -> bool:
|
||||
engine = self._ner_engine
|
||||
return engine.is_ready()
|
||||
|
||||
async def extract_entities(
|
||||
self,
|
||||
|
||||
@@ -2,27 +2,30 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Final
|
||||
|
||||
from noteflow.config.constants import SPACY_MODEL_LG, SPACY_MODEL_SM
|
||||
|
||||
from ._types import ModelInfo
|
||||
|
||||
# Constants to avoid magic strings
|
||||
_DEFAULT_MODEL = "spacy-en"
|
||||
CMD_DOWNLOAD = "download"
|
||||
_DEFAULT_MODEL: Final[str] = "spacy-en"
|
||||
_SPACY_EN_LG: Final[str] = "spacy-en-lg"
|
||||
CMD_DOWNLOAD: Final[str] = "download"
|
||||
_SPACY_MODULE: Final[str] = "spacy"
|
||||
|
||||
AVAILABLE_MODELS: dict[str, ModelInfo] = {
|
||||
_DEFAULT_MODEL: ModelInfo(
|
||||
name=_DEFAULT_MODEL,
|
||||
description=f"English NER model ({SPACY_MODEL_SM})",
|
||||
feature="ner",
|
||||
install_command=["python", "-m", "spacy", CMD_DOWNLOAD, SPACY_MODEL_SM],
|
||||
install_command=["python", "-m", _SPACY_MODULE, CMD_DOWNLOAD, SPACY_MODEL_SM],
|
||||
check_import=SPACY_MODEL_SM,
|
||||
),
|
||||
"spacy-en-lg": ModelInfo(
|
||||
name="spacy-en-lg",
|
||||
_SPACY_EN_LG: ModelInfo(
|
||||
name=_SPACY_EN_LG,
|
||||
description=f"English NER model - large ({SPACY_MODEL_LG})",
|
||||
feature="ner",
|
||||
install_command=["python", "-m", "spacy", CMD_DOWNLOAD, SPACY_MODEL_LG],
|
||||
install_command=["python", "-m", _SPACY_MODULE, CMD_DOWNLOAD, SPACY_MODEL_LG],
|
||||
check_import=SPACY_MODEL_LG,
|
||||
),
|
||||
}
|
||||
|
||||
@@ -33,6 +33,8 @@ from noteflow.config.constants.domain import (
|
||||
EXPORT_FORMAT_HTML,
|
||||
FEATURE_NAME_PROJECTS,
|
||||
MEETING_TITLE_PREFIX,
|
||||
NER_BACKEND_GLINER,
|
||||
NER_BACKEND_SPACY,
|
||||
PROVIDER_NAME_OPENAI,
|
||||
RULE_FIELD_APP_MATCH_PATTERNS,
|
||||
RULE_FIELD_AUTO_START_ENABLED,
|
||||
@@ -173,6 +175,8 @@ __all__ = [
|
||||
"LOG_EVENT_WEBHOOK_UPDATE_FAILED",
|
||||
"MAX_GRPC_MESSAGE_SIZE",
|
||||
"MEETING_TITLE_PREFIX",
|
||||
"NER_BACKEND_GLINER",
|
||||
"NER_BACKEND_SPACY",
|
||||
"OAUTH_FIELD_ACCESS_TOKEN",
|
||||
"OAUTH_FIELD_EXPIRES_IN",
|
||||
"OAUTH_FIELD_REFRESH_TOKEN",
|
||||
|
||||
@@ -44,6 +44,16 @@ SPACY_MODEL_LG: Final[str] = "en_core_web_lg"
|
||||
SPACY_MODEL_TRF: Final[str] = "en_core_web_trf"
|
||||
"""Transformer-based English spaCy model for NER."""
|
||||
|
||||
# =============================================================================
|
||||
# NER Backend Names
|
||||
# =============================================================================
|
||||
|
||||
NER_BACKEND_SPACY: Final[str] = "spacy"
|
||||
"""spaCy NER backend identifier."""
|
||||
|
||||
NER_BACKEND_GLINER: Final[str] = "gliner"
|
||||
"""GLiNER NER backend identifier."""
|
||||
|
||||
# =============================================================================
|
||||
# Provider Names
|
||||
# =============================================================================
|
||||
|
||||
@@ -14,7 +14,7 @@ from noteflow.config.constants.core import (
|
||||
DEFAULT_OLLAMA_TIMEOUT_SECONDS,
|
||||
HOURS_PER_DAY,
|
||||
)
|
||||
from noteflow.config.constants.domain import DEFAULT_ANTHROPIC_MODEL
|
||||
from noteflow.config.constants.domain import DEFAULT_ANTHROPIC_MODEL, NER_BACKEND_SPACY
|
||||
from noteflow.config.settings._base import ENV_FILE, EXTRA_IGNORE
|
||||
from noteflow.config.settings._triggers import TriggerSettings
|
||||
|
||||
@@ -184,7 +184,9 @@ class Settings(TriggerSettings):
|
||||
]
|
||||
grpc_partial_cadence_seconds: Annotated[
|
||||
float,
|
||||
Field(default=2.0, ge=0.5, le=10.0, description="Interval for emitting partial transcripts"),
|
||||
Field(
|
||||
default=2.0, ge=0.5, le=10.0, description="Interval for emitting partial transcripts"
|
||||
),
|
||||
]
|
||||
grpc_min_partial_audio_seconds: Annotated[
|
||||
float,
|
||||
@@ -218,11 +220,18 @@ class Settings(TriggerSettings):
|
||||
]
|
||||
webhook_backoff_base: Annotated[
|
||||
float,
|
||||
Field(default=2.0, ge=1.1, le=5.0, description="Exponential backoff multiplier for webhook retries"),
|
||||
Field(
|
||||
default=2.0,
|
||||
ge=1.1,
|
||||
le=5.0,
|
||||
description="Exponential backoff multiplier for webhook retries",
|
||||
),
|
||||
]
|
||||
webhook_max_response_length: Annotated[
|
||||
int,
|
||||
Field(default=5 * 100, ge=100, le=10 * 1000, description="Maximum response body length to log"),
|
||||
Field(
|
||||
default=5 * 100, ge=100, le=10 * 1000, description="Maximum response body length to log"
|
||||
),
|
||||
]
|
||||
|
||||
# LLM/Summarization settings
|
||||
@@ -241,7 +250,9 @@ class Settings(TriggerSettings):
|
||||
]
|
||||
llm_default_anthropic_model: Annotated[
|
||||
str,
|
||||
Field(default=DEFAULT_ANTHROPIC_MODEL, description="Default Anthropic model for summarization"),
|
||||
Field(
|
||||
default=DEFAULT_ANTHROPIC_MODEL, description="Default Anthropic model for summarization"
|
||||
),
|
||||
]
|
||||
llm_timeout_seconds: Annotated[
|
||||
float,
|
||||
@@ -298,6 +309,31 @@ class Settings(TriggerSettings):
|
||||
),
|
||||
]
|
||||
|
||||
# NER settings
|
||||
ner_backend: Annotated[
|
||||
str,
|
||||
Field(
|
||||
default=NER_BACKEND_SPACY,
|
||||
description="NER backend: spacy (default) or gliner",
|
||||
),
|
||||
]
|
||||
ner_gliner_model: Annotated[
|
||||
str,
|
||||
Field(
|
||||
default="urchade/gliner_medium-v2.1",
|
||||
description="GLiNER model name from HuggingFace Hub",
|
||||
),
|
||||
]
|
||||
ner_gliner_threshold: Annotated[
|
||||
float,
|
||||
Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="GLiNER confidence threshold for entity extraction",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def database_url_str(self) -> str:
|
||||
"""Return database URL as string."""
|
||||
|
||||
@@ -27,15 +27,29 @@ class EntityCategory(Enum):
|
||||
|
||||
Note: TECHNICAL and ACRONYM are placeholders for future custom pattern
|
||||
matching (not currently mapped from spaCy's default NER model).
|
||||
|
||||
Meeting-specific categories (GLiNER):
|
||||
TIME_RELATIVE -> Relative time references (next week, tomorrow)
|
||||
DURATION -> Time durations (20 minutes, two weeks)
|
||||
EVENT -> Events and activities (meeting, demo, standup)
|
||||
TASK -> Action items and tasks (file a Jira, patch)
|
||||
DECISION -> Decisions made (decided to drop)
|
||||
TOPIC -> Discussion topics (budget, onboarding flow)
|
||||
"""
|
||||
|
||||
PERSON = "person"
|
||||
COMPANY = "company"
|
||||
PRODUCT = "product"
|
||||
TECHNICAL = "technical" # Future: custom pattern matching
|
||||
ACRONYM = "acronym" # Future: custom pattern matching
|
||||
TECHNICAL = "technical"
|
||||
ACRONYM = "acronym"
|
||||
LOCATION = ENTITY_LOCATION
|
||||
DATE = ENTITY_DATE
|
||||
TIME_RELATIVE = "time_relative"
|
||||
DURATION = "duration"
|
||||
EVENT = "event"
|
||||
TASK = "task"
|
||||
DECISION = "decision"
|
||||
TOPIC = "topic"
|
||||
OTHER = "other"
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -70,7 +70,7 @@ def _append_ner_status(payload: dict[str, object], host: _ModelStatusHost) -> No
|
||||
if ner is None:
|
||||
payload["ner_ready"] = False
|
||||
return
|
||||
payload["ner_ready"] = ner.is_ready()
|
||||
payload["ner_ready"] = ner.is_ner_ready()
|
||||
|
||||
|
||||
def _append_summarization_status(
|
||||
@@ -93,8 +93,7 @@ def _append_summarization_status(
|
||||
return
|
||||
|
||||
summaries: list[str] = [
|
||||
_format_provider_summary(mode, provider)
|
||||
for mode, provider in service.providers.items()
|
||||
_format_provider_summary(mode, provider) for mode, provider in service.providers.items()
|
||||
]
|
||||
payload["summarization_providers"] = summaries
|
||||
|
||||
|
||||
@@ -9,7 +9,8 @@ Sprint GAP-003: Error Handling Mismatches
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any
|
||||
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
@@ -19,7 +20,7 @@ logger = get_logger(__name__)
|
||||
def create_job_done_callback(
|
||||
job_id: str,
|
||||
tasks_dict: dict[str, asyncio.Task[None]],
|
||||
mark_failed: Callable[[str, str], Awaitable[None]],
|
||||
mark_failed: Callable[[str, str], Coroutine[Any, Any, None]],
|
||||
) -> Callable[[asyncio.Task[None]], None]:
|
||||
"""Create task done callback that marks jobs failed on exception.
|
||||
|
||||
@@ -54,7 +55,7 @@ def _handle_task_completion(
|
||||
task: asyncio.Task[None],
|
||||
job_id: str,
|
||||
tasks_dict: dict[str, asyncio.Task[None]],
|
||||
mark_failed: Callable[[str, str], Awaitable[None]],
|
||||
mark_failed: Callable[[str, str], Coroutine[Any, Any, None]],
|
||||
) -> None:
|
||||
"""Process task completion, logging failures and scheduling mark_failed.
|
||||
|
||||
@@ -85,7 +86,7 @@ def _handle_task_completion(
|
||||
def _log_and_schedule_failure(
|
||||
job_id: str,
|
||||
exc: BaseException,
|
||||
mark_failed: Callable[[str, str], Awaitable[None]],
|
||||
mark_failed: Callable[[str, str], Coroutine[Any, Any, None]],
|
||||
) -> None:
|
||||
"""Log task exception and schedule the mark_failed coroutine.
|
||||
|
||||
@@ -100,15 +101,25 @@ def _log_and_schedule_failure(
|
||||
exc_info=exc,
|
||||
)
|
||||
|
||||
# Schedule the mark_failed coroutine (fire-and-forget)
|
||||
# Note: This requires an active event loop, which should exist
|
||||
# since we're being called from asyncio task completion
|
||||
try:
|
||||
asyncio.create_task(mark_failed(job_id, str(exc)))
|
||||
coro = mark_failed(job_id, str(exc))
|
||||
task: asyncio.Task[None] = asyncio.create_task(coro)
|
||||
task.add_done_callback(lambda t: _log_mark_failed_result(t, job_id))
|
||||
except RuntimeError as schedule_err:
|
||||
# If we can't schedule (e.g., loop closed), log but don't crash
|
||||
logger.error(
|
||||
"Failed to schedule mark_failed for job %s: %s",
|
||||
job_id,
|
||||
schedule_err,
|
||||
)
|
||||
|
||||
|
||||
def _log_mark_failed_result(task: asyncio.Task[None], job_id: str) -> None:
|
||||
if task.cancelled():
|
||||
return
|
||||
exc = task.exception()
|
||||
if exc is not None:
|
||||
logger.error(
|
||||
"mark_failed task failed for job %s: %s",
|
||||
job_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
@@ -4,6 +4,8 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
from ....proto import noteflow_pb2
|
||||
from ...converters import create_congestion_info
|
||||
from ._constants import PROCESSING_DELAY_THRESHOLD_MS, QUEUE_DEPTH_THRESHOLD
|
||||
@@ -11,6 +13,8 @@ from ._constants import PROCESSING_DELAY_THRESHOLD_MS, QUEUE_DEPTH_THRESHOLD
|
||||
if TYPE_CHECKING:
|
||||
from ...protocols import ServicerHost
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def calculate_congestion_info(
|
||||
host: ServicerHost,
|
||||
@@ -30,20 +34,17 @@ def calculate_congestion_info(
|
||||
processing_delay_ms = 0
|
||||
if receipt_times := host.chunk_receipt_times.get(meeting_id):
|
||||
try:
|
||||
# Access [0] can race with popleft() in decrement_pending_chunks
|
||||
oldest_receipt = receipt_times[0]
|
||||
processing_delay_ms = int((current_time - oldest_receipt) * 1000)
|
||||
except IndexError:
|
||||
# Deque was emptied by concurrent popleft() - no pending chunks
|
||||
pass
|
||||
logger.debug("congestion_race_access: deque emptied by concurrent popleft")
|
||||
|
||||
# Get queue depth (pending chunks not yet processed through ASR)
|
||||
queue_depth = host.pending_chunks.get(meeting_id, 0)
|
||||
|
||||
# Determine if throttle is recommended
|
||||
throttle_recommended = (
|
||||
processing_delay_ms > PROCESSING_DELAY_THRESHOLD_MS
|
||||
or queue_depth > QUEUE_DEPTH_THRESHOLD
|
||||
processing_delay_ms > PROCESSING_DELAY_THRESHOLD_MS or queue_depth > QUEUE_DEPTH_THRESHOLD
|
||||
)
|
||||
|
||||
return create_congestion_info(
|
||||
@@ -68,5 +69,4 @@ def decrement_pending_chunks(host: ServicerHost, meeting_id: str) -> None:
|
||||
try:
|
||||
receipt_times.popleft()
|
||||
except IndexError:
|
||||
# Deque already empty - can occur if multiple coroutines race
|
||||
pass
|
||||
logger.debug("congestion_race_popleft: deque already empty from concurrent access")
|
||||
|
||||
@@ -154,9 +154,18 @@ class NoteFlowServer:
|
||||
self._servicer: NoteFlowServicer | None = None
|
||||
self._warmup_task: asyncio.Task[None] | None = None
|
||||
|
||||
def _warmup_task_done_callback(self, task: asyncio.Task[None]) -> None:
|
||||
if task.cancelled():
|
||||
return
|
||||
exc = task.exception()
|
||||
if exc is not None:
|
||||
logger.error("diarization_warmup_failed", error=str(exc))
|
||||
|
||||
async def _apply_persisted_asr_config(self) -> None:
|
||||
settings = get_settings()
|
||||
stored = await _load_asr_config_preference(self._state.session_factory, settings.meetings_dir)
|
||||
stored = await _load_asr_config_preference(
|
||||
self._state.session_factory, settings.meetings_dir
|
||||
)
|
||||
if stored is None:
|
||||
return
|
||||
|
||||
@@ -207,6 +216,7 @@ class NoteFlowServer:
|
||||
self._warmup_task = asyncio.create_task(
|
||||
warm_diarization_engine(self._state.diarization_engine)
|
||||
)
|
||||
self._warmup_task.add_done_callback(self._warmup_task_done_callback)
|
||||
self._servicer = build_servicer(self._state, asr_engine, self._streaming_config)
|
||||
await recover_orphaned_jobs(self._state.session_factory)
|
||||
|
||||
@@ -221,6 +231,14 @@ class NoteFlowServer:
|
||||
Args:
|
||||
grace_period: Time to wait for in-flight RPCs.
|
||||
"""
|
||||
if self._warmup_task is not None and not self._warmup_task.done():
|
||||
self._warmup_task.cancel()
|
||||
try:
|
||||
await self._warmup_task
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Diarization warmup task cancelled")
|
||||
self._warmup_task = None
|
||||
|
||||
await stop_server(self._server, self._servicer, self._db_engine, grace_period)
|
||||
if self._db_engine is not None:
|
||||
self._db_engine = None
|
||||
@@ -231,7 +249,6 @@ class NoteFlowServer:
|
||||
await self._server.wait_for_termination()
|
||||
|
||||
|
||||
|
||||
async def run_server_with_config(config: GrpcServerConfig) -> None:
|
||||
"""Run the async gRPC server with structured configuration.
|
||||
|
||||
|
||||
@@ -124,6 +124,7 @@ class ServicerStreamingStateMixin:
|
||||
chunk_counts: dict[str, int]
|
||||
chunk_receipt_times: dict[str, deque[float]]
|
||||
pending_chunks: dict[str, int]
|
||||
stop_requested: set[str]
|
||||
DEFAULT_SAMPLE_RATE: int
|
||||
def init_streaming_state(self, meeting_id: str, next_segment_id: int) -> None:
|
||||
"""Initialize VAD, Segmenter, speaking state, and partial buffers for a meeting."""
|
||||
@@ -177,6 +178,7 @@ class ServicerStreamingStateMixin:
|
||||
|
||||
self.chunk_sequences.pop(meeting_id, None)
|
||||
self.chunk_counts.pop(meeting_id, None)
|
||||
self.stop_requested.discard(meeting_id)
|
||||
|
||||
if hasattr(self, "_chunk_receipt_times"):
|
||||
self.chunk_receipt_times.pop(meeting_id, None)
|
||||
@@ -298,8 +300,7 @@ class ServicerInfoMixin:
|
||||
|
||||
diarization_enabled = self.diarization_engine is not None
|
||||
diarization_ready = self.diarization_engine is not None and (
|
||||
self.diarization_engine.is_streaming_loaded
|
||||
or self.diarization_engine.is_offline_loaded
|
||||
self.diarization_engine.is_streaming_loaded or self.diarization_engine.is_offline_loaded
|
||||
)
|
||||
|
||||
if self.session_factory is not None:
|
||||
|
||||
@@ -10,12 +10,14 @@ from noteflow.application.services.calendar import CalendarService
|
||||
from noteflow.application.services.ner import NerService
|
||||
from noteflow.application.services.webhooks import WebhookService
|
||||
from noteflow.config.settings import Settings, get_calendar_settings, get_feature_flags
|
||||
from noteflow.config.constants.domain import NER_BACKEND_GLINER
|
||||
from noteflow.domain.constants.fields import CALENDAR, DEVICE
|
||||
from noteflow.domain.entities.integration import IntegrationStatus
|
||||
from noteflow.domain.ports.gpu import GpuBackend
|
||||
from noteflow.infrastructure.diarization import DiarizationEngine
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
from noteflow.infrastructure.ner import NerEngine
|
||||
from noteflow.infrastructure.ner.backends.types import NerBackend
|
||||
from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWork
|
||||
from noteflow.infrastructure.webhooks import WebhookExecutor
|
||||
|
||||
@@ -38,11 +40,7 @@ async def check_calendar_needed_from_db(uow: SqlAlchemyUnitOfWork) -> bool:
|
||||
return False
|
||||
|
||||
calendar_integrations = await uow.integrations.list_by_type(CALENDAR)
|
||||
if connected := [
|
||||
i
|
||||
for i in calendar_integrations
|
||||
if i.status == IntegrationStatus.CONNECTED
|
||||
]:
|
||||
if connected := [i for i in calendar_integrations if i.status == IntegrationStatus.CONNECTED]:
|
||||
logger.info(
|
||||
"Auto-enabling calendar: found %d connected OAuth integration(s)",
|
||||
len(connected),
|
||||
@@ -84,8 +82,9 @@ def create_ner_service(
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info("Initializing NER service (spaCy)...")
|
||||
ner_engine = NerEngine()
|
||||
backend = _create_ner_backend(settings)
|
||||
logger.info("Initializing NER service", backend=settings.ner_backend)
|
||||
ner_engine = NerEngine(backend=backend)
|
||||
ner_service = NerService(
|
||||
ner_engine=ner_engine,
|
||||
uow_factory=lambda: SqlAlchemyUnitOfWork(session_factory, settings.meetings_dir),
|
||||
@@ -94,6 +93,21 @@ def create_ner_service(
|
||||
return ner_service
|
||||
|
||||
|
||||
def _create_ner_backend(settings: Settings) -> NerBackend:
|
||||
"""Create NER backend based on settings."""
|
||||
if settings.ner_backend == NER_BACKEND_GLINER:
|
||||
from noteflow.infrastructure.ner.backends.gliner_backend import GLiNERBackend
|
||||
|
||||
return GLiNERBackend(
|
||||
model_name=settings.ner_gliner_model,
|
||||
threshold=settings.ner_gliner_threshold,
|
||||
)
|
||||
|
||||
from noteflow.infrastructure.ner.backends.spacy_backend import SpacyBackend
|
||||
|
||||
return SpacyBackend()
|
||||
|
||||
|
||||
async def create_calendar_service(
|
||||
session_factory: async_sessionmaker[AsyncSession] | None,
|
||||
settings: Settings,
|
||||
|
||||
352
src/noteflow/infrastructure/CLAUDE.md
Normal file
352
src/noteflow/infrastructure/CLAUDE.md
Normal file
@@ -0,0 +1,352 @@
|
||||
# Infrastructure Layer Development Guide
|
||||
|
||||
## Overview
|
||||
|
||||
The infrastructure layer (`src/noteflow/infrastructure/`) contains adapters that implement domain ports. These connect the application to external systems: databases, ML models, cloud APIs, file systems.
|
||||
|
||||
**Architecture**: Hexagonal (Ports-and-Adapters)
|
||||
- Domain Ports define interfaces in `domain/ports/`
|
||||
- Infrastructure Adapters implement those interfaces here
|
||||
- **Rule**: Infrastructure imports domain; domain NEVER imports infrastructure
|
||||
|
||||
---
|
||||
|
||||
## Adapter Catalog
|
||||
|
||||
| Directory | Responsibility | Port/Protocol | Key Dependencies |
|
||||
|-----------|----------------|---------------|------------------|
|
||||
| `asr/` | Speech-to-text (Whisper) | `asr/protocols.AsrEngine` | faster-whisper |
|
||||
| `diarization/` | Speaker identification | internal protocols | pyannote, diart |
|
||||
| `ner/` | Named entity extraction | `domain/ports/ner.NerPort` | spacy, gliner |
|
||||
| `summarization/` | LLM summarization | `application/services/protocols` | anthropic, openai |
|
||||
| `persistence/` | Database (SQLAlchemy) | `domain/ports/unit_of_work.UnitOfWork` | sqlalchemy, asyncpg |
|
||||
| `calendar/` | OAuth + event sync | `domain/ports/calendar.CalendarPort` | httpx |
|
||||
| `webhooks/` | Event delivery | — | httpx |
|
||||
| `export/` | PDF/HTML/Markdown | `export/protocols.TranscriptExporter` | weasyprint, jinja2 |
|
||||
| `audio/` | Recording/playback | `audio/protocols.AudioCapture` | sounddevice |
|
||||
| `security/` | Encryption | `security/protocols.CryptoBox`, `KeyStore` | cryptography, keyring |
|
||||
| `gpu/` | GPU detection | — | torch |
|
||||
| `triggers/` | Signal providers | internal protocols | platform-specific |
|
||||
| `auth/` | OIDC management | — | httpx |
|
||||
| `converters/` | Layer bridging | — | — |
|
||||
| `logging/` | Structured logging | — | structlog |
|
||||
| `observability/` | OpenTelemetry & usage tracking | — | opentelemetry |
|
||||
| `metrics/` | Metrics collection | — | prometheus_client |
|
||||
|
||||
**Note**: Some protocols are domain-level ports (`domain/ports/`), others are infrastructure-level protocols defined within the subsystem.
|
||||
|
||||
---
|
||||
|
||||
## Common Implementation Patterns
|
||||
|
||||
### 1. Async Wrappers for Sync Libraries
|
||||
|
||||
Many ML libraries (spaCy, faster-whisper, pyannote) are synchronous. Always wrap:
|
||||
|
||||
```python
|
||||
async def extract_async(self, text: str) -> list[NamedEntity]:
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
None, # Default ThreadPoolExecutor
|
||||
self._sync_extract,
|
||||
text
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Backend Selection via Factory
|
||||
|
||||
```python
|
||||
def create_ner_engine(config: NerConfig) -> NerPort:
|
||||
match config.backend:
|
||||
case "spacy":
|
||||
return SpacyBackend(model=config.model_name)
|
||||
case "gliner":
|
||||
return GlinerBackend(model=config.model_name)
|
||||
case _:
|
||||
raise ValueError(f"Unknown NER backend: {config.backend}")
|
||||
```
|
||||
|
||||
### 3. Protocol-Based Backends
|
||||
|
||||
```python
|
||||
# backends/types.py
|
||||
class NerBackend(Protocol):
|
||||
def extract(self, text: str) -> list[RawEntity]: ...
|
||||
def load_model(self) -> None: ...
|
||||
def unload(self) -> None: ...
|
||||
@property
|
||||
def model_loaded(self) -> bool: ...
|
||||
```
|
||||
|
||||
### 4. Capability Flags for Optional Features
|
||||
|
||||
```python
|
||||
class SqlAlchemyUnitOfWork(UnitOfWork):
|
||||
@property
|
||||
def supports_entities(self) -> bool:
|
||||
return True # Has EntityRepository
|
||||
|
||||
@property
|
||||
def supports_webhooks(self) -> bool:
|
||||
return True # Has WebhookRepository
|
||||
```
|
||||
|
||||
**Always check capability before accessing optional repository**:
|
||||
```python
|
||||
if uow.supports_entities:
|
||||
entities = await uow.entities.get_by_meeting(meeting_id)
|
||||
```
|
||||
|
||||
### 5. Mixin Composition for Complex Engines
|
||||
|
||||
```python
|
||||
class DiarizationEngine(
|
||||
DeviceMixin, # GPU/CPU detection
|
||||
StreamingMixin, # Real-time diarization (diart)
|
||||
OfflineMixin, # Batch diarization (pyannote)
|
||||
ProcessingMixin, # Audio frame processing
|
||||
LifecycleMixin, # Model loading/unloading
|
||||
):
|
||||
"""Composed from focused mixins for maintainability."""
|
||||
```
|
||||
|
||||
### 6. Data Transfer Objects (DTOs)
|
||||
|
||||
```python
|
||||
@dataclass(frozen=True)
|
||||
class RawEntity:
|
||||
text: str
|
||||
label: str
|
||||
start: int
|
||||
end: int
|
||||
confidence: float = 1.0
|
||||
```
|
||||
|
||||
Use frozen dataclasses for immutable DTOs between layers.
|
||||
|
||||
---
|
||||
|
||||
## Subsystem-Specific Guidance
|
||||
|
||||
### NER (`ner/`)
|
||||
|
||||
**Structure**:
|
||||
```
|
||||
ner/
|
||||
├── engine.py # NerEngine (composition class)
|
||||
├── mapper.py # RawEntity → NamedEntity
|
||||
├── post_processing.py # Pipeline: drop → infer → merge → dedupe
|
||||
└── backends/
|
||||
├── types.py # NerBackend protocol, RawEntity
|
||||
├── gliner_backend.py
|
||||
└── spacy_backend.py
|
||||
```
|
||||
|
||||
**Key classes**:
|
||||
- `NerEngine`: Composes backend + mapper + post-processing
|
||||
- `NerBackend`: Protocol for backend implementations
|
||||
- `RawEntity`: Backend output (frozen dataclass)
|
||||
|
||||
**Flow**: `Backend.extract()` → `map_raw_entities_to_named()` → `post_process()` → `dedupe_entities()`
|
||||
|
||||
### ASR (`asr/`)
|
||||
|
||||
**Key classes**:
|
||||
- `FasterWhisperEngine`: Whisper wrapper using CTranslate2
|
||||
- `AsrResult`: Dataclass with text, language, segments, duration
|
||||
- `StreamingVad`: Real-time voice activity detection
|
||||
- `Segmenter`: VAD-based audio chunking
|
||||
|
||||
**Device handling**: Uses `compute_type` (int8, float16, float32) + `device` (cpu, cuda, rocm)
|
||||
|
||||
### Diarization (`diarization/`)
|
||||
|
||||
**Modes**:
|
||||
- **Streaming** (diart): Real-time speaker detection during recording
|
||||
- **Offline** (pyannote): Post-meeting refinement with full audio
|
||||
|
||||
**Key classes**:
|
||||
- `DiarizationEngine`: Main engine with mixins
|
||||
- `DiarizationSession`: Session state for active diarization
|
||||
- `SpeakerTurn`: Dataclass with speaker_id, start, end
|
||||
|
||||
### Persistence (`persistence/`)
|
||||
|
||||
**Structure**:
|
||||
```
|
||||
persistence/
|
||||
├── unit_of_work/ # SqlAlchemyUnitOfWork (async context manager)
|
||||
├── models/ # SQLAlchemy ORM models
|
||||
├── repositories/ # Repository implementations
|
||||
├── memory/ # In-memory fallback for testing
|
||||
└── migrations/ # Alembic migrations
|
||||
```
|
||||
|
||||
**Unit of Work Pattern**:
|
||||
```python
|
||||
async with uow:
|
||||
meeting = await uow.meetings.get(meeting_id)
|
||||
segments = await uow.segments.get_by_meeting(meeting_id)
|
||||
# Auto-commit on success, rollback on exception
|
||||
```
|
||||
|
||||
**Base repository helpers**:
|
||||
- `_execute_scalar()`: Single result query with timing
|
||||
- `_execute_scalars()`: Multiple results with timing
|
||||
- `_add_and_flush()`: Persist single model
|
||||
- `_add_all_and_flush()`: Batch persist
|
||||
|
||||
### Summarization (`summarization/`)
|
||||
|
||||
**Providers**:
|
||||
- `CloudProvider`: Anthropic/OpenAI wrapper
|
||||
- `OllamaSummarizer`: Local LLM via Ollama
|
||||
- `MockSummarizer`: Testing fixture
|
||||
|
||||
**Key features**:
|
||||
- Factory auto-detects available providers
|
||||
- Citation verification links claims to segments
|
||||
- Template rendering via Jinja2
|
||||
|
||||
### Webhooks (`webhooks/`)
|
||||
|
||||
**Key features**:
|
||||
- HMAC-SHA256 signing (`X-Noteflow-Signature` header)
|
||||
- Exponential backoff with configurable multiplier
|
||||
- Async delivery with httpx connection pooling
|
||||
- Delivery tracking and metrics
|
||||
|
||||
### Calendar (`calendar/`)
|
||||
|
||||
**Adapters**: Google Calendar, Outlook (Microsoft Graph)
|
||||
|
||||
**Key classes**:
|
||||
- `GoogleCalendarAdapter`: Google Calendar API v3
|
||||
- `OutlookCalendarAdapter`: Microsoft Graph API
|
||||
- `OAuthManager`: OAuth flow orchestration
|
||||
|
||||
---
|
||||
|
||||
## Forbidden Patterns
|
||||
|
||||
### Direct database access outside persistence/
|
||||
```python
|
||||
# WRONG: Raw SQL in service layer
|
||||
async with engine.connect() as conn:
|
||||
result = await conn.execute(text("SELECT * FROM meetings"))
|
||||
```
|
||||
|
||||
### Hardcoded API keys
|
||||
```python
|
||||
# WRONG: Secrets in code
|
||||
client = anthropic.Anthropic(api_key="sk-ant-...")
|
||||
# RIGHT: Use environment variables or config
|
||||
```
|
||||
|
||||
### Synchronous I/O in async context
|
||||
```python
|
||||
# WRONG: Blocking the event loop
|
||||
def load_model(self):
|
||||
self.model = whisper.load_model("base") # Blocks!
|
||||
|
||||
# RIGHT: Use executor
|
||||
async def load_model_async(self):
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, self._load_model_sync)
|
||||
```
|
||||
|
||||
### Domain mutations in infrastructure
|
||||
```python
|
||||
# WRONG: Infrastructure modifying domain objects
|
||||
from noteflow.domain.entities import Meeting
|
||||
meeting.state = "COMPLETED" # Don't mutate here!
|
||||
|
||||
# RIGHT: Return new data, let application layer handle state
|
||||
```
|
||||
|
||||
### Skipping capability checks
|
||||
```python
|
||||
# WRONG: Assuming optional features exist
|
||||
entities = await uow.entities.get_by_meeting(meeting_id)
|
||||
|
||||
# RIGHT: Check capability first
|
||||
if uow.supports_entities:
|
||||
entities = await uow.entities.get_by_meeting(meeting_id)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Testing Infrastructure Adapters
|
||||
|
||||
### Unit Tests: Mock Dependencies
|
||||
|
||||
```python
|
||||
# tests/infrastructure/ner/test_engine.py
|
||||
@pytest.fixture
|
||||
def mock_backend() -> NerBackend:
|
||||
backend = Mock(spec=NerBackend)
|
||||
backend.extract.return_value = [
|
||||
RawEntity(text="John", label="PERSON", start=0, end=4)
|
||||
]
|
||||
return backend
|
||||
|
||||
async def test_engine_uses_backend(mock_backend):
|
||||
engine = NerEngine(backend=mock_backend)
|
||||
result = await engine.extract_async("Hello John")
|
||||
mock_backend.extract.assert_called_once()
|
||||
```
|
||||
|
||||
### Integration Tests: Real Services
|
||||
|
||||
```python
|
||||
# tests/integration/test_ner_integration.py
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.requires_gpu
|
||||
async def test_gliner_real_extraction():
|
||||
backend = GlinerBackend(model="urchade/gliner_base")
|
||||
result = backend.extract("Microsoft CEO Satya Nadella announced...")
|
||||
assert any(e.label == "ORG" and "Microsoft" in e.text for e in result)
|
||||
```
|
||||
|
||||
### Use pytest markers:
|
||||
- `@pytest.mark.integration` — requires external services
|
||||
- `@pytest.mark.requires_gpu` — requires CUDA/ROCm
|
||||
- `@pytest.mark.slow` — long-running tests
|
||||
|
||||
---
|
||||
|
||||
## Adding a New Adapter
|
||||
|
||||
1. **Define port in domain** (`domain/ports/`) if not exists
|
||||
2. **Create adapter directory** (`infrastructure/<adapter_name>/`)
|
||||
3. **Implement the protocol** with proper async handling
|
||||
4. **Add factory function** for backend selection (if multiple implementations)
|
||||
5. **Write unit tests** with mocked dependencies
|
||||
6. **Write integration test** with real external service
|
||||
7. **Update gRPC startup** (`grpc/startup/services.py`) for dependency injection
|
||||
8. **Update this file** (add to Adapter Catalog table)
|
||||
|
||||
---
|
||||
|
||||
## Key Files Reference
|
||||
|
||||
| Pattern | Location |
|
||||
|---------|----------|
|
||||
| Main engine | `<subsystem>/engine.py` |
|
||||
| Backend protocol | `<subsystem>/backends/types.py` |
|
||||
| Backend implementations | `<subsystem>/backends/<name>_backend.py` |
|
||||
| External→Domain mapping | `<subsystem>/mapper.py` |
|
||||
| Output normalization | `<subsystem>/post_processing.py` |
|
||||
| UnitOfWork | `persistence/unit_of_work/unit_of_work.py` |
|
||||
| Repository base | `persistence/repositories/_base/_base.py` |
|
||||
| ORM models | `persistence/models/<domain>/*.py` |
|
||||
| Migrations | `persistence/migrations/versions/*.py` |
|
||||
|
||||
---
|
||||
|
||||
## See Also
|
||||
|
||||
- `/src/noteflow/domain/ports/` — Port definitions (interfaces)
|
||||
- `/src/noteflow/grpc/startup/services.py` — Dependency injection wiring
|
||||
- `/tests/infrastructure/` — Adapter test suites
|
||||
- `/src/noteflow/CLAUDE.md` — Python backend standards
|
||||
@@ -23,6 +23,8 @@ from noteflow.domain.constants.fields import START
|
||||
from noteflow.infrastructure.asr.dto import AsrResult, WordTiming
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
_KEY_LANGUAGE: str = "language"
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
@@ -73,6 +75,7 @@ class _WhisperModel(Protocol):
|
||||
"""Convert model to half precision."""
|
||||
...
|
||||
|
||||
|
||||
# Valid model sizes for openai-whisper
|
||||
PYTORCH_VALID_MODEL_SIZES: tuple[str, ...] = (
|
||||
"tiny",
|
||||
@@ -232,14 +235,14 @@ class WhisperPyTorchEngine:
|
||||
}
|
||||
|
||||
if language is not None:
|
||||
options["language"] = language
|
||||
options[_KEY_LANGUAGE] = language
|
||||
|
||||
# Transcribe
|
||||
result = self._model.transcribe(audio, **options)
|
||||
|
||||
# Convert to our segment format
|
||||
segments = result.get("segments", [])
|
||||
detected_language = result.get("language", "en")
|
||||
detected_language = result.get(_KEY_LANGUAGE, "en")
|
||||
|
||||
for segment in segments:
|
||||
words = self._extract_word_timings(segment)
|
||||
|
||||
@@ -6,6 +6,7 @@ Provide cross-platform audio input capture with device handling.
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import weakref
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import TYPE_CHECKING, Unpack
|
||||
|
||||
@@ -239,6 +240,8 @@ class SoundDeviceCapture:
|
||||
def _build_stream_callback(
|
||||
self, channels: int
|
||||
) -> Callable[[NDArray[np.float32], int, object, object], None]:
|
||||
weak_self = weakref.ref(self)
|
||||
|
||||
def _stream_callback(
|
||||
indata: NDArray[np.float32],
|
||||
frames: int,
|
||||
@@ -250,9 +253,10 @@ class SoundDeviceCapture:
|
||||
if status:
|
||||
logger.warning("Audio stream status: %s", status)
|
||||
|
||||
if self._callback is not None:
|
||||
capture = weak_self()
|
||||
if capture is not None and capture._callback is not None:
|
||||
audio_data = indata[:, 0].copy() if channels == 1 else indata.flatten()
|
||||
timestamp = time.monotonic()
|
||||
self._callback(audio_data, timestamp)
|
||||
capture._callback(audio_data, timestamp)
|
||||
|
||||
return _stream_callback
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""Named Entity Recognition infrastructure."""
|
||||
|
||||
from noteflow.infrastructure.ner.backends import NerBackend, RawEntity
|
||||
from noteflow.infrastructure.ner.backends.gliner_backend import GLiNERBackend
|
||||
from noteflow.infrastructure.ner.backends.spacy_backend import SpacyBackend
|
||||
from noteflow.infrastructure.ner.engine import NerEngine
|
||||
|
||||
__all__ = ["NerEngine"]
|
||||
__all__ = ["GLiNERBackend", "NerBackend", "NerEngine", "RawEntity", "SpacyBackend"]
|
||||
|
||||
5
src/noteflow/infrastructure/ner/backends/__init__.py
Normal file
5
src/noteflow/infrastructure/ner/backends/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""NER backend implementations."""
|
||||
|
||||
from noteflow.infrastructure.ner.backends.types import NerBackend, RawEntity
|
||||
|
||||
__all__ = ["NerBackend", "RawEntity"]
|
||||
122
src/noteflow/infrastructure/ner/backends/gliner_backend.py
Normal file
122
src/noteflow/infrastructure/ner/backends/gliner_backend.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""GLiNER NER backend implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Final, TypedDict, cast
|
||||
|
||||
from noteflow.domain.entities.named_entity import EntityCategory
|
||||
from noteflow.infrastructure.logging import get_logger, log_timing
|
||||
from noteflow.infrastructure.ner.backends.types import RawEntity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gliner import GLiNER
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class _GLiNEREntityDict(TypedDict):
|
||||
text: str
|
||||
label: str
|
||||
start: int
|
||||
end: int
|
||||
score: float
|
||||
|
||||
|
||||
DEFAULT_MODEL: Final[str] = "urchade/gliner_medium-v2.1"
|
||||
DEFAULT_THRESHOLD: Final[float] = 0.5
|
||||
|
||||
MEETING_LABELS: Final[tuple[str, ...]] = (
|
||||
EntityCategory.PERSON.value,
|
||||
"org",
|
||||
EntityCategory.PRODUCT.value,
|
||||
"app",
|
||||
EntityCategory.LOCATION.value,
|
||||
"time",
|
||||
EntityCategory.TIME_RELATIVE.value,
|
||||
EntityCategory.DURATION.value,
|
||||
EntityCategory.EVENT.value,
|
||||
EntityCategory.TASK.value,
|
||||
EntityCategory.DECISION.value,
|
||||
EntityCategory.TOPIC.value,
|
||||
)
|
||||
|
||||
|
||||
class GLiNERBackend:
|
||||
"""GLiNER-based NER backend for meeting transcripts.
|
||||
|
||||
Uses generalist zero-shot NER with custom label sets optimized
|
||||
for meeting content (people, products, events, decisions, etc.).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = DEFAULT_MODEL,
|
||||
labels: tuple[str, ...] = MEETING_LABELS,
|
||||
threshold: float = DEFAULT_THRESHOLD,
|
||||
) -> None:
|
||||
self._model_name = model_name
|
||||
self._labels = labels
|
||||
self._threshold = threshold
|
||||
self._model: GLiNER | None = None
|
||||
|
||||
def load_model(self) -> None:
|
||||
from gliner import GLiNER as GLiNERClass
|
||||
|
||||
with log_timing("gliner_model_load", model_name=self._model_name):
|
||||
self._model = GLiNERClass.from_pretrained(self._model_name)
|
||||
logger.info("GLiNER model loaded", model_name=self._model_name)
|
||||
|
||||
def _ensure_loaded(self) -> GLiNER:
|
||||
if self._model is None:
|
||||
self.load_model()
|
||||
assert self._model is not None
|
||||
return self._model
|
||||
|
||||
def model_loaded(self) -> bool:
|
||||
return self._model is not None
|
||||
|
||||
def unload(self) -> None:
|
||||
self._model = None
|
||||
logger.info("GLiNER model unloaded")
|
||||
|
||||
def _call_predict_entities(
|
||||
self,
|
||||
model: GLiNER,
|
||||
text: str,
|
||||
) -> list[_GLiNEREntityDict]:
|
||||
result = model.predict_entities(
|
||||
text,
|
||||
labels=list(self._labels),
|
||||
threshold=self._threshold,
|
||||
)
|
||||
return cast(list[_GLiNEREntityDict], result)
|
||||
|
||||
def extract(self, text: str) -> list[RawEntity]:
|
||||
if not text or not text.strip():
|
||||
return []
|
||||
|
||||
model = self._ensure_loaded()
|
||||
raw_entities = self._call_predict_entities(model, text)
|
||||
|
||||
return [
|
||||
RawEntity(
|
||||
text=ent["text"],
|
||||
label=ent["label"].lower(),
|
||||
start=ent["start"],
|
||||
end=ent["end"],
|
||||
confidence=ent["score"],
|
||||
)
|
||||
for ent in raw_entities
|
||||
]
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def labels(self) -> tuple[str, ...]:
|
||||
return self._labels
|
||||
|
||||
@property
|
||||
def threshold(self) -> float:
|
||||
return self._threshold
|
||||
152
src/noteflow/infrastructure/ner/backends/spacy_backend.py
Normal file
152
src/noteflow/infrastructure/ner/backends/spacy_backend.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""spaCy NER backend implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Final, Protocol
|
||||
|
||||
from noteflow.config.constants import (
|
||||
SPACY_MODEL_LG,
|
||||
SPACY_MODEL_MD,
|
||||
SPACY_MODEL_SM,
|
||||
SPACY_MODEL_TRF,
|
||||
)
|
||||
from noteflow.infrastructure.logging import get_logger, log_timing
|
||||
from noteflow.infrastructure.ner.backends.types import RawEntity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from spacy.language import Language
|
||||
from spacy.tokens import Span
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class _SpacyModule(Protocol):
|
||||
def load(self, name: str) -> Language: ...
|
||||
|
||||
|
||||
VALID_SPACY_MODELS: Final[tuple[str, ...]] = (
|
||||
SPACY_MODEL_SM,
|
||||
SPACY_MODEL_MD,
|
||||
SPACY_MODEL_LG,
|
||||
SPACY_MODEL_TRF,
|
||||
)
|
||||
|
||||
SKIP_ENTITY_TYPES: Final[frozenset[str]] = frozenset(
|
||||
{
|
||||
"CARDINAL",
|
||||
"ORDINAL",
|
||||
"QUANTITY",
|
||||
"PERCENT",
|
||||
"MONEY",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class SpacyBackend:
|
||||
"""spaCy-based NER backend.
|
||||
|
||||
Uses spaCy's pre-trained models for entity extraction. Supports
|
||||
lazy model loading and fallback from transformer to small model.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = SPACY_MODEL_TRF) -> None:
|
||||
if model_name not in VALID_SPACY_MODELS:
|
||||
raise ValueError(
|
||||
f"Invalid model name: {model_name}. Valid models: {', '.join(VALID_SPACY_MODELS)}"
|
||||
)
|
||||
self._model_name = model_name
|
||||
self._nlp: Language | None = None
|
||||
|
||||
def load_model(self) -> None:
|
||||
import spacy
|
||||
|
||||
self._warn_if_curated_transformers_missing()
|
||||
with log_timing("spacy_model_load", model_name=self._model_name):
|
||||
try:
|
||||
self._nlp = spacy.load(self._model_name)
|
||||
except OSError as exc:
|
||||
self._nlp = self._handle_model_load_failure(spacy, exc)
|
||||
|
||||
logger.info("spaCy model loaded", model_name=self._model_name)
|
||||
|
||||
def _warn_if_curated_transformers_missing(self) -> None:
|
||||
if self._model_name != SPACY_MODEL_TRF:
|
||||
return
|
||||
try:
|
||||
importlib.import_module("spacy_curated_transformers")
|
||||
except ModuleNotFoundError:
|
||||
logger.warning(
|
||||
"spaCy curated transformers not installed; transformer model load may fail",
|
||||
model_name=self._model_name,
|
||||
)
|
||||
|
||||
def _handle_model_load_failure(
|
||||
self,
|
||||
spacy_module: _SpacyModule,
|
||||
exc: OSError,
|
||||
) -> Language:
|
||||
if self._model_name == SPACY_MODEL_TRF:
|
||||
return self._load_fallback_model(spacy_module)
|
||||
msg = (
|
||||
f"Failed to load spaCy model '{self._model_name}'. "
|
||||
f"Run: python -m spacy download {self._model_name}"
|
||||
)
|
||||
raise RuntimeError(msg) from exc
|
||||
|
||||
def _load_fallback_model(self, spacy_module: _SpacyModule) -> Language:
|
||||
fallback_model = SPACY_MODEL_SM
|
||||
logger.warning(
|
||||
"spaCy model '%s' unavailable; falling back to '%s'",
|
||||
self._model_name,
|
||||
fallback_model,
|
||||
)
|
||||
try:
|
||||
nlp = spacy_module.load(fallback_model)
|
||||
except OSError as fallback_error:
|
||||
msg = (
|
||||
f"Failed to load spaCy model '{self._model_name}' "
|
||||
f"and fallback '{fallback_model}'. "
|
||||
f"Run: python -m spacy download {fallback_model}"
|
||||
)
|
||||
raise RuntimeError(msg) from fallback_error
|
||||
self._model_name = fallback_model
|
||||
return nlp
|
||||
|
||||
def _ensure_loaded(self) -> Language:
|
||||
if self._nlp is None:
|
||||
self.load_model()
|
||||
assert self._nlp is not None
|
||||
return self._nlp
|
||||
|
||||
def model_loaded(self) -> bool:
|
||||
return self._nlp is not None
|
||||
|
||||
def unload(self) -> None:
|
||||
self._nlp = None
|
||||
logger.info("spaCy model unloaded")
|
||||
|
||||
def extract(self, text: str) -> list[RawEntity]:
|
||||
if not text or not text.strip():
|
||||
return []
|
||||
|
||||
nlp = self._ensure_loaded()
|
||||
doc = nlp(text)
|
||||
|
||||
return [
|
||||
_spacy_entity_to_raw(ent) for ent in doc.ents if ent.label_ not in SKIP_ENTITY_TYPES
|
||||
]
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model_name
|
||||
|
||||
|
||||
def _spacy_entity_to_raw(ent: Span) -> RawEntity:
|
||||
return RawEntity(
|
||||
text=ent.text,
|
||||
label=ent.label_.lower(),
|
||||
start=ent.start_char,
|
||||
end=ent.end_char,
|
||||
confidence=None,
|
||||
)
|
||||
56
src/noteflow/infrastructure/ner/backends/types.py
Normal file
56
src/noteflow/infrastructure/ner/backends/types.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Backend-agnostic types for NER extraction."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RawEntity:
|
||||
"""Backend-agnostic entity representation.
|
||||
|
||||
All backends convert their output to this format before post-processing.
|
||||
Labels must be lowercase (e.g., 'person', 'time_relative').
|
||||
"""
|
||||
|
||||
text: str
|
||||
label: str
|
||||
start: int
|
||||
end: int
|
||||
confidence: float | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.label != self.label.lower():
|
||||
object.__setattr__(self, "label", self.label.lower())
|
||||
|
||||
|
||||
class NerBackend(Protocol):
|
||||
"""Protocol for NER backend implementations.
|
||||
|
||||
Backends extract raw entities from text. Post-processing and mapping
|
||||
to domain entities happens in separate pipeline stages.
|
||||
"""
|
||||
|
||||
def extract(self, text: str) -> list[RawEntity]:
|
||||
"""Extract raw entities from text.
|
||||
|
||||
Args:
|
||||
text: Input text to analyze.
|
||||
|
||||
Returns:
|
||||
List of raw entities with lowercase labels.
|
||||
"""
|
||||
...
|
||||
|
||||
def model_loaded(self) -> bool:
|
||||
"""Check if the backend model is loaded and ready."""
|
||||
...
|
||||
|
||||
def load_model(self) -> None:
|
||||
"""Load the model (for lazy initialization)."""
|
||||
...
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload the model to free resources."""
|
||||
...
|
||||
@@ -1,264 +1,51 @@
|
||||
"""NER engine implementation using spaCy.
|
||||
"""NER engine with backend composition.
|
||||
|
||||
Provides named entity extraction with lazy model loading and segment tracking.
|
||||
Provides named entity extraction with configurable backends and shared
|
||||
post-processing pipeline.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Final, Protocol
|
||||
|
||||
from noteflow.config.constants import (
|
||||
SPACY_MODEL_LG,
|
||||
SPACY_MODEL_MD,
|
||||
SPACY_MODEL_SM,
|
||||
SPACY_MODEL_TRF,
|
||||
)
|
||||
from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity
|
||||
from noteflow.infrastructure.logging import get_logger, log_timing
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from spacy.language import Language
|
||||
|
||||
|
||||
class _SpacyModule(Protocol):
|
||||
def load(self, name: str) -> Language: ...
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Map spaCy entity types to our categories
|
||||
_SPACY_CATEGORY_MAP: Final[dict[str, EntityCategory]] = {
|
||||
# People
|
||||
"PERSON": EntityCategory.PERSON,
|
||||
# Organizations
|
||||
"ORG": EntityCategory.COMPANY,
|
||||
# Products and creative works
|
||||
"PRODUCT": EntityCategory.PRODUCT,
|
||||
"WORK_OF_ART": EntityCategory.PRODUCT,
|
||||
# Locations
|
||||
"GPE": EntityCategory.LOCATION, # Geo-political entity (countries, cities)
|
||||
"LOC": EntityCategory.LOCATION, # Non-GPE locations (mountains, rivers)
|
||||
"FAC": EntityCategory.LOCATION, # Facilities (buildings, airports)
|
||||
# Dates and times
|
||||
"DATE": EntityCategory.DATE,
|
||||
"TIME": EntityCategory.DATE,
|
||||
# Others (filtered out or mapped to OTHER)
|
||||
"MONEY": EntityCategory.OTHER,
|
||||
"PERCENT": EntityCategory.OTHER,
|
||||
"CARDINAL": EntityCategory.OTHER,
|
||||
"ORDINAL": EntityCategory.OTHER,
|
||||
"QUANTITY": EntityCategory.OTHER,
|
||||
"NORP": EntityCategory.OTHER, # Nationalities, religions
|
||||
"EVENT": EntityCategory.OTHER,
|
||||
"LAW": EntityCategory.OTHER,
|
||||
"LANGUAGE": EntityCategory.OTHER,
|
||||
}
|
||||
|
||||
# Entity types to skip (low value for meeting context)
|
||||
_SKIP_ENTITY_TYPES: Final[frozenset[str]] = frozenset({
|
||||
"CARDINAL",
|
||||
"ORDINAL",
|
||||
"QUANTITY",
|
||||
"PERCENT",
|
||||
"MONEY",
|
||||
})
|
||||
|
||||
# Valid model names
|
||||
VALID_SPACY_MODELS: Final[tuple[str, ...]] = (
|
||||
SPACY_MODEL_SM,
|
||||
SPACY_MODEL_MD,
|
||||
SPACY_MODEL_LG,
|
||||
SPACY_MODEL_TRF,
|
||||
)
|
||||
from noteflow.domain.entities.named_entity import NamedEntity
|
||||
from noteflow.infrastructure.ner.backends.types import NerBackend
|
||||
from noteflow.infrastructure.ner.mapper import map_raw_entities_to_named
|
||||
from noteflow.infrastructure.ner.post_processing import post_process
|
||||
|
||||
|
||||
class NerEngine:
|
||||
"""Named entity recognition engine using spaCy.
|
||||
"""Named entity recognition engine with backend composition.
|
||||
|
||||
Lazy-loads the spaCy model on first use to avoid startup delay.
|
||||
Implements the NerPort protocol for hexagonal architecture.
|
||||
|
||||
Uses chunking by segment (speaker turn/paragraph) to avoid OOM on
|
||||
long transcripts while maintaining segment tracking.
|
||||
Composes a backend (spaCy or GLiNER) with shared post-processing
|
||||
and domain entity mapping. Implements NerPort protocol.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = SPACY_MODEL_TRF) -> None:
|
||||
"""Initialize NER engine.
|
||||
|
||||
Args:
|
||||
model_name: spaCy model to use. Defaults to transformer model
|
||||
for higher accuracy.
|
||||
"""
|
||||
if model_name not in VALID_SPACY_MODELS:
|
||||
raise ValueError(
|
||||
f"Invalid model name: {model_name}. "
|
||||
f"Valid models: {', '.join(VALID_SPACY_MODELS)}"
|
||||
)
|
||||
self._model_name = model_name
|
||||
self._nlp: Language | None = None
|
||||
def __init__(self, backend: NerBackend) -> None:
|
||||
self._backend = backend
|
||||
|
||||
def load_model(self) -> None:
|
||||
"""Load the spaCy model.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model loading fails.
|
||||
"""
|
||||
import spacy
|
||||
|
||||
self._warn_if_curated_transformers_missing()
|
||||
with log_timing("ner_model_load", model_name=self._model_name):
|
||||
try:
|
||||
self._nlp = spacy.load(self._model_name)
|
||||
except OSError as exc:
|
||||
self._nlp = self._handle_model_load_failure(spacy, exc)
|
||||
|
||||
logger.info("ner_model_loaded", model_name=self._model_name)
|
||||
|
||||
def _warn_if_curated_transformers_missing(self) -> None:
|
||||
if self._model_name != SPACY_MODEL_TRF:
|
||||
return
|
||||
try:
|
||||
importlib.import_module("spacy_curated_transformers")
|
||||
except ModuleNotFoundError:
|
||||
logger.warning(
|
||||
"spaCy curated transformers not installed; "
|
||||
"transformer model load may fail",
|
||||
model_name=self._model_name,
|
||||
)
|
||||
|
||||
def _handle_model_load_failure(self, spacy_module: _SpacyModule, exc: OSError) -> Language:
|
||||
if self._model_name == SPACY_MODEL_TRF:
|
||||
return self._load_fallback_model(spacy_module)
|
||||
msg = (
|
||||
f"Failed to load spaCy model '{self._model_name}'. "
|
||||
f"Run: python -m spacy download {self._model_name}"
|
||||
)
|
||||
raise RuntimeError(msg) from exc
|
||||
|
||||
def _load_fallback_model(self, spacy_module: _SpacyModule) -> Language:
|
||||
fallback_model = SPACY_MODEL_SM
|
||||
logger.warning(
|
||||
"spaCy model '%s' unavailable; falling back to '%s'",
|
||||
self._model_name,
|
||||
fallback_model,
|
||||
)
|
||||
try:
|
||||
nlp = spacy_module.load(fallback_model)
|
||||
except OSError as fallback_error:
|
||||
msg = (
|
||||
f"Failed to load spaCy model '{self._model_name}' "
|
||||
f"and fallback '{fallback_model}'. "
|
||||
f"Run: python -m spacy download {fallback_model}"
|
||||
)
|
||||
raise RuntimeError(msg) from fallback_error
|
||||
self._model_name = fallback_model
|
||||
return nlp
|
||||
|
||||
def _ensure_loaded(self) -> Language:
|
||||
"""Ensure model is loaded, loading if necessary.
|
||||
|
||||
Returns:
|
||||
The loaded spaCy Language model.
|
||||
"""
|
||||
if self._nlp is None:
|
||||
self.load_model()
|
||||
assert self._nlp is not None, "load_model() should set self._nlp"
|
||||
return self._nlp
|
||||
self._backend.load_model()
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
return self._nlp is not None
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload the model to free memory."""
|
||||
self._nlp = None
|
||||
logger.info("spaCy model unloaded")
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Return the model name."""
|
||||
return self._model_name
|
||||
backend = self._backend
|
||||
return backend.model_loaded()
|
||||
|
||||
def extract(self, text: str) -> list[NamedEntity]:
|
||||
"""Extract named entities from text.
|
||||
|
||||
Args:
|
||||
text: Input text to analyze.
|
||||
|
||||
Returns:
|
||||
List of extracted entities (deduplicated by normalized text).
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return []
|
||||
|
||||
nlp = self._ensure_loaded()
|
||||
doc = nlp(text)
|
||||
raw_entities = self._backend.extract(text)
|
||||
processed = post_process(raw_entities, text)
|
||||
entities = map_raw_entities_to_named(processed)
|
||||
|
||||
entities: list[NamedEntity] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
for ent in doc.ents:
|
||||
# Normalize for deduplication
|
||||
normalized = ent.text.lower().strip()
|
||||
if not normalized or normalized in seen:
|
||||
continue
|
||||
|
||||
# Skip low-value entity types
|
||||
if ent.label_ in _SKIP_ENTITY_TYPES:
|
||||
continue
|
||||
|
||||
seen.add(normalized)
|
||||
category = _SPACY_CATEGORY_MAP.get(ent.label_, EntityCategory.OTHER)
|
||||
|
||||
entities.append(
|
||||
NamedEntity.create(
|
||||
text=ent.text,
|
||||
category=category,
|
||||
segment_ids=[], # Filled by caller
|
||||
confidence=0.8, # spaCy doesn't provide per-entity confidence
|
||||
)
|
||||
)
|
||||
|
||||
return entities
|
||||
|
||||
def _merge_entity_into_collection(
|
||||
self,
|
||||
entity: NamedEntity,
|
||||
segment_id: int,
|
||||
all_entities: dict[str, NamedEntity],
|
||||
) -> None:
|
||||
"""Merge an entity into the collection, tracking segment occurrences.
|
||||
|
||||
Args:
|
||||
entity: Entity to merge.
|
||||
segment_id: Segment where entity was found.
|
||||
all_entities: Collection to merge into.
|
||||
"""
|
||||
key = entity.normalized_text
|
||||
if key in all_entities:
|
||||
all_entities[key].merge_segments([segment_id])
|
||||
else:
|
||||
entity.segment_ids = [segment_id]
|
||||
all_entities[key] = entity
|
||||
return _dedupe_by_normalized_text(entities)
|
||||
|
||||
def extract_from_segments(
|
||||
self,
|
||||
segments: list[tuple[int, str]],
|
||||
) -> list[NamedEntity]:
|
||||
"""Extract entities from multiple segments with segment tracking.
|
||||
|
||||
Processes each segment individually (chunking by speaker turn/paragraph)
|
||||
to avoid OOM on long transcripts. Entities appearing in multiple segments
|
||||
are deduplicated with merged segment lists.
|
||||
|
||||
Args:
|
||||
segments: List of (segment_id, text) tuples.
|
||||
|
||||
Returns:
|
||||
Entities with segment_ids populated (deduplicated across segments).
|
||||
"""
|
||||
if not segments:
|
||||
return []
|
||||
|
||||
@@ -267,41 +54,57 @@ class NerEngine:
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
for entity in self.extract(text):
|
||||
self._merge_entity_into_collection(entity, segment_id, all_entities)
|
||||
_merge_entity_into_collection(entity, segment_id, all_entities)
|
||||
|
||||
return list(all_entities.values())
|
||||
|
||||
async def extract_async(self, text: str) -> list[NamedEntity]:
|
||||
"""Extract entities asynchronously using executor.
|
||||
|
||||
Offloads blocking extraction to a thread pool executor to avoid
|
||||
blocking the asyncio event loop.
|
||||
|
||||
Args:
|
||||
text: Input text to analyze.
|
||||
|
||||
Returns:
|
||||
List of extracted entities.
|
||||
"""
|
||||
return await self._run_in_executor(partial(self.extract, text))
|
||||
|
||||
async def extract_from_segments_async(
|
||||
self,
|
||||
segments: list[tuple[int, str]],
|
||||
) -> list[NamedEntity]:
|
||||
"""Extract entities from segments asynchronously.
|
||||
|
||||
Args:
|
||||
segments: List of (segment_id, text) tuples.
|
||||
|
||||
Returns:
|
||||
Entities with segment_ids populated.
|
||||
"""
|
||||
return await self._run_in_executor(partial(self.extract_from_segments, segments))
|
||||
|
||||
async def _run_in_executor(
|
||||
self,
|
||||
func: partial[list[NamedEntity]],
|
||||
) -> list[NamedEntity]:
|
||||
"""Run sync extraction in thread pool executor."""
|
||||
return await asyncio.get_running_loop().run_in_executor(None, func)
|
||||
|
||||
@property
|
||||
def backend(self) -> NerBackend:
|
||||
return self._backend
|
||||
|
||||
|
||||
def _merge_entity_into_collection(
|
||||
entity: NamedEntity,
|
||||
segment_id: int,
|
||||
all_entities: dict[str, NamedEntity],
|
||||
) -> None:
|
||||
key = entity.normalized_text
|
||||
if key not in all_entities:
|
||||
entity.segment_ids = [segment_id]
|
||||
all_entities[key] = entity
|
||||
return
|
||||
|
||||
existing = all_entities[key]
|
||||
existing.merge_segments([segment_id])
|
||||
if entity.confidence > existing.confidence:
|
||||
existing.confidence = entity.confidence
|
||||
existing.category = entity.category
|
||||
if len(entity.text) > len(existing.text):
|
||||
existing.text = entity.text
|
||||
|
||||
|
||||
def _dedupe_by_normalized_text(entities: list[NamedEntity]) -> list[NamedEntity]:
|
||||
seen: set[str] = set()
|
||||
result: list[NamedEntity] = []
|
||||
|
||||
for entity in entities:
|
||||
if entity.normalized_text not in seen:
|
||||
seen.add(entity.normalized_text)
|
||||
result.append(entity)
|
||||
|
||||
return result
|
||||
|
||||
65
src/noteflow/infrastructure/ner/mapper.py
Normal file
65
src/noteflow/infrastructure/ner/mapper.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Mapper from RawEntity to domain NamedEntity."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Final
|
||||
|
||||
from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity
|
||||
from noteflow.infrastructure.ner.backends.types import RawEntity
|
||||
|
||||
DEFAULT_CONFIDENCE: Final[float] = 0.8
|
||||
|
||||
LABEL_TO_CATEGORY: Final[dict[str, EntityCategory]] = {
|
||||
EntityCategory.PERSON.value: EntityCategory.PERSON,
|
||||
"org": EntityCategory.COMPANY,
|
||||
EntityCategory.COMPANY.value: EntityCategory.COMPANY,
|
||||
EntityCategory.PRODUCT.value: EntityCategory.PRODUCT,
|
||||
"app": EntityCategory.PRODUCT,
|
||||
EntityCategory.LOCATION.value: EntityCategory.LOCATION,
|
||||
"gpe": EntityCategory.LOCATION,
|
||||
"loc": EntityCategory.LOCATION,
|
||||
"fac": EntityCategory.LOCATION,
|
||||
EntityCategory.DATE.value: EntityCategory.DATE,
|
||||
"time": EntityCategory.DATE,
|
||||
EntityCategory.TIME_RELATIVE.value: EntityCategory.TIME_RELATIVE,
|
||||
EntityCategory.DURATION.value: EntityCategory.DURATION,
|
||||
EntityCategory.EVENT.value: EntityCategory.EVENT,
|
||||
EntityCategory.TASK.value: EntityCategory.TASK,
|
||||
EntityCategory.DECISION.value: EntityCategory.DECISION,
|
||||
EntityCategory.TOPIC.value: EntityCategory.TOPIC,
|
||||
"work_of_art": EntityCategory.PRODUCT,
|
||||
EntityCategory.TECHNICAL.value: EntityCategory.TECHNICAL,
|
||||
EntityCategory.ACRONYM.value: EntityCategory.ACRONYM,
|
||||
"norp": EntityCategory.OTHER,
|
||||
"money": EntityCategory.OTHER,
|
||||
"percent": EntityCategory.OTHER,
|
||||
"cardinal": EntityCategory.OTHER,
|
||||
"ordinal": EntityCategory.OTHER,
|
||||
"quantity": EntityCategory.OTHER,
|
||||
"law": EntityCategory.OTHER,
|
||||
"language": EntityCategory.OTHER,
|
||||
}
|
||||
|
||||
|
||||
def map_label_to_category(label: str) -> EntityCategory:
|
||||
"""Map a raw label string to EntityCategory."""
|
||||
normalized = label.lower().strip()
|
||||
return LABEL_TO_CATEGORY.get(normalized, EntityCategory.OTHER)
|
||||
|
||||
|
||||
def map_raw_to_named(raw: RawEntity) -> NamedEntity:
|
||||
"""Convert a single RawEntity to a NamedEntity."""
|
||||
confidence = raw.confidence if raw.confidence is not None else DEFAULT_CONFIDENCE
|
||||
category = map_label_to_category(raw.label)
|
||||
|
||||
return NamedEntity.create(
|
||||
text=raw.text,
|
||||
category=category,
|
||||
segment_ids=[],
|
||||
confidence=confidence,
|
||||
)
|
||||
|
||||
|
||||
def map_raw_entities_to_named(entities: list[RawEntity]) -> list[NamedEntity]:
|
||||
"""Convert a list of RawEntity to NamedEntity objects."""
|
||||
return [map_raw_to_named(raw) for raw in entities]
|
||||
243
src/noteflow/infrastructure/ner/post_processing.py
Normal file
243
src/noteflow/infrastructure/ner/post_processing.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""Shared post-processing pipeline for NER backends."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Final
|
||||
|
||||
from noteflow.infrastructure.ner.backends.types import RawEntity
|
||||
|
||||
PROFANITY_WORDS: Final[frozenset[str]] = frozenset(
|
||||
{
|
||||
"shit",
|
||||
"fuck",
|
||||
"damn",
|
||||
"crap",
|
||||
"ass",
|
||||
"hell",
|
||||
"bitch",
|
||||
}
|
||||
)
|
||||
|
||||
DURATION_UNITS: Final[frozenset[str]] = frozenset(
|
||||
{
|
||||
"second",
|
||||
"seconds",
|
||||
"minute",
|
||||
"minutes",
|
||||
"hour",
|
||||
"hours",
|
||||
"day",
|
||||
"days",
|
||||
"week",
|
||||
"weeks",
|
||||
"month",
|
||||
"months",
|
||||
"year",
|
||||
"years",
|
||||
}
|
||||
)
|
||||
|
||||
TIME_LABELS: Final[frozenset[str]] = frozenset(
|
||||
{
|
||||
"time",
|
||||
"time_relative",
|
||||
"date",
|
||||
}
|
||||
)
|
||||
|
||||
MIN_ENTITY_LENGTH: Final[int] = 2
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
"""Normalize entity text for comparison and deduplication."""
|
||||
normalized = text.lower().strip()
|
||||
normalized = re.sub(r"\s+", " ", normalized)
|
||||
normalized = re.sub(r"['\"]", "", normalized)
|
||||
return normalized
|
||||
|
||||
|
||||
def _entity_key(entity: RawEntity) -> str:
|
||||
return f"{normalize_text(entity.text)}:{entity.label}"
|
||||
|
||||
|
||||
def _should_replace_existing(existing: RawEntity, new: RawEntity) -> bool:
|
||||
if new.confidence is None:
|
||||
return False
|
||||
if existing.confidence is None:
|
||||
return True
|
||||
return new.confidence > existing.confidence
|
||||
|
||||
|
||||
def dedupe_entities(entities: list[RawEntity]) -> list[RawEntity]:
|
||||
"""Remove duplicate entities, keeping highest confidence version."""
|
||||
seen: dict[str, RawEntity] = {}
|
||||
|
||||
for entity in entities:
|
||||
key = _entity_key(entity)
|
||||
if key not in seen or _should_replace_existing(seen[key], entity):
|
||||
seen[key] = entity
|
||||
|
||||
return list(seen.values())
|
||||
|
||||
|
||||
def is_low_signal_entity(entity: RawEntity) -> bool:
|
||||
normalized = normalize_text(entity.text)
|
||||
|
||||
if len(normalized) <= MIN_ENTITY_LENGTH:
|
||||
return True
|
||||
|
||||
if normalized.isdigit():
|
||||
return True
|
||||
|
||||
if normalized in PROFANITY_WORDS:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def drop_low_signal_entities(
|
||||
entities: list[RawEntity],
|
||||
original_text: str,
|
||||
) -> list[RawEntity]:
|
||||
_ = original_text # Reserved for future context-aware filtering
|
||||
return [e for e in entities if not is_low_signal_entity(e)]
|
||||
|
||||
|
||||
NUMBER_WORDS: Final[frozenset[str]] = frozenset(
|
||||
{
|
||||
"one",
|
||||
"two",
|
||||
"three",
|
||||
"four",
|
||||
"five",
|
||||
"six",
|
||||
"seven",
|
||||
"eight",
|
||||
"nine",
|
||||
"ten",
|
||||
"a",
|
||||
"an",
|
||||
"few",
|
||||
"several",
|
||||
"couple",
|
||||
"half",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _is_numeric_duration(text: str) -> bool:
|
||||
words = text.lower().split()
|
||||
for i, word in enumerate(words):
|
||||
if word not in DURATION_UNITS:
|
||||
continue
|
||||
if i == 0:
|
||||
continue
|
||||
prev_word = words[i - 1]
|
||||
if prev_word.isdigit() or prev_word in NUMBER_WORDS:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def infer_duration(entities: list[RawEntity]) -> list[RawEntity]:
|
||||
"""Relabel entities containing duration units as 'duration'."""
|
||||
result: list[RawEntity] = []
|
||||
|
||||
for entity in entities:
|
||||
if _is_numeric_duration(entity.text):
|
||||
result.append(
|
||||
RawEntity(
|
||||
text=entity.text,
|
||||
label="duration",
|
||||
start=entity.start,
|
||||
end=entity.end,
|
||||
confidence=entity.confidence,
|
||||
)
|
||||
)
|
||||
else:
|
||||
result.append(entity)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _are_adjacent(e1: RawEntity, e2: RawEntity, original_text: str) -> bool:
|
||||
"""Check if two entities are adjacent in text (within 3 chars)."""
|
||||
gap_start = min(e1.end, e2.end)
|
||||
gap_end = max(e1.start, e2.start)
|
||||
|
||||
if gap_end <= gap_start:
|
||||
return True
|
||||
|
||||
gap_text = original_text[gap_start:gap_end]
|
||||
return len(gap_text.strip()) <= 1
|
||||
|
||||
|
||||
def _average_confidence(conf_a: float | None, conf_b: float | None) -> float | None:
|
||||
if conf_a is not None and conf_b is not None:
|
||||
return (conf_a + conf_b) / 2
|
||||
return conf_a if conf_a is not None else conf_b
|
||||
|
||||
|
||||
def _merge_two_entities(
|
||||
current: RawEntity,
|
||||
next_entity: RawEntity,
|
||||
original_text: str,
|
||||
) -> RawEntity:
|
||||
merged_start = min(current.start, next_entity.start)
|
||||
merged_end = max(current.end, next_entity.end)
|
||||
return RawEntity(
|
||||
text=original_text[merged_start:merged_end],
|
||||
label=current.label,
|
||||
start=merged_start,
|
||||
end=merged_end,
|
||||
confidence=_average_confidence(current.confidence, next_entity.confidence),
|
||||
)
|
||||
|
||||
|
||||
def _merge_adjacent_time_entities(
|
||||
time_entities: list[RawEntity],
|
||||
original_text: str,
|
||||
) -> list[RawEntity]:
|
||||
time_entities.sort(key=lambda e: e.start)
|
||||
merged: list[RawEntity] = []
|
||||
current = time_entities[0]
|
||||
|
||||
for next_entity in time_entities[1:]:
|
||||
if _are_adjacent(current, next_entity, original_text):
|
||||
current = _merge_two_entities(current, next_entity, original_text)
|
||||
else:
|
||||
merged.append(current)
|
||||
current = next_entity
|
||||
|
||||
merged.append(current)
|
||||
return merged
|
||||
|
||||
|
||||
def merge_time_phrases(
|
||||
entities: list[RawEntity],
|
||||
original_text: str,
|
||||
) -> list[RawEntity]:
|
||||
"""Merge adjacent time-like entities into single spans."""
|
||||
if not entities:
|
||||
return []
|
||||
|
||||
time_entities = [e for e in entities if e.label in TIME_LABELS]
|
||||
other_entities = [e for e in entities if e.label not in TIME_LABELS]
|
||||
|
||||
if len(time_entities) <= 1:
|
||||
return entities
|
||||
|
||||
merged_time = _merge_adjacent_time_entities(time_entities, original_text)
|
||||
return other_entities + merged_time
|
||||
|
||||
|
||||
def post_process(
|
||||
entities: list[RawEntity],
|
||||
original_text: str,
|
||||
) -> list[RawEntity]:
|
||||
"""Apply full post-processing pipeline to extracted entities."""
|
||||
entities = drop_low_signal_entities(entities, original_text)
|
||||
entities = infer_duration(entities)
|
||||
entities = merge_time_phrases(entities, original_text)
|
||||
entities = dedupe_entities(entities)
|
||||
return entities
|
||||
@@ -33,14 +33,29 @@ def _alter_segment_ids_to_int_array(table: str) -> None:
|
||||
op.execute(
|
||||
f"""
|
||||
ALTER TABLE noteflow.{table}
|
||||
ALTER COLUMN segment_ids DROP DEFAULT,
|
||||
ALTER COLUMN segment_ids TYPE INTEGER[]
|
||||
USING CASE
|
||||
ADD COLUMN segment_ids_temp INTEGER[] DEFAULT ARRAY[]::int[];
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
f"""
|
||||
UPDATE noteflow.{table}
|
||||
SET segment_ids_temp = CASE
|
||||
WHEN jsonb_typeof(segment_ids) = 'array'
|
||||
THEN ARRAY(SELECT jsonb_array_elements_text(segment_ids)::int)
|
||||
THEN (SELECT array_agg(elem::int) FROM jsonb_array_elements_text(segment_ids) AS elem)
|
||||
ELSE ARRAY[]::int[]
|
||||
END,
|
||||
ALTER COLUMN segment_ids SET DEFAULT ARRAY[]::int[];
|
||||
END;
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
f"""
|
||||
ALTER TABLE noteflow.{table}
|
||||
DROP COLUMN segment_ids;
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
f"""
|
||||
ALTER TABLE noteflow.{table}
|
||||
RENAME COLUMN segment_ids_temp TO segment_ids;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
@@ -35,13 +35,28 @@ def downgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE noteflow.diarization_jobs
|
||||
ALTER COLUMN speaker_ids DROP DEFAULT,
|
||||
ALTER COLUMN speaker_ids TYPE TEXT[]
|
||||
USING CASE
|
||||
WHEN jsonb_typeof(speaker_ids) = 'array'
|
||||
THEN ARRAY(SELECT jsonb_array_elements_text(speaker_ids))
|
||||
ELSE ARRAY[]::text[]
|
||||
END,
|
||||
ALTER COLUMN speaker_ids SET DEFAULT ARRAY[]::text[];
|
||||
ADD COLUMN speaker_ids_temp TEXT[] DEFAULT ARRAY[]::text[];
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE noteflow.diarization_jobs
|
||||
SET speaker_ids_temp = CASE
|
||||
WHEN jsonb_typeof(speaker_ids) = 'array'
|
||||
THEN (SELECT array_agg(elem) FROM jsonb_array_elements_text(speaker_ids) AS elem)
|
||||
ELSE ARRAY[]::text[]
|
||||
END;
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE noteflow.diarization_jobs
|
||||
DROP COLUMN speaker_ids;
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE noteflow.diarization_jobs
|
||||
RENAME COLUMN speaker_ids_temp TO speaker_ids;
|
||||
"""
|
||||
)
|
||||
|
||||
26
src/noteflow/infrastructure/security/crypto/_base.py
Normal file
26
src/noteflow/infrastructure/security/crypto/_base.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Base classes for crypto module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Self
|
||||
|
||||
|
||||
class ContextManagedClosable(ABC):
|
||||
"""Mixin for classes with close() that support context manager protocol."""
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Close the resource."""
|
||||
...
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: object,
|
||||
) -> None:
|
||||
self.close()
|
||||
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, BinaryIO
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
from noteflow.infrastructure.security.protocols import EncryptedChunk
|
||||
|
||||
from ._base import ContextManagedClosable
|
||||
from ._binary_io import read_exact
|
||||
from ._constants import FILE_MAGIC, FILE_VERSION, MIN_CHUNK_LENGTH, NONCE_SIZE, TAG_SIZE
|
||||
|
||||
@@ -19,7 +20,7 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ChunkedAssetReader:
|
||||
class ChunkedAssetReader(ContextManagedClosable):
|
||||
"""Streaming encrypted asset reader."""
|
||||
|
||||
def __init__(self, crypto: AesGcmCryptoBox) -> None:
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, BinaryIO
|
||||
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
from ._base import ContextManagedClosable
|
||||
from ._constants import FILE_MAGIC, FILE_VERSION
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -16,7 +17,7 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ChunkedAssetWriter:
|
||||
class ChunkedAssetWriter(ContextManagedClosable):
|
||||
"""Streaming encrypted asset writer.
|
||||
|
||||
File format:
|
||||
@@ -56,9 +57,14 @@ class ChunkedAssetWriter:
|
||||
self._handle = path.open("wb")
|
||||
self._bytes_written = 0
|
||||
|
||||
# Write header
|
||||
self._handle.write(FILE_MAGIC)
|
||||
self._handle.write(struct.pack("B", FILE_VERSION))
|
||||
try:
|
||||
self._handle.write(FILE_MAGIC)
|
||||
self._handle.write(struct.pack("B", FILE_VERSION))
|
||||
except OSError:
|
||||
self._handle.close()
|
||||
self._handle = None
|
||||
self._dek = None
|
||||
raise
|
||||
|
||||
logger.debug("Opened encrypted file for writing: %s", path)
|
||||
|
||||
|
||||
@@ -66,7 +66,9 @@ def test_get_capabilities_returns_current_config(
|
||||
asr_config_service: AsrConfigService,
|
||||
) -> None:
|
||||
"""get_capabilities returns current ASR configuration."""
|
||||
with patch.object(asr_config_service.engine_manager, "detect_cuda_available", return_value=False):
|
||||
with patch.object(
|
||||
asr_config_service.engine_manager, "detect_cuda_available", return_value=False
|
||||
):
|
||||
caps = asr_config_service.get_capabilities()
|
||||
|
||||
assert caps.model_size == "base", "model_size should be 'base' from engine"
|
||||
@@ -95,9 +97,12 @@ def test_get_capabilities_with_cuda_available(
|
||||
mock_asr_engine: MagicMock,
|
||||
) -> None:
|
||||
"""get_capabilities includes CUDA compute types when available."""
|
||||
from noteflow.domain.ports.gpu import GpuBackend
|
||||
|
||||
mock_asr_engine.device = "cuda"
|
||||
with patch.object(
|
||||
asr_config_service.engine_manager, "detect_cuda_available", return_value=True
|
||||
with patch(
|
||||
"noteflow.application.services.asr_config._engine_manager.detect_gpu_backend",
|
||||
return_value=GpuBackend.CUDA,
|
||||
):
|
||||
caps = asr_config_service.get_capabilities()
|
||||
|
||||
@@ -117,7 +122,9 @@ def test_validate_configuration_valid_cpu_config(
|
||||
asr_config_service: AsrConfigService,
|
||||
) -> None:
|
||||
"""validate_configuration accepts valid CPU configuration."""
|
||||
with patch.object(asr_config_service.engine_manager, "detect_cuda_available", return_value=False):
|
||||
with patch.object(
|
||||
asr_config_service.engine_manager, "detect_cuda_available", return_value=False
|
||||
):
|
||||
error = asr_config_service.validate_configuration(
|
||||
model_size="small",
|
||||
device=AsrDevice.CPU,
|
||||
@@ -145,7 +152,9 @@ def test_validate_configuration_cuda_unavailable(
|
||||
asr_config_service: AsrConfigService,
|
||||
) -> None:
|
||||
"""validate_configuration rejects CUDA when unavailable."""
|
||||
with patch.object(asr_config_service.engine_manager, "detect_cuda_available", return_value=False):
|
||||
with patch.object(
|
||||
asr_config_service.engine_manager, "detect_cuda_available", return_value=False
|
||||
):
|
||||
error = asr_config_service.validate_configuration(
|
||||
model_size=None,
|
||||
device=AsrDevice.CUDA,
|
||||
@@ -193,7 +202,9 @@ async def test_start_reconfiguration_returns_job_id(
|
||||
asr_config_service: AsrConfigService,
|
||||
) -> None:
|
||||
"""start_reconfiguration returns job ID on success."""
|
||||
with patch.object(asr_config_service.engine_manager, "detect_cuda_available", return_value=False):
|
||||
with patch.object(
|
||||
asr_config_service.engine_manager, "detect_cuda_available", return_value=False
|
||||
):
|
||||
job_id, error = await asr_config_service.start_reconfiguration(
|
||||
model_size="small",
|
||||
device=None,
|
||||
@@ -253,7 +264,9 @@ async def test_start_reconfiguration_validation_failure(
|
||||
asr_config_service: AsrConfigService,
|
||||
) -> None:
|
||||
"""start_reconfiguration fails on invalid configuration."""
|
||||
with patch.object(asr_config_service.engine_manager, "detect_cuda_available", return_value=False):
|
||||
with patch.object(
|
||||
asr_config_service.engine_manager, "detect_cuda_available", return_value=False
|
||||
):
|
||||
job_id, error = await asr_config_service.start_reconfiguration(
|
||||
model_size="invalid-model",
|
||||
device=None,
|
||||
@@ -276,7 +289,9 @@ async def test_get_job_status_returns_job(
|
||||
asr_config_service: AsrConfigService,
|
||||
) -> None:
|
||||
"""get_job_status returns job info for valid ID."""
|
||||
with patch.object(asr_config_service.engine_manager, "detect_cuda_available", return_value=False):
|
||||
with patch.object(
|
||||
asr_config_service.engine_manager, "detect_cuda_available", return_value=False
|
||||
):
|
||||
job_id, _ = await asr_config_service.start_reconfiguration(
|
||||
model_size="small",
|
||||
device=None,
|
||||
@@ -317,10 +332,12 @@ def test_detect_cuda_available_with_cuda(
|
||||
asr_config_service: AsrConfigService,
|
||||
) -> None:
|
||||
"""detect_cuda_available returns True when CUDA is available."""
|
||||
mock_torch = MagicMock()
|
||||
mock_torch.cuda.is_available.return_value = True
|
||||
from noteflow.domain.ports.gpu import GpuBackend
|
||||
|
||||
with patch.dict("sys.modules", {"torch": mock_torch}):
|
||||
with patch(
|
||||
"noteflow.application.services.asr_config._engine_manager.detect_gpu_backend",
|
||||
return_value=GpuBackend.CUDA,
|
||||
):
|
||||
result = asr_config_service.detect_cuda_available()
|
||||
|
||||
assert result is True, "detect_cuda_available should return True when CUDA available"
|
||||
@@ -330,10 +347,12 @@ def test_detect_cuda_available_no_cuda(
|
||||
asr_config_service: AsrConfigService,
|
||||
) -> None:
|
||||
"""detect_cuda_available returns False when CUDA is not available."""
|
||||
mock_torch = MagicMock()
|
||||
mock_torch.cuda.is_available.return_value = False
|
||||
from noteflow.domain.ports.gpu import GpuBackend
|
||||
|
||||
with patch.dict("sys.modules", {"torch": mock_torch}):
|
||||
with patch(
|
||||
"noteflow.application.services.asr_config._engine_manager.detect_gpu_backend",
|
||||
return_value=GpuBackend.NONE,
|
||||
):
|
||||
result = asr_config_service.detect_cuda_available()
|
||||
|
||||
assert result is False, "detect_cuda_available should return False when CUDA unavailable"
|
||||
@@ -347,7 +366,9 @@ def test_detect_cuda_available_no_cuda(
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconfiguration_failure_keeps_active_engine(mock_asr_engine: MagicMock) -> None:
|
||||
"""Reconfiguration failure should not replace or unload the active engine."""
|
||||
updates: list[MagicMock] = []
|
||||
from noteflow.infrastructure.asr.engine import FasterWhisperEngine
|
||||
|
||||
updates: list[FasterWhisperEngine] = []
|
||||
service = AsrConfigService(asr_engine=mock_asr_engine, on_engine_update=updates.append)
|
||||
mgr = service.engine_manager
|
||||
|
||||
@@ -356,6 +377,7 @@ async def test_reconfiguration_failure_keeps_active_engine(mock_asr_engine: Magi
|
||||
patch.object(mgr, "load_model", side_effect=RuntimeError("boom")),
|
||||
):
|
||||
job_id, _ = await service.start_reconfiguration("small", None, None, False)
|
||||
assert job_id is not None, "job_id should not be None"
|
||||
job = service.get_job_status(job_id)
|
||||
assert job is not None and job.task is not None, "job should be created with task"
|
||||
await job.task
|
||||
@@ -368,7 +390,9 @@ async def test_reconfiguration_failure_keeps_active_engine(mock_asr_engine: Magi
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconfiguration_success_swaps_engine(mock_asr_engine: MagicMock) -> None:
|
||||
"""Successful reconfiguration should swap engine and unload the old one."""
|
||||
updates: list[MagicMock] = []
|
||||
from noteflow.infrastructure.asr.engine import FasterWhisperEngine
|
||||
|
||||
updates: list[FasterWhisperEngine] = []
|
||||
service = AsrConfigService(asr_engine=mock_asr_engine, on_engine_update=updates.append)
|
||||
new_engine, mgr = MagicMock(), service.engine_manager
|
||||
|
||||
@@ -377,6 +401,7 @@ async def test_reconfiguration_success_swaps_engine(mock_asr_engine: MagicMock)
|
||||
patch.object(mgr, "load_model", return_value=None),
|
||||
):
|
||||
job_id, _ = await service.start_reconfiguration("small", None, None, False)
|
||||
assert job_id is not None, "job_id should not be None"
|
||||
job = service.get_job_status(job_id)
|
||||
assert job is not None and job.task is not None, "job should be created with task"
|
||||
await job.task
|
||||
|
||||
@@ -19,14 +19,19 @@ class TestPartialAudioBufferInit:
|
||||
"""Tests for buffer initialization."""
|
||||
|
||||
def test_default_capacity(self) -> None:
|
||||
"""Buffer should have 5 seconds capacity by default at 16kHz."""
|
||||
"""Buffer should have 10 seconds capacity by default at 16kHz."""
|
||||
buffer = PartialAudioBuffer()
|
||||
assert buffer.capacity_samples == 5 * SAMPLE_RATE_16K, "Default capacity should be 5 seconds"
|
||||
expected_capacity = int(PartialAudioBuffer.DEFAULT_MAX_DURATION * SAMPLE_RATE_16K)
|
||||
assert buffer.capacity_samples == expected_capacity, (
|
||||
"Default capacity should match DEFAULT_MAX_DURATION"
|
||||
)
|
||||
|
||||
def test_custom_duration(self) -> None:
|
||||
"""Buffer should respect custom max duration."""
|
||||
buffer = PartialAudioBuffer(max_duration_seconds=3.0, sample_rate=SAMPLE_RATE_16K)
|
||||
assert buffer.capacity_samples == 3 * SAMPLE_RATE_16K, "Capacity should match custom duration"
|
||||
assert buffer.capacity_samples == 3 * SAMPLE_RATE_16K, (
|
||||
"Capacity should match custom duration"
|
||||
)
|
||||
|
||||
def test_custom_sample_rate(self) -> None:
|
||||
"""Buffer should respect custom sample rate."""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Tests for NER engine (spaCy wrapper)."""
|
||||
"""Tests for NER engine (composition-based)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -6,26 +6,32 @@ import pytest
|
||||
|
||||
from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity
|
||||
from noteflow.infrastructure.ner import NerEngine
|
||||
from noteflow.infrastructure.ner.backends.spacy_backend import SpacyBackend
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ner_engine() -> NerEngine:
|
||||
"""Create NER engine (module-scoped to avoid repeated model loads)."""
|
||||
return NerEngine(model_name="en_core_web_sm")
|
||||
def spacy_backend() -> SpacyBackend:
|
||||
"""Create spaCy backend (module-scoped to avoid repeated model loads)."""
|
||||
return SpacyBackend(model_name="en_core_web_sm")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ner_engine(spacy_backend: SpacyBackend) -> NerEngine:
|
||||
"""Create NER engine with spaCy backend."""
|
||||
return NerEngine(backend=spacy_backend)
|
||||
|
||||
|
||||
class TestNerEngineBasics:
|
||||
"""Basic NER engine functionality tests."""
|
||||
|
||||
def test_is_ready_before_load(self) -> None:
|
||||
"""Engine is not ready before first use."""
|
||||
engine = NerEngine()
|
||||
backend = SpacyBackend()
|
||||
engine = NerEngine(backend=backend)
|
||||
assert not engine.is_ready(), "Engine should not be ready before model is loaded"
|
||||
|
||||
def test_is_ready_after_extract(self, ner_engine: NerEngine) -> None:
|
||||
"""Engine is ready after extraction triggers lazy load."""
|
||||
ner_engine.extract("Hello, John.")
|
||||
assert ner_engine.is_ready(), "Engine should be ready after extraction triggers lazy load"
|
||||
assert ner_engine.is_ready(), "Engine should be ready after extraction"
|
||||
|
||||
|
||||
class TestEntityExtraction:
|
||||
@@ -58,7 +64,9 @@ class TestEntityExtraction:
|
||||
"""Extract returns a list of NamedEntity objects."""
|
||||
entities = ner_engine.extract("John works at Google.")
|
||||
assert isinstance(entities, list), "extract() should return a list"
|
||||
assert all(isinstance(e, NamedEntity) for e in entities), "All items should be NamedEntity instances"
|
||||
assert all(isinstance(e, NamedEntity) for e in entities), (
|
||||
"All items should be NamedEntity instances"
|
||||
)
|
||||
|
||||
def test_extract_empty_text_returns_empty(self, ner_engine: NerEngine) -> None:
|
||||
"""Empty text returns empty list."""
|
||||
@@ -68,7 +76,6 @@ class TestEntityExtraction:
|
||||
def test_extract_no_entities_returns_empty(self, ner_engine: NerEngine) -> None:
|
||||
"""Text with no entities returns empty list."""
|
||||
entities = ner_engine.extract("The quick brown fox.")
|
||||
# May still find entities depending on model, but should not crash
|
||||
assert isinstance(entities, list), "Result should be a list even with no entities"
|
||||
|
||||
|
||||
@@ -84,7 +91,6 @@ class TestSegmentExtraction:
|
||||
]
|
||||
entities = ner_engine.extract_from_segments(segments)
|
||||
|
||||
# Mary should appear in multiple segments
|
||||
mary_entities = [e for e in entities if "mary" in e.normalized_text]
|
||||
assert mary_entities, "Mary entity not found"
|
||||
mary = mary_entities[0]
|
||||
@@ -98,7 +104,6 @@ class TestSegmentExtraction:
|
||||
]
|
||||
entities = ner_engine.extract_from_segments(segments)
|
||||
|
||||
# Should have one John Smith entity (deduplicated by normalized text)
|
||||
john_entities = [e for e in entities if "john" in e.normalized_text]
|
||||
assert len(john_entities) == 1, "John Smith should be deduplicated"
|
||||
assert len(john_entities[0].segment_ids) == 2, "Should track both segments"
|
||||
@@ -116,11 +121,15 @@ class TestEntityNormalization:
|
||||
"""Normalized text should be lowercase."""
|
||||
entities = ner_engine.extract("John SMITH went to NYC.")
|
||||
non_lowercase = [e for e in entities if e.normalized_text != e.normalized_text.lower()]
|
||||
assert not non_lowercase, f"All normalized text should be lowercase, but found: {[e.normalized_text for e in non_lowercase]}"
|
||||
assert not non_lowercase, (
|
||||
f"All normalized text should be lowercase, but found: {[e.normalized_text for e in non_lowercase]}"
|
||||
)
|
||||
|
||||
def test_confidence_is_set(self, ner_engine: NerEngine) -> None:
|
||||
"""Entities should have confidence score."""
|
||||
entities = ner_engine.extract("Microsoft Corporation is based in Seattle.")
|
||||
assert entities, "Should find entities"
|
||||
invalid_confidence = [e for e in entities if not (0.0 <= e.confidence <= 1.0)]
|
||||
assert not invalid_confidence, f"All entities should have confidence between 0 and 1, but found: {[(e.text, e.confidence) for e in invalid_confidence]}"
|
||||
assert not invalid_confidence, (
|
||||
f"All entities should have confidence between 0 and 1, but found: {[(e.text, e.confidence) for e in invalid_confidence]}"
|
||||
)
|
||||
|
||||
219
tests/infrastructure/ner/test_gliner_backend.py
Normal file
219
tests/infrastructure/ner/test_gliner_backend.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""Tests for GLiNER NER backend."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.infrastructure.ner.backends.gliner_backend import (
|
||||
DEFAULT_MODEL,
|
||||
DEFAULT_THRESHOLD,
|
||||
MEETING_LABELS,
|
||||
GLiNERBackend,
|
||||
)
|
||||
from noteflow.infrastructure.ner.backends.types import RawEntity
|
||||
|
||||
|
||||
def _create_backend_with_mock_model(
|
||||
model_name: str = DEFAULT_MODEL,
|
||||
labels: tuple[str, ...] = MEETING_LABELS,
|
||||
threshold: float = DEFAULT_THRESHOLD,
|
||||
mock_predictions: list[dict[str, object]] | None = None,
|
||||
) -> GLiNERBackend:
|
||||
backend = GLiNERBackend(model_name=model_name, labels=labels, threshold=threshold)
|
||||
mock_model = MagicMock()
|
||||
mock_model.predict_entities = MagicMock(return_value=mock_predictions or [])
|
||||
object.__setattr__(backend, "_model", mock_model)
|
||||
return backend
|
||||
|
||||
|
||||
class TestGLiNERBackendInit:
|
||||
def test_default_model_name(self) -> None:
|
||||
backend = GLiNERBackend()
|
||||
assert backend.model_name == DEFAULT_MODEL
|
||||
|
||||
def test_default_labels(self) -> None:
|
||||
backend = GLiNERBackend()
|
||||
assert backend.labels == MEETING_LABELS
|
||||
|
||||
def test_default_threshold(self) -> None:
|
||||
backend = GLiNERBackend()
|
||||
assert backend.threshold == DEFAULT_THRESHOLD
|
||||
|
||||
def test_init_with_custom_model_name(self) -> None:
|
||||
backend = GLiNERBackend(model_name="custom/model")
|
||||
assert backend.model_name == "custom/model"
|
||||
|
||||
def test_custom_labels(self) -> None:
|
||||
custom_labels = ("person", "location")
|
||||
backend = GLiNERBackend(labels=custom_labels)
|
||||
assert backend.labels == custom_labels
|
||||
|
||||
def test_custom_threshold(self) -> None:
|
||||
backend = GLiNERBackend(threshold=0.7)
|
||||
assert backend.threshold == 0.7
|
||||
|
||||
def test_not_loaded_initially(self) -> None:
|
||||
backend = GLiNERBackend()
|
||||
assert not backend.model_loaded()
|
||||
|
||||
|
||||
class TestGLiNERBackendModelState:
|
||||
def test_model_loaded_returns_true_when_model_set(self) -> None:
|
||||
backend = _create_backend_with_mock_model()
|
||||
assert backend.model_loaded()
|
||||
|
||||
def test_model_loaded_returns_false_initially(self) -> None:
|
||||
backend = GLiNERBackend()
|
||||
assert not backend.model_loaded()
|
||||
|
||||
|
||||
class TestGLiNERBackendExtraction:
|
||||
def test_extract_empty_string_returns_empty(self) -> None:
|
||||
backend = GLiNERBackend()
|
||||
assert backend.extract("") == []
|
||||
|
||||
def test_extract_whitespace_only_returns_empty(self) -> None:
|
||||
backend = GLiNERBackend()
|
||||
assert backend.extract(" ") == []
|
||||
|
||||
def test_extract_returns_correct_count(self) -> None:
|
||||
backend = _create_backend_with_mock_model(
|
||||
mock_predictions=[
|
||||
{"text": "John", "label": "PERSON", "start": 0, "end": 4, "score": 0.95},
|
||||
{"text": "NYC", "label": "location", "start": 12, "end": 15, "score": 0.88},
|
||||
]
|
||||
)
|
||||
entities = backend.extract("John lives in NYC")
|
||||
assert len(entities) == 2
|
||||
|
||||
def test_extract_returns_raw_entity_type(self) -> None:
|
||||
backend = _create_backend_with_mock_model(
|
||||
mock_predictions=[
|
||||
{"text": "John", "label": "PERSON", "start": 0, "end": 4, "score": 0.95},
|
||||
]
|
||||
)
|
||||
entities = backend.extract("John")
|
||||
assert isinstance(entities[0], RawEntity)
|
||||
|
||||
def test_extract_normalizes_labels_to_lowercase(self) -> None:
|
||||
backend = _create_backend_with_mock_model(
|
||||
mock_predictions=[
|
||||
{"text": "John", "label": "PERSON", "start": 0, "end": 4, "score": 0.9},
|
||||
]
|
||||
)
|
||||
entities = backend.extract("John")
|
||||
assert entities[0].label == "person"
|
||||
|
||||
def test_extract_includes_confidence_score(self) -> None:
|
||||
backend = _create_backend_with_mock_model(
|
||||
mock_predictions=[
|
||||
{"text": "Meeting", "label": "event", "start": 0, "end": 7, "score": 0.85},
|
||||
]
|
||||
)
|
||||
entities = backend.extract("Meeting tomorrow")
|
||||
assert entities[0].confidence == 0.85
|
||||
|
||||
def test_extract_passes_threshold_to_predict(self) -> None:
|
||||
backend = _create_backend_with_mock_model(threshold=0.75)
|
||||
backend.extract("Some text")
|
||||
mock_model = getattr(backend, "_model")
|
||||
predict_entities = getattr(mock_model, "predict_entities")
|
||||
call_kwargs = predict_entities.call_args
|
||||
assert call_kwargs[1]["threshold"] == 0.75
|
||||
|
||||
def test_extract_passes_labels_to_predict(self) -> None:
|
||||
custom_labels = ("person", "task")
|
||||
backend = _create_backend_with_mock_model(labels=custom_labels)
|
||||
backend.extract("Some text")
|
||||
mock_model = getattr(backend, "_model")
|
||||
predict_entities = getattr(mock_model, "predict_entities")
|
||||
call_kwargs = predict_entities.call_args
|
||||
assert call_kwargs[1]["labels"] == ["person", "task"]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("text", "label", "start", "end", "score"),
|
||||
[
|
||||
("John Smith", "person", 0, 10, 0.92),
|
||||
("New York", "location", 5, 13, 0.88),
|
||||
("decision to proceed", "decision", 0, 19, 0.75),
|
||||
],
|
||||
)
|
||||
def test_extract_entity_text_mapping(
|
||||
self, text: str, label: str, start: int, end: int, score: float
|
||||
) -> None:
|
||||
backend = _create_backend_with_mock_model(
|
||||
mock_predictions=[
|
||||
{"text": text, "label": label.upper(), "start": start, "end": end, "score": score},
|
||||
]
|
||||
)
|
||||
entities = backend.extract("input text")
|
||||
assert entities[0].text == text
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("text", "label", "start", "end", "score"),
|
||||
[
|
||||
("John Smith", "person", 0, 10, 0.92),
|
||||
("New York", "location", 5, 13, 0.88),
|
||||
("decision to proceed", "decision", 0, 19, 0.75),
|
||||
],
|
||||
)
|
||||
def test_extract_entity_label_mapping(
|
||||
self, text: str, label: str, start: int, end: int, score: float
|
||||
) -> None:
|
||||
backend = _create_backend_with_mock_model(
|
||||
mock_predictions=[
|
||||
{"text": text, "label": label.upper(), "start": start, "end": end, "score": score},
|
||||
]
|
||||
)
|
||||
entities = backend.extract("input text")
|
||||
assert entities[0].label == label
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("text", "label", "start", "end", "score"),
|
||||
[
|
||||
("John Smith", "person", 0, 10, 0.92),
|
||||
("New York", "location", 5, 13, 0.88),
|
||||
("decision to proceed", "decision", 0, 19, 0.75),
|
||||
],
|
||||
)
|
||||
def test_extract_entity_position_mapping(
|
||||
self, text: str, label: str, start: int, end: int, score: float
|
||||
) -> None:
|
||||
backend = _create_backend_with_mock_model(
|
||||
mock_predictions=[
|
||||
{"text": text, "label": label.upper(), "start": start, "end": end, "score": score},
|
||||
]
|
||||
)
|
||||
entities = backend.extract("input text")
|
||||
assert (entities[0].start, entities[0].end) == (start, end)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("text", "label", "start", "end", "score"),
|
||||
[
|
||||
("John Smith", "person", 0, 10, 0.92),
|
||||
("New York", "location", 5, 13, 0.88),
|
||||
("decision to proceed", "decision", 0, 19, 0.75),
|
||||
],
|
||||
)
|
||||
def test_extract_entity_confidence_mapping(
|
||||
self, text: str, label: str, start: int, end: int, score: float
|
||||
) -> None:
|
||||
backend = _create_backend_with_mock_model(
|
||||
mock_predictions=[
|
||||
{"text": text, "label": label.upper(), "start": start, "end": end, "score": score},
|
||||
]
|
||||
)
|
||||
entities = backend.extract("input text")
|
||||
assert entities[0].confidence == score
|
||||
|
||||
|
||||
class TestGLiNERBackendMeetingLabels:
|
||||
def test_meeting_labels_include_core_categories(self) -> None:
|
||||
expected_labels = {"person", "org", "product", "app", "location", "time"}
|
||||
assert expected_labels.issubset(set(MEETING_LABELS))
|
||||
|
||||
def test_meeting_labels_include_meeting_specific_categories(self) -> None:
|
||||
expected_labels = {"task", "decision", "topic", "event"}
|
||||
assert expected_labels.issubset(set(MEETING_LABELS))
|
||||
100
tests/infrastructure/ner/test_mapper.py
Normal file
100
tests/infrastructure/ner/test_mapper.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Tests for NER entity mapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity
|
||||
from noteflow.infrastructure.ner.backends.types import RawEntity
|
||||
from noteflow.infrastructure.ner.mapper import (
|
||||
DEFAULT_CONFIDENCE,
|
||||
map_label_to_category,
|
||||
map_raw_entities_to_named,
|
||||
map_raw_to_named,
|
||||
)
|
||||
|
||||
|
||||
class TestMapLabelToCategory:
|
||||
"""Tests for label to category mapping."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("label", "expected"),
|
||||
[
|
||||
("person", EntityCategory.PERSON),
|
||||
("PERSON", EntityCategory.PERSON),
|
||||
("org", EntityCategory.COMPANY),
|
||||
("company", EntityCategory.COMPANY),
|
||||
("product", EntityCategory.PRODUCT),
|
||||
("app", EntityCategory.PRODUCT),
|
||||
("location", EntityCategory.LOCATION),
|
||||
("gpe", EntityCategory.LOCATION),
|
||||
("loc", EntityCategory.LOCATION),
|
||||
("fac", EntityCategory.LOCATION),
|
||||
("date", EntityCategory.DATE),
|
||||
("time", EntityCategory.DATE),
|
||||
("time_relative", EntityCategory.TIME_RELATIVE),
|
||||
("duration", EntityCategory.DURATION),
|
||||
("event", EntityCategory.EVENT),
|
||||
("task", EntityCategory.TASK),
|
||||
("decision", EntityCategory.DECISION),
|
||||
("topic", EntityCategory.TOPIC),
|
||||
("work_of_art", EntityCategory.PRODUCT),
|
||||
("unknown_label", EntityCategory.OTHER),
|
||||
],
|
||||
)
|
||||
def test_label_mapping(self, label: str, expected: EntityCategory) -> None:
|
||||
assert map_label_to_category(label) == expected
|
||||
|
||||
|
||||
class TestMapRawToNamed:
|
||||
"""Tests for single entity mapping."""
|
||||
|
||||
def test_maps_text_and_category(self) -> None:
|
||||
raw = RawEntity(text="John Smith", label="person", start=0, end=10, confidence=0.9)
|
||||
result = map_raw_to_named(raw)
|
||||
|
||||
assert result.text == "John Smith", "Text should be preserved"
|
||||
assert result.category == EntityCategory.PERSON, "Category should be mapped"
|
||||
|
||||
def test_uses_provided_confidence(self) -> None:
|
||||
raw = RawEntity(text="Apple", label="product", start=0, end=5, confidence=0.85)
|
||||
result = map_raw_to_named(raw)
|
||||
|
||||
assert result.confidence == 0.85, "Should use provided confidence"
|
||||
|
||||
def test_uses_default_confidence_when_none(self) -> None:
|
||||
raw = RawEntity(text="Paris", label="location", start=0, end=5, confidence=None)
|
||||
result = map_raw_to_named(raw)
|
||||
|
||||
assert result.confidence == DEFAULT_CONFIDENCE, "Should use default confidence when None"
|
||||
|
||||
def test_segment_ids_empty_by_default(self) -> None:
|
||||
raw = RawEntity(text="Monday", label="date", start=0, end=6, confidence=0.8)
|
||||
result = map_raw_to_named(raw)
|
||||
|
||||
assert result.segment_ids == [], "Segment IDs should be empty by default"
|
||||
|
||||
def test_result_is_named_entity(self) -> None:
|
||||
raw = RawEntity(text="Google", label="company", start=0, end=6, confidence=0.9)
|
||||
result = map_raw_to_named(raw)
|
||||
|
||||
assert isinstance(result, NamedEntity), "Result should be NamedEntity"
|
||||
|
||||
|
||||
class TestMapRawEntitiesToNamed:
|
||||
"""Tests for batch entity mapping."""
|
||||
|
||||
def test_maps_multiple_entities(self) -> None:
|
||||
raw_entities = [
|
||||
RawEntity(text="John", label="person", start=0, end=4, confidence=0.9),
|
||||
RawEntity(text="Google", label="company", start=5, end=11, confidence=0.85),
|
||||
RawEntity(text="New York", label="location", start=12, end=20, confidence=0.95),
|
||||
]
|
||||
result = map_raw_entities_to_named(raw_entities)
|
||||
|
||||
assert len(result) == 3, "Should map all entities"
|
||||
assert all(isinstance(e, NamedEntity) for e in result), "All should be NamedEntity"
|
||||
|
||||
def test_empty_list_returns_empty(self) -> None:
|
||||
result = map_raw_entities_to_named([])
|
||||
assert result == [], "Empty input should return empty output"
|
||||
158
tests/infrastructure/ner/test_post_processing.py
Normal file
158
tests/infrastructure/ner/test_post_processing.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Tests for NER post-processing pipeline."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.infrastructure.ner.backends.types import RawEntity
|
||||
from noteflow.infrastructure.ner.post_processing import (
|
||||
dedupe_entities,
|
||||
drop_low_signal_entities,
|
||||
infer_duration,
|
||||
merge_time_phrases,
|
||||
normalize_text,
|
||||
post_process,
|
||||
)
|
||||
|
||||
|
||||
class TestNormalizeText:
|
||||
"""Tests for text normalization."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("input_text", "expected"),
|
||||
[
|
||||
("John Smith", "john smith"),
|
||||
(" UPPER CASE ", "upper case"),
|
||||
("'quoted'", "quoted"),
|
||||
('"double"', "double"),
|
||||
("spaces between words", "spaces between words"),
|
||||
],
|
||||
)
|
||||
def test_normalize_text(self, input_text: str, expected: str) -> None:
|
||||
assert normalize_text(input_text) == expected
|
||||
|
||||
|
||||
class TestDedupeEntities:
|
||||
"""Tests for entity deduplication."""
|
||||
|
||||
def test_deduplicates_by_normalized_text_and_label(self) -> None:
|
||||
entities = [
|
||||
RawEntity(text="John", label="person", start=0, end=4, confidence=0.8),
|
||||
RawEntity(text="john", label="person", start=10, end=14, confidence=0.9),
|
||||
]
|
||||
result = dedupe_entities(entities)
|
||||
assert len(result) == 1, "Should deduplicate same entity"
|
||||
|
||||
def test_keeps_higher_confidence(self) -> None:
|
||||
entities = [
|
||||
RawEntity(text="John", label="person", start=0, end=4, confidence=0.7),
|
||||
RawEntity(text="john", label="person", start=10, end=14, confidence=0.9),
|
||||
]
|
||||
result = dedupe_entities(entities)
|
||||
assert result[0].confidence == 0.9, "Should keep higher confidence"
|
||||
|
||||
def test_different_labels_not_deduped(self) -> None:
|
||||
entities = [
|
||||
RawEntity(text="Apple", label="product", start=0, end=5, confidence=0.8),
|
||||
RawEntity(text="Apple", label="company", start=10, end=15, confidence=0.9),
|
||||
]
|
||||
result = dedupe_entities(entities)
|
||||
assert len(result) == 2, "Different labels should not be deduped"
|
||||
|
||||
|
||||
class TestDropLowSignalEntities:
|
||||
"""Tests for low-signal entity filtering."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("text", "label"),
|
||||
[
|
||||
("a", "person"),
|
||||
("to", "location"),
|
||||
("123", "product"),
|
||||
("shit", "topic"),
|
||||
],
|
||||
)
|
||||
def test_drops_low_signal_entities(self, text: str, label: str) -> None:
|
||||
entity = RawEntity(text=text, label=label, start=0, end=len(text), confidence=0.8)
|
||||
result = drop_low_signal_entities([entity], "some original text")
|
||||
assert len(result) == 0, f"'{text}' should be dropped"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("text", "label"),
|
||||
[
|
||||
("John", "person"),
|
||||
("New York", "location"),
|
||||
],
|
||||
)
|
||||
def test_keeps_valid_entities(self, text: str, label: str) -> None:
|
||||
entity = RawEntity(text=text, label=label, start=0, end=len(text), confidence=0.8)
|
||||
result = drop_low_signal_entities([entity], "some original text")
|
||||
assert len(result) == 1, f"'{text}' should NOT be dropped"
|
||||
|
||||
|
||||
class TestInferDuration:
|
||||
"""Tests for duration inference."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("text", "expected_label"),
|
||||
[
|
||||
("20 minutes", "duration"),
|
||||
("two weeks", "duration"),
|
||||
("3 hours", "duration"),
|
||||
("next week", "time_relative"),
|
||||
("tomorrow", "time_relative"),
|
||||
],
|
||||
)
|
||||
def test_infer_duration(self, text: str, expected_label: str) -> None:
|
||||
entity = RawEntity(text=text, label="time_relative", start=0, end=len(text), confidence=0.8)
|
||||
result = infer_duration([entity])
|
||||
assert result[0].label == expected_label, f"'{text}' should be labeled as {expected_label}"
|
||||
|
||||
|
||||
class TestMergeTimePhrases:
|
||||
"""Tests for time phrase merging."""
|
||||
|
||||
def test_merges_adjacent_time_entities(self) -> None:
|
||||
original_text = "last night we met"
|
||||
entities = [
|
||||
RawEntity(text="last", label="time", start=0, end=4, confidence=0.8),
|
||||
RawEntity(text="night", label="time", start=5, end=10, confidence=0.7),
|
||||
]
|
||||
result = merge_time_phrases(entities, original_text)
|
||||
|
||||
time_entities = [e for e in result if e.label == "time"]
|
||||
assert len(time_entities) == 1, "Should merge adjacent time entities"
|
||||
assert time_entities[0].text == "last night", "Merged text should be 'last night'"
|
||||
|
||||
def test_does_not_merge_non_adjacent(self) -> None:
|
||||
original_text = "Monday came and then Friday arrived"
|
||||
entities = [
|
||||
RawEntity(text="Monday", label="time", start=0, end=6, confidence=0.8),
|
||||
RawEntity(text="Friday", label="time", start=21, end=27, confidence=0.8),
|
||||
]
|
||||
result = merge_time_phrases(entities, original_text)
|
||||
|
||||
time_entities = [e for e in result if e.label == "time"]
|
||||
assert len(time_entities) == 2, "Should not merge non-adjacent time entities"
|
||||
|
||||
|
||||
class TestPostProcess:
|
||||
"""Tests for full post-processing pipeline."""
|
||||
|
||||
def test_full_pipeline(self) -> None:
|
||||
original_text = "John met shit in Paris for 20 minutes"
|
||||
entities = [
|
||||
RawEntity(text="John", label="person", start=0, end=4, confidence=0.9),
|
||||
RawEntity(text="shit", label="topic", start=9, end=13, confidence=0.5),
|
||||
RawEntity(text="Paris", label="location", start=17, end=22, confidence=0.95),
|
||||
RawEntity(text="20 minutes", label="time", start=27, end=37, confidence=0.8),
|
||||
]
|
||||
result = post_process(entities, original_text)
|
||||
|
||||
texts = {e.text for e in result}
|
||||
assert "shit" not in texts, "Profanity should be filtered"
|
||||
assert "John" in texts, "Valid person should remain"
|
||||
assert "Paris" in texts, "Valid location should remain"
|
||||
|
||||
duration_entities = [e for e in result if e.label == "duration"]
|
||||
assert len(duration_entities) == 1, "20 minutes should be labeled as duration"
|
||||
@@ -233,7 +233,11 @@ class TestHuggingFaceTokenGrpc:
|
||||
|
||||
assert delete_response.success is True, "delete token should succeed"
|
||||
|
||||
status = await _get_hf_status(hf_token_servicer, mock_grpc_context)
|
||||
with patch(
|
||||
"noteflow.application.services.huggingface.service.get_settings"
|
||||
) as mock_settings:
|
||||
mock_settings.return_value.diarization_hf_token = None
|
||||
status = await _get_hf_status(hf_token_servicer, mock_grpc_context)
|
||||
|
||||
assert status.is_configured is False, "token should be cleared after delete"
|
||||
assert status.is_validated is False, "validation should reset after delete"
|
||||
|
||||
35
typings/gliner/__init__.pyi
Normal file
35
typings/gliner/__init__.pyi
Normal file
@@ -0,0 +1,35 @@
|
||||
from pathlib import Path
|
||||
from typing import Self
|
||||
|
||||
class GLiNER:
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_id: str,
|
||||
revision: str | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
force_download: bool = False,
|
||||
proxies: dict[str, str] | None = None,
|
||||
resume_download: bool = False,
|
||||
local_files_only: bool = False,
|
||||
token: str | bool | None = None,
|
||||
map_location: str = "cpu",
|
||||
strict: bool = False,
|
||||
load_tokenizer: bool | None = None,
|
||||
resize_token_embeddings: bool | None = True,
|
||||
compile_torch_model: bool | None = False,
|
||||
load_onnx_model: bool | None = False,
|
||||
onnx_model_file: str | None = "model.onnx",
|
||||
max_length: int | None = None,
|
||||
max_width: int | None = None,
|
||||
post_fusion_schema: str | None = None,
|
||||
_attn_implementation: str | None = None,
|
||||
) -> Self: ...
|
||||
def predict_entities(
|
||||
self,
|
||||
text: str,
|
||||
labels: list[str],
|
||||
flat_ner: bool = True,
|
||||
threshold: float = 0.5,
|
||||
multi_label: bool = False,
|
||||
) -> list[dict[str, object]]: ...
|
||||
Reference in New Issue
Block a user