Update dependencies and enhance calendar integration features

- Updated `pyproject.toml` to include `authlib` as a dependency for OAuth integration.
- Modified `uv.lock` to reflect the addition of `authlib` and its versioning.
- Updated documentation to clarify existing components and their statuses in the calendar sync and webhook integration sprints.
- Refactored various methods and properties for improved clarity and consistency in the meeting and export services.
- Enhanced test coverage for export functionality, including assertions for PDF format.
- Updated client submodule reference to ensure alignment with the latest changes.
This commit is contained in:
2025-12-26 18:18:06 -05:00
parent ae594135f9
commit d658d60241
73 changed files with 7903 additions and 610 deletions

View File

@@ -0,0 +1,519 @@
# Sprint 4 gRPC Layer Validation Report
## Executive Summary
Sprint 4's proposed approach to implement `EntitiesMixin` for Named Entity Recognition aligns perfectly with NoteFlow's established gRPC architecture. The codebase already has:
- Proto definitions for entity extraction (`ExtractEntitiesRequest/Response` lines 593-630)
- ORM model (`NamedEntityModel`) and migration
- No NER service yet (but pattern is well-established)
**Key Finding**: The mixin pattern is mature and battle-tested across 7 existing mixins. Adding `EntitiesMixin` follows proven conventions with minimal surprises.
---
## 1. EXISTING MIXINS ARCHITECTURE
### Mixins Implemented
Located in `src/noteflow/grpc/_mixins/`:
1. **StreamingMixin** - Audio bidirectional streaming + ASR
2. **DiarizationMixin** - Speaker identification
3. **DiarizationJobMixin** - Background diarization job tracking
4. **MeetingMixin** - Meeting CRUD (create, get, list, delete, stop)
5. **SummarizationMixin** - Summary generation
6. **AnnotationMixin** - Annotation management (CRUD)
7. **ExportMixin** - Transcript export (Markdown/HTML/PDF)
All exported from `src/noteflow/grpc/_mixins/__init__.py`.
### NoteFlowServicer Composition
```python
class NoteFlowServicer(
StreamingMixin,
DiarizationMixin,
DiarizationJobMixin,
MeetingMixin,
SummarizationMixin,
AnnotationMixin,
ExportMixin,
noteflow_pb2_grpc.NoteFlowServiceServicer,
):
```
**Pattern**: Multiple inheritance ordering matters—last mixin checked wins on method conflicts.
---
## 2. MIXIN STRUCTURE PATTERN
### Standard Mixin Template (from AnnotationMixin)
**Class Definition**:
```python
class AnnotationMixin:
"""Docstring explaining what it does.
Requires host to implement ServicerHost protocol.
Works with both database and memory backends via RepositoryProvider.
"""
```
**Key Characteristics**:
- No `__init__` required (state lives in NoteFlowServicer)
- Type hint `self: ServicerHost` to access host attributes
- All service dependencies injected through protocol
- Works with both database and in-memory backends
### Method Signature Pattern
```python
async def RPCMethod(
self: ServicerHost,
request: noteflow_pb2.SomeRequest,
context: grpc.aio.ServicerContext,
) -> noteflow_pb2.SomeResponse:
"""RPC implementation."""
# Get meeting ID and validate
meeting_id = await parse_meeting_id_or_abort(request.meeting_id, context)
# Use unified repository provider
async with self._create_repository_provider() as repo:
# Load data
item = await repo.entities.get(item_id)
if item is None:
await abort_not_found(context, "Entity", request.id)
# Process
result = await repo.entities.update(item)
await repo.commit()
return converted_to_proto(result)
```
---
## 3. SERVICEHOST PROTOCOL (src/noteflow/grpc/_mixins/protocols.py)
The `ServicerHost` protocol defines the contract mixins expect from NoteFlowServicer:
### Configuration Attributes
- `_session_factory: async_sessionmaker[AsyncSession] | None`
- `_memory_store: MeetingStore | None`
- `_meetings_dir: Path`
- `_crypto: AesGcmCryptoBox`
### Service Engines
- `_asr_engine: FasterWhisperEngine | None`
- `_diarization_engine: DiarizationEngine | None`
- `_summarization_service: object | None`**Generic type for flexibility**
**For EntitiesMixin**: Would need to add:
```python
_ner_service: NerService | None # Type will be added to protocol
```
### Key Methods Mixins Use
```python
def _use_database(self) -> bool
def _get_memory_store(self) -> MeetingStore
def _create_uow(self) -> SqlAlchemyUnitOfWork
def _create_repository_provider(self) -> UnitOfWork # Handles both DB and memory
def _next_segment_id(self, meeting_id: str, fallback: int = 0) -> int
def _init_streaming_state(self, meeting_id: str, next_segment_id: int) -> None
def _cleanup_streaming_state(self, meeting_id: str) -> None
def _ensure_meeting_dek(self, meeting: Meeting) -> tuple[bytes, bytes, bool]
def _start_meeting_if_needed(self, meeting: Meeting) -> tuple[bool, str | None]
def _open_meeting_audio_writer(...) -> None
def _close_audio_writer(self, meeting_id: str) -> None
```
---
## 4. PROTO DEFINITIONS (noteflow.proto, lines 45-46, 593-630)
**Already defined for Sprint 4**:
### Service RPC
```protobuf
rpc ExtractEntities(ExtractEntitiesRequest) returns (ExtractEntitiesResponse);
```
### Request Message
```protobuf
message ExtractEntitiesRequest {
string meeting_id = 1;
bool force_refresh = 2;
}
```
### Response Messages
```protobuf
message ExtractedEntity {
string id = 1;
string text = 2;
string category = 3; // person, company, product, technical, acronym, location, date, other
repeated int32 segment_ids = 4;
float confidence = 5;
bool is_pinned = 6;
}
message ExtractEntitiesResponse {
repeated ExtractedEntity entities = 1;
int32 total_count = 2;
bool cached = 3;
}
```
**Status**: Proto defs complete, no regeneration needed unless modifying messages.
---
## 5. CONVERTER PATTERN (src/noteflow/grpc/_mixins/converters.py)
### Pattern Overview
All converters are **standalone functions** in a single module, not class methods:
```python
def parse_meeting_id_or_abort(meeting_id_str: str, context: object) -> MeetingId
def parse_annotation_id(annotation_id_str: str) -> AnnotationId
def word_to_proto(word: WordTiming) -> noteflow_pb2.WordTiming
def segment_to_final_segment_proto(segment: Segment) -> noteflow_pb2.FinalSegment
def meeting_to_proto(meeting: Meeting, ...) -> noteflow_pb2.Meeting
def summary_to_proto(summary: Summary) -> noteflow_pb2.Summary
def annotation_to_proto(annotation: Annotation) -> noteflow_pb2.Annotation
def create_segment_from_asr(meeting_id: MeetingId, ...) -> Segment
def proto_to_export_format(proto_format: int) -> ExportFormat
```
### For EntitiesMixin, Need to Add:
```python
def entity_to_proto(entity: Entity) -> noteflow_pb2.ExtractedEntity
def proto_to_entity(request: noteflow_pb2.ExtractEntitiesRequest, ...) -> Entity | None
def proto_entity_list_to_domain(protos: list[...]) -> list[Entity]
```
### Key Principles
1. **Stateless functions** - no dependencies on mixins
2. **Handle both directions** - proto→domain and domain→proto
3. **Consolidate patterns** - e.g., `parse_meeting_id_or_abort` replaces 11+ duplicated patterns
4. **Use domain converters** - import from infrastructure when needed (e.g., `AsrConverter`)
---
## 6. DEPENDENCY INJECTION PATTERN
### How Services Get Injected
**Example: SummarizationMixin**
```python
class SummarizationMixin:
_summarization_service: SummarizationService | None
async def GenerateSummary(self: ServicerHost, ...):
# Access through self
summary = await self._summarize_or_placeholder(meeting_id, segments, style_prompt)
```
**Pattern for NerService**:
1. Add field to NoteFlowServicer `__init__` parameter:
```python
def __init__(
self,
...,
ner_service: NerService | None = None,
) -> None:
self._ner_service = ner_service
```
2. Add to ServicerHost protocol:
```python
class ServicerHost(Protocol):
_ner_service: NerService | None
```
3. Use in EntitiesMixin:
```python
async def ExtractEntities(self: ServicerHost, request, context):
# Service may be None; handle gracefully or abort
if self._ner_service is None:
await abort_service_unavailable(context, "NER")
```
### Repository Provider Pattern
All mixins use `self._create_repository_provider()` to abstract DB vs. in-memory:
```python
async with self._create_repository_provider() as repo:
entities = await repo.entities.get_by_meeting(meeting_id) # Need to implement
```
---
## 7. HOW TO ADD ENTITESMIXIN
### Step-by-Step Integration Points
**A. Create `src/noteflow/grpc/_mixins/entities.py`**
```python
class EntitiesMixin:
"""Mixin providing named entity extraction functionality."""
async def ExtractEntities(
self: ServicerHost,
request: noteflow_pb2.ExtractEntitiesRequest,
context: grpc.aio.ServicerContext,
) -> noteflow_pb2.ExtractEntitiesResponse:
# Implementation
```
**B. Update `src/noteflow/grpc/_mixins/__init__.py`**
```python
from .entities import EntitiesMixin
__all__ = [..., "EntitiesMixin"]
```
**C. Update `src/noteflow/grpc/service.py`**
```python
from ._mixins import (
...,
EntitiesMixin, # Add
)
class NoteFlowServicer(
...,
EntitiesMixin, # Add before noteflow_pb2_grpc.NoteFlowServiceServicer
noteflow_pb2_grpc.NoteFlowServiceServicer,
):
def __init__(
self,
...,
ner_service: NerService | None = None,
):
...,
self._ner_service = ner_service
```
**D. Update `src/noteflow/grpc/_mixins/protocols.py`**
```python
class ServicerHost(Protocol):
...
_ner_service: object | None # or import NerService if in TYPE_CHECKING
```
**E. Update `src/noteflow/grpc/_mixins/converters.py`**
Add entity conversion functions:
```python
def entity_to_proto(entity: Entity) -> noteflow_pb2.ExtractedEntity:
"""Convert domain Entity to protobuf."""
return noteflow_pb2.ExtractedEntity(
id=str(entity.id),
text=entity.text,
category=entity.category,
segment_ids=entity.segment_ids,
confidence=entity.confidence,
is_pinned=entity.is_pinned,
)
```
---
## 8. REPOSITORY INTEGRATION
### Current Repositories Available
Located in `src/noteflow/infrastructure/persistence/repositories/`:
- `annotation_repo.py``SqlAlchemyAnnotationRepository`
- `diarization_job_repo.py``SqlAlchemyDiarizationJobRepository`
- `meeting_repo.py``SqlAlchemyMeetingRepository`
- `preferences_repo.py``SqlAlchemyPreferencesRepository`
- `segment_repo.py``SqlAlchemySegmentRepository`
- `summary_repo.py``SqlAlchemySummaryRepository`
**Missing**: `entity_repo.py` (Sprint 4 task)
### UnitOfWork Integration
Current `SqlAlchemyUnitOfWork` (line 87-127):
```python
@property
def annotations(self) -> SqlAlchemyAnnotationRepository: ...
@property
def meetings(self) -> SqlAlchemyMeetingRepository: ...
```
**Needed for Sprint 4**:
```python
@property
def entities(self) -> SqlAlchemyEntityRepository: ...
# And in __aenter__/__aexit__:
self._entities_repo = SqlAlchemyEntityRepository(self._session)
```
### Domain vs. In-Memory
NoteFlowServicer uses `_create_repository_provider()` which returns:
- `SqlAlchemyUnitOfWork` if database configured
- `MemoryUnitOfWork` if in-memory fallback
EntitiesMixin must handle both seamlessly (like AnnotationMixin does).
---
## 9. ERROR HANDLING PATTERN
All mixins use helper functions from `src/noteflow/grpc/_mixins/errors.py`:
```python
async def abort_not_found(context, entity_type, entity_id) -> None
async def abort_invalid_argument(context, message) -> None
async def abort_database_required(context, feature_name) -> None
```
**Pattern in EntitiesMixin**:
```python
async def ExtractEntities(self: ServicerHost, request, context):
meeting_id = await parse_meeting_id_or_abort(request.meeting_id, context)
async with self._create_repository_provider() as repo:
meeting = await repo.meetings.get(meeting_id)
if meeting is None:
await abort_not_found(context, "Meeting", request.meeting_id)
entities = await repo.entities.get_by_meeting(meeting_id)
# Handle not found, validation errors, etc.
```
---
## 10. CACHING & REFRESH PATTERN
From proto: `force_refresh` flag in `ExtractEntitiesRequest`.
**SummarizationMixin pattern** (similar):
```python
existing = await repo.summaries.get_by_meeting(meeting_id)
if existing and not request.force_regenerate:
return summary_to_proto(existing)
# Otherwise, generate fresh
```
**EntitiesMixin should follow**:
```python
async def ExtractEntities(self: ServicerHost, request, context):
meeting_id = await parse_meeting_id_or_abort(request.meeting_id, context)
async with self._create_repository_provider() as repo:
meeting = await repo.meetings.get(meeting_id)
if meeting is None:
await abort_not_found(context, "Meeting", request.meeting_id)
# Check cache first
cached_entities = await repo.entities.get_by_meeting(meeting_id)
if cached_entities and not request.force_refresh:
return noteflow_pb2.ExtractEntitiesResponse(
entities=[entity_to_proto(e) for e in cached_entities],
total_count=len(cached_entities),
cached=True,
)
# Load segments, run NER service (outside repo context if slow)
segments = await repo.segments.get_by_meeting(meeting_id)
entities = await self._extract_entities(meeting_id, segments)
# Persist in fresh transaction
async with self._create_repository_provider() as repo:
saved = await repo.entities.save_batch(entities)
await repo.commit()
return noteflow_pb2.ExtractEntitiesResponse(
entities=[entity_to_proto(e) for e in saved],
total_count=len(saved),
cached=False,
)
```
---
## 11. KEY INSIGHTS & RISKS
### Strengths of Current Pattern
**Modular**: Each mixin is self-contained, easy to test
**Composable**: NoteFlowServicer inherits from all mixins without complexity
**Database-agnostic**: `_create_repository_provider()` abstracts DB vs. memory
**Proven**: 7 existing mixins provide battle-tested templates
**Asyncio-native**: All methods are async, no blocking calls
**Type-safe**: ServicerHost protocol ensures mixin contracts
### Potential Issues Sprint 4 Should Watch For
1. **Missing Domain Entity**: No `Entity` class in `src/noteflow/domain/entities/` yet
- Proto has `ExtractedEntity` (protobuf message)
- ORM has `NamedEntityModel` (persistence)
- Need to bridge: `Entity` domain class
2. **Missing Repository**: `entity_repo.py` doesn't exist
- Must implement CRUD methods used in mixin
- Must handle both database and in-memory backends
3. **NER Service Not Started**: Service layer (`NerService`) needs implementation
- How to get segments for extraction?
- What model/provider does NER use?
- Is it async? Does it block?
4. **Proto Regeneration**: After any proto changes:
```bash
python -m grpc_tools.protoc -I src/noteflow/grpc/proto \
--python_out=src/noteflow/grpc/proto \
--grpc_python_out=src/noteflow/grpc/proto \
src/noteflow/grpc/proto/noteflow.proto
```
Proto files are already generated, so check if they match lines 45-46 and 593-630.
5. **Memory Store Fallback**: Annotation repository supports in-memory; entity repo must too
- See `src/noteflow/infrastructure/persistence/memory.py` for fallback pattern
---
## 12. PROTO SYNC CHECKLIST FOR SPRINT 4
- [x] `ExtractEntities` RPC defined (line 46)
- [x] `ExtractEntitiesRequest` message (lines 593-599)
- [x] `ExtractedEntity` message (lines 601-619)
- [x] `ExtractEntitiesResponse` message (lines 621-630)
- [ ] Regenerate Python stubs if proto modified
- [ ] Verify `noteflow_pb2.ExtractedEntity` exists in generated code
- [ ] Verify `noteflow_pb2_grpc.NoteFlowService` includes `ExtractEntities` method stub
---
## Summary Table
| Component | Status | File | Notes |
|-----------|--------|------|-------|
| **Mixin Class** | READY | `_mixins/entities.py` (new) | Follow AnnotationMixin pattern |
| **Service RPC** | ✅ EXISTS | `proto/noteflow.proto:46` | Already defined |
| **Request/Response** | ✅ EXISTS | `proto/noteflow.proto:593-630` | Complete messages |
| **ORM Model** | ✅ EXISTS | `persistence/models/entities/named_entity.py` | `NamedEntityModel` ready |
| **Migration** | ✅ EXISTS | `persistence/migrations/.../add_named_entities_table.py` | Table created |
| **Domain Entity** | ❌ MISSING | `domain/entities/` | Need `Entity` class |
| **Repository** | ❌ MISSING | `persistence/repositories/entity_repo.py` | Need CRUD implementation |
| **NER Service** | ❌ MISSING | `application/services/` | Need extraction logic |
| **Converter Functions** | PARTIAL | `_mixins/converters.py` | Need `entity_to_proto()` |
| **UnitOfWork** | READY | `persistence/unit_of_work.py` | Just add `entities` property |
| **ServicerHost** | READY | `_mixins/protocols.py` | Just add `_ner_service` field |
---
## VALIDATION CONCLUSION
**Status**: ✅ **SPRINT 4 APPROACH IS SOUND**
The proposed mixin-based architecture for EntitiesMixin:
1. Follows established patterns from 7 existing mixins
2. Leverages existing proto definitions
3. Has ORM and migration infrastructure ready
4. Fits cleanly into NoteFlowServicer composition
**Next Steps**:
1. Create domain `Entity` class (mirrors proto `ExtractedEntity`)
2. Implement `SqlAlchemyEntityRepository` with CRUD operations
3. Create `NerService` application service
4. Implement `EntitiesMixin.ExtractEntities()` RPC handler
5. Add converter functions `entity_to_proto()` etc.
6. Add `entities` property to `SqlAlchemyUnitOfWork`
7. Add `_ner_service` to `ServicerHost` protocol
No architectural surprises or deviations from established patterns.

View File

@@ -36,44 +36,56 @@ Automatically extract named entities (people, companies, products, locations) fr
## Current State Analysis
### What Exists
### What Already Exists
#### Frontend Components (Complete)
| Component | Location | Status |
|-----------|----------|--------|
| Entity UI | `client/src/components/EntityHighlightText.tsx` | Renders highlighted entities |
| Entity Panel | `client/src/components/EntityManagementPanel.tsx` | Manual CRUD for entities |
| Annotation RPC | `noteflow.proto` | `AddAnnotation` for manual entries |
| Entity Highlight | `client/src/components/entity-highlight.tsx` | Renders inline entity highlights with tooltips (190 lines) |
| Entity Panel | `client/src/components/entity-management-panel.tsx` | Manual CRUD with Sheet UI (388 lines) |
| Entity Store | `client/src/lib/entity-store.ts` | Client-side state with observer pattern (169 lines) |
| Entity Types | `client/src/types/entity.ts` | TypeScript types + color mappings (49 lines) |
### Existing Persistence Infrastructure
#### Backend Infrastructure (Complete)
**Location**: `src/noteflow/infrastructure/persistence/`
| Component | Location | Status |
|-----------|----------|--------|
| **ORM Model** | `infrastructure/persistence/models/entities/named_entity.py` | Complete - `NamedEntityModel` with all fields |
| **Migration** | `migrations/versions/h2c3d4e5f6g7_add_named_entities_table.py` | Complete - table created with indices |
| **Meeting Relationship** | `models/core/meeting.py:35-37` | Configured - `named_entities` with cascade delete |
| **Proto RPC** | `noteflow.proto:46` | Defined - `rpc ExtractEntities(...)` |
| **Proto Messages** | `noteflow.proto:593-630` | Defined - Request, Response, ExtractedEntity |
| **Feature Flag** | `config/settings.py:223-226` | Defined - `NOTEFLOW_FEATURE_NER_ENABLED` |
| **Dependency** | `pyproject.toml:63` | Declared - `spacy>=3.7` |
**Decision Point**: Named entities can be stored in:
> **Note**: The `named_entities` table and proto definitions are already implemented.
> Sprint 0 completed this foundation work. No schema changes needed.
1. **Existing `annotations` table** (simpler, less work):
- `annotation_type` = 'ENTITY_PERSON', 'ENTITY_COMPANY', etc.
- Reuses existing `AnnotationModel` and `SqlAlchemyAnnotationRepository`
- Schema: `docker/db/schema.sql:182-196`
### Gap: What Needs Implementation
2. **New `named_entities` table** (recommended for richer features):
- Dedicated table with `category`, `confidence`, `is_pinned` columns
- Better for entity deduplication and cross-meeting entity linking
- Requires new Alembic migration
Backend NER processing pipeline:
- No spaCy engine wrapper (`infrastructure/ner/`)
- No `NamedEntity` domain entity class
- No `NerPort` domain protocol
- No `NerService` application layer
- No `SqlAlchemyEntityRepository`
- No `EntitiesMixin` gRPC handler
- No Rust/Tauri command for extraction
- Frontend not wired to backend extraction
**Recommendation**: Use option 2 (new table) for:
- Entity confidence scores
- Pinned/verified entity status
- Future: cross-meeting entity linking and knowledge graph
### Field Naming Alignment Warning
### Gap
No backend NER capability:
- No spaCy or other NER engine
- No `ExtractEntities` RPC
- No `NamedEntity` domain entity
- **No `NerService` application layer** (violates hexagonal architecture)
- Entities are purely client-side manual entries
- No persistence for extracted entities
> **⚠️ CRITICAL**: Frontend and backend use different field names for entity text:
>
> | Layer | Field | Location |
> |-------|-------|----------|
> | Frontend TS | `term` | `client/src/types/entity.ts:Entity.term` |
> | Backend Proto | `text` | `noteflow.proto:ExtractedEntity.text` |
> | ORM Model | `text` | `models/entities/named_entity.py:NamedEntityModel.text` |
>
> **Resolution required**: Either align frontend to use `text` or add mapping in Tauri commands.
> The existing `entity-store.ts` uses `term` throughout—a rename may be needed for consistency.
---
@@ -120,29 +132,33 @@ All components use **Protocol-based dependency injection** for testability:
| `src/noteflow/application/services/ner_service.py` | **Application service layer** | ~150 |
| `src/noteflow/infrastructure/ner/__init__.py` | Module init | ~10 |
| `src/noteflow/infrastructure/ner/engine.py` | spaCy NER engine | ~120 |
| `src/noteflow/infrastructure/persistence/models/core/named_entity.py` | ORM model | ~60 |
| `src/noteflow/infrastructure/persistence/repositories/entity_repo.py` | Entity repository | ~100 |
| `src/noteflow/infrastructure/converters/ner_converters.py` | ORM ↔ domain converters | ~50 |
| `src/noteflow/grpc/_mixins/entities.py` | gRPC mixin (calls NerService) | ~100 |
| `tests/application/test_ner_service.py` | Application layer tests | ~200 |
| `tests/infrastructure/ner/test_engine.py` | Engine unit tests | ~150 |
| `client/src/hooks/useEntityExtraction.ts` | React hook with polling | ~80 |
| `client/src/components/EntityExtractionPanel.tsx` | UI component | ~120 |
| `client/src/hooks/use-entity-extraction.ts` | React hook with polling | ~80 |
| `client/src-tauri/src/commands/entities.rs` | Rust command | ~60 |
> **Note**: ORM model (`NamedEntityModel`) already exists at `models/entities/named_entity.py`.
> Frontend extraction panel can extend existing `entity-management-panel.tsx`.
### Files to Modify
| File | Change Type | Lines Est. |
|------|-------------|------------|
| `src/noteflow/grpc/service.py` | Add mixin + NerService initialization | +15 |
| `src/noteflow/grpc/server.py` | Initialize NerService | +10 |
| `src/noteflow/infrastructure/persistence/models/__init__.py` | Export entity model | +2 |
| `src/noteflow/infrastructure/persistence/models/core/__init__.py` | Export entity model | +2 |
| `src/noteflow/grpc/_mixins/protocols.py` | Add `_ner_service` to ServicerHost | +5 |
| `src/noteflow/infrastructure/persistence/repositories/__init__.py` | Export entity repo | +2 |
| `src/noteflow/infrastructure/persistence/unit_of_work.py` | Add entity repository | +10 |
| `pyproject.toml` | Add spacy dependency | +2 |
| `src/noteflow/infrastructure/persistence/unit_of_work.py` | Add entity repository property | +15 |
| `client/src/lib/tauri.ts` | Add extractEntities wrapper | +15 |
| `client/src/pages/MeetingDetail.tsx` | Integrate EntityExtractionPanel | +20 |
| `client/src/components/entity-management-panel.tsx` | Add extraction trigger button | +30 |
| `client/src-tauri/src/commands/mod.rs` | Export entities module | +2 |
| `client/src-tauri/src/lib.rs` | Register entity commands | +3 |
| `client/src-tauri/src/grpc/client.rs` | Add extract_entities method | +25 |
> **Note**: `pyproject.toml` already has `spacy>=3.7`. Proto regeneration not needed.
---
@@ -493,59 +509,16 @@ class NerEngine:
### Task 5: Create Entity Persistence
**File**: `src/noteflow/infrastructure/persistence/models/core/named_entity.py`
> **✅ ORM Model Already Exists**: The `NamedEntityModel` is already implemented at
> `infrastructure/persistence/models/entities/named_entity.py`. Do **not** recreate it.
> Only the repository and converters need to be implemented.
```python
"""Named entity ORM model."""
**Existing ORM** (reference only—do not modify):
- Location: `src/noteflow/infrastructure/persistence/models/entities/named_entity.py`
- Fields: `id`, `meeting_id`, `text`, `category`, `segment_ids`, `confidence`, `is_pinned`, `created_at`
- Relationship: `meeting` → `MeetingModel.named_entities` (cascade delete configured)
from __future__ import annotations
from datetime import datetime
from uuid import UUID
from sqlalchemy import ARRAY, Boolean, DateTime, Float, ForeignKey, Integer, String
from sqlalchemy.dialects.postgresql import UUID as PG_UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from noteflow.infrastructure.persistence.models.base import Base
class NamedEntityModel(Base):
"""SQLAlchemy model for named entities."""
__tablename__ = "named_entities"
__table_args__ = {"schema": "noteflow"}
id: Mapped[UUID] = mapped_column(
PG_UUID(as_uuid=True),
primary_key=True,
)
meeting_id: Mapped[UUID] = mapped_column(
PG_UUID(as_uuid=True),
ForeignKey("noteflow.meetings.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
text: Mapped[str] = mapped_column(String, nullable=False)
category: Mapped[str] = mapped_column(String(50), nullable=False)
segment_ids: Mapped[list[int]] = mapped_column(
ARRAY(Integer),
nullable=False,
default=list,
)
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
is_pinned: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=datetime.utcnow,
)
# Relationship
meeting = relationship("MeetingModel", back_populates="named_entities")
```
**File**: `src/noteflow/infrastructure/persistence/repositories/entity_repo.py`
**File to Create**: `src/noteflow/infrastructure/persistence/repositories/entity_repo.py`
```python
"""Named entity repository."""
@@ -1510,30 +1483,53 @@ pub async fn clear_entities(
| Pattern | Reference Location | Usage |
|---------|-------------------|-------|
| ORM Model | `models/core/annotation.py` | Structure for `NamedEntityModel` |
| Repository | `repositories/annotation_repo.py` | CRUD operations |
| Unit of Work | `unit_of_work.py` | Add entity repository |
| Converters | `converters/orm_converters.py` | ORM ↔ domain entity |
| ORM Model | `models/entities/named_entity.py` | **Already exists** - no creation needed |
| Repository | `repositories/annotation_repo.py` | Template for CRUD operations |
| Base Repository | `repositories/_base.py` | Extend for helper methods |
| Unit of Work | `unit_of_work.py` | Add entity repository property |
| Converters | `converters/orm_converters.py` | Add `entity_to_domain()` method |
### Existing Entity UI Components
**Location**: `client/src/components/EntityHighlightText.tsx`
**Location**: `client/src/components/entity-highlight.tsx`
Already renders entity highlights - connect to extracted entities.
Already renders entity highlights with tooltips - connect to extracted entities.
**Location**: `client/src/components/EntityManagementPanel.tsx`
**Location**: `client/src/components/entity-management-panel.tsx`
CRUD panel - extend to display auto-extracted entities.
CRUD panel with Sheet UI - extend to display auto-extracted entities.
**Location**: `client/src/lib/entity-store.ts`
Client-side observer pattern store - wire to backend extraction results.
> **Warning**: Color definitions are duplicated between `entity-highlight.tsx` (inline `categoryColors`)
> and `types/entity.ts` (`ENTITY_CATEGORY_COLORS`). Use the shared constant from `types/entity.ts`.
### Application Service Pattern
**Location**: `src/noteflow/application/services/diarization_service.py`
**Location**: `src/noteflow/application/services/summarization_service.py`
Pattern for application service with:
- Async locks for concurrency control
- UoW factory for database operations
- `run_in_executor` for CPU-bound work
- Structured logging
- Dataclass-based service with settings
- Provider registration pattern (multi-backend support)
- Lazy model loading via property getters
- Callback-based persistence (not UoW-injected)
**Location**: `src/noteflow/infrastructure/summarization/ollama_provider.py`
Pattern for lazy model loading:
- `self._client: T | None = None` initial state
- `_get_client()` method for lazy initialization
- `is_available` property for runtime checks
- `asyncio.to_thread()` for CPU-bound inference
**Location**: `src/noteflow/grpc/_mixins/diarization.py`
Pattern for CPU-bound gRPC handlers:
- `asyncio.Lock` for concurrency control
- `loop.run_in_executor()` for blocking operations
- Structured logging with meeting context
---

View File

@@ -180,7 +180,7 @@ The persistence layer is complete, but there's no application/domain layer to:
2. **Feature flag** verified:
```python
from noteflow.config.settings import get_settings
assert get_settings().feature_flags.calendar_sync_enabled
assert get_settings().feature_flags.calendar_enabled
```
**If any verification fails**: Complete Sprint 0 first.
@@ -1057,7 +1057,7 @@ class CalendarService:
Authorization URL to redirect user to.
"""
settings = get_settings()
if not settings.feature_flags.calendar_sync_enabled:
if not settings.feature_flags.calendar_enabled:
msg = "Calendar sync is disabled by feature flag"
raise RuntimeError(msg)

File diff suppressed because it is too large Load Diff

View File

@@ -31,6 +31,7 @@ dependencies = [
# HTTP client for webhooks and integrations
"httpx>=0.27",
"weasyprint>=67.0",
"authlib>=1.6.6",
]
[project.optional-dependencies]

View File

@@ -0,0 +1,421 @@
"""Calendar integration service.
Orchestrates OAuth flow, token management, and calendar event fetching.
Uses existing Integration entity and IntegrationRepository for persistence.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from uuid import UUID
from noteflow.domain.entities.integration import Integration, IntegrationStatus, IntegrationType
from noteflow.domain.ports.calendar import CalendarEventInfo, OAuthConnectionInfo
from noteflow.domain.value_objects import OAuthProvider, OAuthTokens
from noteflow.infrastructure.calendar import (
GoogleCalendarAdapter,
OAuthManager,
OutlookCalendarAdapter,
)
from noteflow.infrastructure.calendar.google_adapter import GoogleCalendarError
from noteflow.infrastructure.calendar.oauth_manager import OAuthError
from noteflow.infrastructure.calendar.outlook_adapter import OutlookCalendarError
if TYPE_CHECKING:
from collections.abc import Callable
from noteflow.config.settings import CalendarSettings
from noteflow.domain.ports.unit_of_work import UnitOfWork
logger = logging.getLogger(__name__)
class CalendarServiceError(Exception):
"""Calendar service operation failed."""
class CalendarService:
"""Calendar integration service.
Orchestrates OAuth flow and calendar event fetching. Uses:
- IntegrationRepository for Integration entity CRUD
- IntegrationRepository.get_secrets/set_secrets for encrypted token storage
- OAuthManager for PKCE OAuth flow
- GoogleCalendarAdapter/OutlookCalendarAdapter for provider APIs
"""
# Default workspace ID for single-user mode
DEFAULT_WORKSPACE_ID = UUID("00000000-0000-0000-0000-000000000000")
def __init__(
self,
uow_factory: Callable[[], UnitOfWork],
settings: CalendarSettings,
oauth_manager: OAuthManager | None = None,
google_adapter: GoogleCalendarAdapter | None = None,
outlook_adapter: OutlookCalendarAdapter | None = None,
) -> None:
"""Initialize calendar service.
Args:
uow_factory: Factory function returning UnitOfWork instances.
settings: Calendar settings with OAuth credentials.
oauth_manager: Optional OAuth manager (created from settings if not provided).
google_adapter: Optional Google adapter (created if not provided).
outlook_adapter: Optional Outlook adapter (created if not provided).
"""
self._uow_factory = uow_factory
self._settings = settings
self._oauth_manager = oauth_manager or OAuthManager(settings)
self._google_adapter = google_adapter or GoogleCalendarAdapter()
self._outlook_adapter = outlook_adapter or OutlookCalendarAdapter()
async def initiate_oauth(
self,
provider: str,
redirect_uri: str | None = None,
) -> tuple[str, str]:
"""Start OAuth flow for a calendar provider.
Args:
provider: Provider name ('google' or 'outlook').
redirect_uri: Optional override for OAuth callback URI.
Returns:
Tuple of (authorization_url, state_token).
Raises:
CalendarServiceError: If provider is invalid or credentials not configured.
"""
oauth_provider = self._parse_provider(provider)
effective_redirect = redirect_uri or self._settings.redirect_uri
try:
auth_url, state = self._oauth_manager.initiate_auth(
provider=oauth_provider,
redirect_uri=effective_redirect,
)
logger.info("Initiated OAuth flow for provider=%s", provider)
return auth_url, state
except OAuthError as e:
raise CalendarServiceError(str(e)) from e
async def complete_oauth(
self,
provider: str,
code: str,
state: str,
) -> bool:
"""Complete OAuth flow and store tokens.
Creates or updates Integration entity with CALENDAR type and
stores encrypted tokens via IntegrationRepository.set_secrets.
Args:
provider: Provider name ('google' or 'outlook').
code: Authorization code from OAuth callback.
state: State parameter from OAuth callback.
Returns:
True if OAuth completed successfully.
Raises:
CalendarServiceError: If OAuth exchange fails.
"""
oauth_provider = self._parse_provider(provider)
try:
tokens = await self._oauth_manager.complete_auth(
provider=oauth_provider,
code=code,
state=state,
)
except OAuthError as e:
raise CalendarServiceError(f"OAuth failed: {e}") from e
# Get user email from provider
try:
email = await self._get_user_email(oauth_provider, tokens.access_token)
except (GoogleCalendarError, OutlookCalendarError) as e:
raise CalendarServiceError(f"Failed to get user email: {e}") from e
# Persist integration and tokens
async with self._uow_factory() as uow:
integration = await uow.integrations.get_by_provider(
provider=provider,
integration_type=IntegrationType.CALENDAR.value,
)
if integration is None:
integration = Integration.create(
workspace_id=self.DEFAULT_WORKSPACE_ID,
name=f"{provider.title()} Calendar",
integration_type=IntegrationType.CALENDAR,
config={"provider": provider},
)
await uow.integrations.create(integration)
else:
integration.config["provider"] = provider
integration.connect(provider_email=email)
await uow.integrations.update(integration)
# Store encrypted tokens
await uow.integrations.set_secrets(
integration_id=integration.id,
secrets=tokens.to_secrets_dict(),
)
await uow.commit()
logger.info("Completed OAuth for provider=%s, email=%s", provider, email)
return True
async def get_connection_status(self, provider: str) -> OAuthConnectionInfo:
"""Get OAuth connection status for a provider.
Args:
provider: Provider name ('google' or 'outlook').
Returns:
OAuthConnectionInfo with status and details.
"""
async with self._uow_factory() as uow:
integration = await uow.integrations.get_by_provider(
provider=provider,
integration_type=IntegrationType.CALENDAR.value,
)
if integration is None:
return OAuthConnectionInfo(
provider=provider,
status="disconnected",
)
# Check token expiry
secrets = await uow.integrations.get_secrets(integration.id)
expires_at = None
status = self._map_integration_status(integration.status)
if secrets and integration.is_connected:
try:
tokens = OAuthTokens.from_secrets_dict(secrets)
expires_at = tokens.expires_at
if tokens.is_expired():
status = "expired"
except (KeyError, ValueError):
status = "error"
return OAuthConnectionInfo(
provider=provider,
status=status,
email=integration.provider_email,
expires_at=expires_at,
error_message=integration.error_message,
)
async def disconnect(self, provider: str) -> bool:
"""Disconnect OAuth integration and revoke tokens.
Args:
provider: Provider name ('google' or 'outlook').
Returns:
True if disconnected successfully.
"""
oauth_provider = self._parse_provider(provider)
async with self._uow_factory() as uow:
integration = await uow.integrations.get_by_provider(
provider=provider,
integration_type=IntegrationType.CALENDAR.value,
)
if integration is None:
return False
# Get tokens before deletion for revocation
secrets = await uow.integrations.get_secrets(integration.id)
access_token = secrets.get("access_token") if secrets else None
# Delete integration (cascades to secrets)
await uow.integrations.delete(integration.id)
await uow.commit()
# Revoke tokens with provider (best-effort)
if access_token:
try:
await self._oauth_manager.revoke_tokens(oauth_provider, access_token)
except OAuthError as e:
logger.warning(
"Failed to revoke tokens for provider=%s: %s",
provider,
e,
)
logger.info("Disconnected provider=%s", provider)
return True
async def list_calendar_events(
self,
provider: str | None = None,
hours_ahead: int | None = None,
limit: int | None = None,
) -> list[CalendarEventInfo]:
"""Fetch calendar events from connected providers.
Args:
provider: Optional provider to fetch from (fetches all if None).
hours_ahead: Hours to look ahead (defaults to settings).
limit: Maximum events per provider (defaults to settings).
Returns:
List of calendar events sorted by start time.
Raises:
CalendarServiceError: If no providers connected or fetch fails.
"""
effective_hours = hours_ahead or self._settings.sync_hours_ahead
effective_limit = limit or self._settings.max_events
events: list[CalendarEventInfo] = []
if provider:
provider_events = await self._fetch_provider_events(
provider=provider,
hours_ahead=effective_hours,
limit=effective_limit,
)
events.extend(provider_events)
else:
# Fetch from all connected providers
for p in ["google", "outlook"]:
try:
provider_events = await self._fetch_provider_events(
provider=p,
hours_ahead=effective_hours,
limit=effective_limit,
)
events.extend(provider_events)
except CalendarServiceError:
continue # Skip disconnected providers
# Sort by start time
events.sort(key=lambda e: e.start_time)
return events
async def _fetch_provider_events(
self,
provider: str,
hours_ahead: int,
limit: int,
) -> list[CalendarEventInfo]:
"""Fetch events from a specific provider with token refresh."""
oauth_provider = self._parse_provider(provider)
async with self._uow_factory() as uow:
integration = await uow.integrations.get_by_provider(
provider=provider,
integration_type=IntegrationType.CALENDAR.value,
)
if integration is None or not integration.is_connected:
raise CalendarServiceError(f"Provider {provider} not connected")
secrets = await uow.integrations.get_secrets(integration.id)
if not secrets:
raise CalendarServiceError(f"No tokens for provider {provider}")
try:
tokens = OAuthTokens.from_secrets_dict(secrets)
except (KeyError, ValueError) as e:
raise CalendarServiceError(f"Invalid tokens: {e}") from e
# Refresh if expired
if tokens.is_expired() and tokens.refresh_token:
try:
tokens = await self._oauth_manager.refresh_tokens(
provider=oauth_provider,
refresh_token=tokens.refresh_token,
)
await uow.integrations.set_secrets(
integration_id=integration.id,
secrets=tokens.to_secrets_dict(),
)
await uow.commit()
except OAuthError as e:
integration.mark_error(f"Token refresh failed: {e}")
await uow.integrations.update(integration)
await uow.commit()
raise CalendarServiceError(f"Token refresh failed: {e}") from e
# Fetch events
try:
events = await self._fetch_events(
oauth_provider,
tokens.access_token,
hours_ahead,
limit,
)
integration.record_sync()
await uow.integrations.update(integration)
await uow.commit()
return events
except (GoogleCalendarError, OutlookCalendarError) as e:
integration.mark_error(str(e))
await uow.integrations.update(integration)
await uow.commit()
raise CalendarServiceError(str(e)) from e
async def _fetch_events(
self,
provider: OAuthProvider,
access_token: str,
hours_ahead: int,
limit: int,
) -> list[CalendarEventInfo]:
"""Fetch events from provider API."""
adapter = self._get_adapter(provider)
return await adapter.list_events(
access_token=access_token,
hours_ahead=hours_ahead,
limit=limit,
)
async def _get_user_email(
self,
provider: OAuthProvider,
access_token: str,
) -> str:
"""Get user email from provider API."""
adapter = self._get_adapter(provider)
return await adapter.get_user_email(access_token)
def _get_adapter(
self,
provider: OAuthProvider,
) -> GoogleCalendarAdapter | OutlookCalendarAdapter:
"""Get calendar adapter for provider."""
if provider == OAuthProvider.GOOGLE:
return self._google_adapter
return self._outlook_adapter
@staticmethod
def _parse_provider(provider: str) -> OAuthProvider:
"""Parse and validate provider string."""
try:
return OAuthProvider(provider.lower())
except ValueError as e:
raise CalendarServiceError(
f"Invalid provider: {provider}. Must be 'google' or 'outlook'."
) from e
@staticmethod
def _map_integration_status(status: IntegrationStatus) -> str:
"""Map IntegrationStatus to connection status string."""
mapping = {
IntegrationStatus.CONNECTED: "connected",
IntegrationStatus.DISCONNECTED: "disconnected",
IntegrationStatus.ERROR: "error",
}
return mapping.get(status, "disconnected")

View File

@@ -84,13 +84,14 @@ class ExportService:
ValueError: If meeting not found.
"""
async with self._uow:
meeting = await self._uow.meetings.get(meeting_id)
if meeting is None:
raise ValueError(f"Meeting {meeting_id} not found")
found_meeting = await self._uow.meetings.get(meeting_id)
if not found_meeting:
msg = f"Meeting {meeting_id} not found"
raise ValueError(msg)
segments = await self._uow.segments.get_by_meeting(meeting_id)
exporter = self._get_exporter(fmt)
return exporter.export(meeting, segments)
return exporter.export(found_meeting, segments)
async def export_to_file(
self,
@@ -123,7 +124,10 @@ class ExportService:
output_path = output_path.with_suffix(exporter.file_extension)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(content, encoding="utf-8")
if isinstance(content, bytes):
output_path.write_bytes(content)
else:
output_path.write_text(content, encoding="utf-8")
return output_path
def _infer_format_from_extension(self, extension: str) -> ExportFormat:

View File

@@ -90,13 +90,13 @@ class MeetingService:
"""List meetings with optional filtering.
Args:
states: Optional list of states to filter by.
limit: Maximum number of meetings to return.
offset: Number of meetings to skip.
sort_desc: Sort by created_at descending if True.
states: Filter to specific meeting states (None = all).
limit: Maximum results to return.
offset: Number of results to skip for pagination.
sort_desc: If True, newest meetings first.
Returns:
Tuple of (meetings list, total count).
Tuple of (meeting sequence, total matching count).
"""
async with self._uow:
return await self._uow.meetings.list_all(

View File

@@ -0,0 +1,263 @@
"""Named Entity Recognition application service.
Orchestrates NER extraction, caching, and persistence following hexagonal architecture:
gRPC mixin → NerService (application) → NerEngine (infrastructure)
"""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING
from noteflow.config.settings import get_settings
from noteflow.domain.entities.named_entity import NamedEntity
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from uuid import UUID
from noteflow.domain.ports.ner import NerPort
from noteflow.domain.value_objects import MeetingId
from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWork
UoWFactory = Callable[[], SqlAlchemyUnitOfWork]
logger = logging.getLogger(__name__)
@dataclass
class ExtractionResult:
"""Result of entity extraction.
Attributes:
entities: List of extracted named entities.
cached: Whether the result came from cache.
total_count: Total number of entities.
"""
entities: Sequence[NamedEntity]
cached: bool
total_count: int
class NerService:
"""Application service for Named Entity Recognition.
Provides a clean interface for NER operations, abstracting away
the infrastructure details (spaCy engine, database persistence).
Orchestrates:
- Feature flag checking
- Cache lookup (return existing entities if available)
- Extraction via NerPort adapter
- Persistence via UnitOfWork
"""
def __init__(
self,
ner_engine: NerPort,
uow_factory: UoWFactory,
) -> None:
"""Initialize NER service.
Args:
ner_engine: NER engine implementation (infrastructure adapter).
uow_factory: Factory for creating Unit of Work instances.
"""
self._ner_engine = ner_engine
self._uow_factory = uow_factory
self._extraction_lock = asyncio.Lock()
self._model_load_lock = asyncio.Lock()
async def extract_entities(
self,
meeting_id: MeetingId,
force_refresh: bool = False,
) -> ExtractionResult:
"""Extract named entities from a meeting's transcript.
Checks for cached results first, unless force_refresh is True.
Persists new extractions to the database.
Args:
meeting_id: Meeting to extract entities from.
force_refresh: If True, re-extract even if cached results exist.
Returns:
ExtractionResult with entities and metadata.
Raises:
ValueError: If meeting not found or has no segments.
RuntimeError: If NER feature is disabled.
"""
self._check_feature_enabled()
# Check cache and get segments in one transaction
cached_or_segments = await self._get_cached_or_segments(meeting_id, force_refresh)
if isinstance(cached_or_segments, ExtractionResult):
return cached_or_segments
segments = cached_or_segments
# Extract and persist
entities = await self._extract_with_lock(segments)
for entity in entities:
entity.meeting_id = meeting_id
await self._persist_entities(meeting_id, entities, force_refresh)
logger.info(
"Extracted %d entities from meeting %s (%d segments)",
len(entities),
meeting_id,
len(segments),
)
return ExtractionResult(entities=entities, cached=False, total_count=len(entities))
def _check_feature_enabled(self) -> None:
"""Raise if NER feature is disabled."""
settings = get_settings()
if not settings.feature_flags.ner_enabled:
raise RuntimeError("NER extraction is disabled by feature flag")
async def _get_cached_or_segments(
self,
meeting_id: MeetingId,
force_refresh: bool,
) -> ExtractionResult | list[tuple[int, str]]:
"""Check cache and return cached result or segments for extraction."""
async with self._uow_factory() as uow:
if not force_refresh:
cached = await uow.entities.get_by_meeting(meeting_id)
if cached:
logger.debug("Returning %d cached entities for meeting %s", len(cached), meeting_id)
return ExtractionResult(entities=cached, cached=True, total_count=len(cached))
meeting = await uow.meetings.get(meeting_id)
if not meeting:
raise ValueError(f"Meeting {meeting_id} not found")
# Load segments separately (not eagerly loaded on meeting)
segments = await uow.segments.get_by_meeting(meeting_id)
if not segments:
logger.debug("Meeting %s has no segments", meeting_id)
return ExtractionResult(entities=[], cached=False, total_count=0)
return [(s.segment_id, s.text) for s in segments]
async def _persist_entities(
self,
meeting_id: MeetingId,
entities: list[NamedEntity],
force_refresh: bool,
) -> None:
"""Persist extracted entities to database."""
async with self._uow_factory() as uow:
if force_refresh:
await uow.entities.delete_by_meeting(meeting_id)
await uow.entities.save_batch(entities)
await uow.commit()
async def _extract_with_lock(
self,
segments: list[tuple[int, str]],
) -> list[NamedEntity]:
"""Extract entities with concurrency control.
Ensures only one extraction runs at a time and handles
lazy model loading safely.
Args:
segments: List of (segment_id, text) tuples.
Returns:
List of extracted entities.
"""
async with self._extraction_lock:
# Ensure model is loaded (thread-safe)
if not self._ner_engine.is_ready():
async with self._model_load_lock:
if not self._ner_engine.is_ready():
# Warm up model with a simple extraction
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None,
lambda: self._ner_engine.extract("warm up"),
)
# Extract entities in executor (CPU-bound)
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
None,
self._ner_engine.extract_from_segments,
segments,
)
async def get_entities(self, meeting_id: MeetingId) -> Sequence[NamedEntity]:
"""Get cached entities for a meeting (no extraction).
Args:
meeting_id: Meeting UUID.
Returns:
List of entities (empty if not extracted yet).
"""
async with self._uow_factory() as uow:
return await uow.entities.get_by_meeting(meeting_id)
async def pin_entity(self, entity_id: UUID, is_pinned: bool = True) -> bool:
"""Mark an entity as user-verified (pinned).
Args:
entity_id: Entity UUID.
is_pinned: New pinned status.
Returns:
True if entity was found and updated.
"""
async with self._uow_factory() as uow:
result = await uow.entities.update_pinned(entity_id, is_pinned)
if result:
await uow.commit()
return result
async def clear_entities(self, meeting_id: MeetingId) -> int:
"""Delete all entities for a meeting.
Args:
meeting_id: Meeting UUID.
Returns:
Number of deleted entities.
"""
async with self._uow_factory() as uow:
count = await uow.entities.delete_by_meeting(meeting_id)
await uow.commit()
logger.info("Cleared %d entities for meeting %s", count, meeting_id)
return count
async def has_entities(self, meeting_id: MeetingId) -> bool:
"""Check if a meeting has extracted entities.
Args:
meeting_id: Meeting UUID.
Returns:
True if at least one entity exists.
"""
async with self._uow_factory() as uow:
return await uow.entities.exists_for_meeting(meeting_id)
def is_engine_ready(self) -> bool:
"""Check if NER engine model is loaded and ready for extraction.
Verifies both that the spaCy model is loaded and that the NER
feature is enabled via configuration.
Returns:
True if the engine is ready and NER is enabled.
"""
settings = get_settings()
return settings.feature_flags.ner_enabled and self._ner_engine.is_ready()

View File

@@ -79,12 +79,24 @@ class DownloadReport:
@property
def success_count(self) -> int:
"""Count of successful downloads."""
"""Count of successful downloads.
Returns:
Number of results with success=True, 0 if no results.
"""
if not self.results:
return 0
return sum(1 for r in self.results if r.success)
@property
def failure_count(self) -> int:
"""Count of failed downloads."""
"""Count of failed downloads.
Returns:
Number of results with success=False, 0 if no results.
"""
if not self.results:
return 0
return sum(1 for r in self.results if not r.success)

View File

@@ -234,6 +234,68 @@ class FeatureFlags(BaseSettings):
]
class CalendarSettings(BaseSettings):
"""Calendar integration OAuth and sync settings.
Environment variables use NOTEFLOW_CALENDAR_ prefix:
NOTEFLOW_CALENDAR_GOOGLE_CLIENT_ID: Google OAuth client ID
NOTEFLOW_CALENDAR_GOOGLE_CLIENT_SECRET: Google OAuth client secret
NOTEFLOW_CALENDAR_OUTLOOK_CLIENT_ID: Microsoft OAuth client ID
NOTEFLOW_CALENDAR_OUTLOOK_CLIENT_SECRET: Microsoft OAuth client secret
NOTEFLOW_CALENDAR_REDIRECT_URI: OAuth callback URI
NOTEFLOW_CALENDAR_SYNC_HOURS_AHEAD: Hours to look ahead for events
NOTEFLOW_CALENDAR_MAX_EVENTS: Maximum events to fetch
NOTEFLOW_CALENDAR_SYNC_INTERVAL_MINUTES: Sync interval in minutes
"""
model_config = SettingsConfigDict(
env_prefix="NOTEFLOW_CALENDAR_",
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
)
# Google OAuth
google_client_id: Annotated[
str,
Field(default="", description="Google OAuth client ID"),
]
google_client_secret: Annotated[
str,
Field(default="", description="Google OAuth client secret"),
]
# Microsoft OAuth
outlook_client_id: Annotated[
str,
Field(default="", description="Microsoft OAuth client ID"),
]
outlook_client_secret: Annotated[
str,
Field(default="", description="Microsoft OAuth client secret"),
]
# OAuth redirect
redirect_uri: Annotated[
str,
Field(default="noteflow://oauth/callback", description="OAuth callback URI"),
]
# Sync settings
sync_hours_ahead: Annotated[
int,
Field(default=24, ge=1, le=168, description="Hours to look ahead for events"),
]
max_events: Annotated[
int,
Field(default=20, ge=1, le=100, description="Maximum events to fetch"),
]
sync_interval_minutes: Annotated[
int,
Field(default=15, ge=1, le=1440, description="Sync interval in minutes"),
]
class Settings(TriggerSettings):
"""Application settings loaded from environment variables.
@@ -347,6 +409,14 @@ class Settings(TriggerSettings):
"""Return database URL as string."""
return str(self.database_url)
@property
def feature_flags(self) -> FeatureFlags:
"""Return cached feature flags instance.
Provides convenient access to feature flags via settings object.
"""
return get_feature_flags()
def _load_settings() -> Settings:
"""Load settings from environment.
@@ -398,3 +468,18 @@ def get_feature_flags() -> FeatureFlags:
Cached FeatureFlags instance loaded from environment.
"""
return _load_feature_flags()
def _load_calendar_settings() -> CalendarSettings:
"""Load calendar settings from environment."""
return CalendarSettings.model_validate({})
@lru_cache
def get_calendar_settings() -> CalendarSettings:
"""Get cached calendar settings instance.
Returns:
Cached CalendarSettings instance loaded from environment.
"""
return _load_calendar_settings()

View File

@@ -1,15 +1,22 @@
"""Domain entities for NoteFlow."""
from .annotation import Annotation
from .integration import Integration, IntegrationStatus, IntegrationType
from .meeting import Meeting
from .named_entity import EntityCategory, NamedEntity
from .segment import Segment, WordTiming
from .summary import ActionItem, KeyPoint, Summary
__all__ = [
"ActionItem",
"Annotation",
"EntityCategory",
"Integration",
"IntegrationStatus",
"IntegrationType",
"KeyPoint",
"Meeting",
"NamedEntity",
"Segment",
"Summary",
"WordTiming",

View File

@@ -0,0 +1,123 @@
"""Integration entity for external service connections."""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime
from enum import StrEnum
from uuid import UUID, uuid4
from noteflow.domain.utils.time import utc_now
class IntegrationType(StrEnum):
"""Types of integrations supported."""
AUTH = "auth"
EMAIL = "email"
CALENDAR = "calendar"
PKM = "pkm"
CUSTOM = "custom"
class IntegrationStatus(StrEnum):
"""Status of an integration connection."""
DISCONNECTED = "disconnected"
CONNECTED = "connected"
ERROR = "error"
@dataclass
class Integration:
"""External service integration entity.
Represents a connection to an external service like Google Calendar,
Microsoft Outlook, Notion, etc.
"""
id: UUID
workspace_id: UUID
name: str
type: IntegrationType
status: IntegrationStatus = IntegrationStatus.DISCONNECTED
config: dict[str, object] = field(default_factory=dict)
last_sync: datetime | None = None
error_message: str | None = None
created_at: datetime = field(default_factory=utc_now)
updated_at: datetime = field(default_factory=utc_now)
@classmethod
def create(
cls,
workspace_id: UUID,
name: str,
integration_type: IntegrationType,
config: dict[str, object] | None = None,
) -> Integration:
"""Create a new integration.
Args:
workspace_id: ID of the workspace this integration belongs to.
name: Display name for the integration.
integration_type: Type of integration.
config: Optional configuration dictionary.
Returns:
New Integration instance.
"""
now = utc_now()
return cls(
id=uuid4(),
workspace_id=workspace_id,
name=name,
type=integration_type,
status=IntegrationStatus.DISCONNECTED,
config=config or {},
created_at=now,
updated_at=now,
)
def connect(self, provider_email: str | None = None) -> None:
"""Mark integration as connected.
Args:
provider_email: Optional email of the authenticated account.
"""
self.status = IntegrationStatus.CONNECTED
self.error_message = None
self.updated_at = utc_now()
if provider_email:
self.config["provider_email"] = provider_email
def disconnect(self) -> None:
"""Mark integration as disconnected."""
self.status = IntegrationStatus.DISCONNECTED
self.error_message = None
self.updated_at = utc_now()
def mark_error(self, message: str) -> None:
"""Mark integration as having an error.
Args:
message: Error message describing the issue.
"""
self.status = IntegrationStatus.ERROR
self.error_message = message
self.updated_at = utc_now()
def record_sync(self) -> None:
"""Record a successful sync timestamp."""
self.last_sync = utc_now()
self.updated_at = self.last_sync
@property
def is_connected(self) -> bool:
"""Check if integration is connected."""
return self.status == IntegrationStatus.CONNECTED
@property
def provider_email(self) -> str | None:
"""Get the provider email if available."""
email = self.config.get("provider_email")
return str(email) if email else None

View File

@@ -197,7 +197,7 @@ class Meeting:
"""Concatenate all segment text."""
return " ".join(s.text for s in self.segments)
def is_active(self) -> bool:
def is_in_active_state(self) -> bool:
"""Check if meeting is in an active state (created or recording).
Note: STOPPING is not considered active as it's transitioning to stopped.

View File

@@ -0,0 +1,146 @@
"""Named entity domain entity for NER extraction."""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING
from uuid import UUID, uuid4
if TYPE_CHECKING:
from noteflow.domain.value_objects import MeetingId
class EntityCategory(Enum):
"""Categories for named entities.
Maps to spaCy entity types:
PERSON -> PERSON
ORG -> COMPANY
GPE/LOC/FAC -> LOCATION
PRODUCT/WORK_OF_ART -> PRODUCT
DATE/TIME -> DATE
Note: TECHNICAL and ACRONYM are placeholders for future custom pattern
matching (not currently mapped from spaCy's default NER model).
"""
PERSON = "person"
COMPANY = "company"
PRODUCT = "product"
TECHNICAL = "technical" # Future: custom pattern matching
ACRONYM = "acronym" # Future: custom pattern matching
LOCATION = "location"
DATE = "date"
OTHER = "other"
@classmethod
def from_string(cls, value: str) -> EntityCategory:
"""Convert string to EntityCategory.
Args:
value: Category string value.
Returns:
Corresponding EntityCategory.
Raises:
ValueError: If value is not a valid category.
"""
try:
return cls(value.lower())
except ValueError as e:
raise ValueError(f"Invalid entity category: {value}") from e
@dataclass
class NamedEntity:
"""A named entity extracted from a meeting transcript.
Represents a person, company, product, location, or other notable term
identified in the meeting transcript via NER processing.
"""
id: UUID = field(default_factory=uuid4)
meeting_id: MeetingId | None = None
text: str = ""
normalized_text: str = ""
category: EntityCategory = EntityCategory.OTHER
segment_ids: list[int] = field(default_factory=list)
confidence: float = 0.0
is_pinned: bool = False
# Database primary key (set after persistence)
db_id: UUID | None = None
def __post_init__(self) -> None:
"""Validate entity data and compute normalized_text."""
if not 0.0 <= self.confidence <= 1.0:
raise ValueError(f"confidence must be between 0 and 1, got {self.confidence}")
# Auto-compute normalized_text if not provided
if self.text and not self.normalized_text:
self.normalized_text = self.text.lower().strip()
@classmethod
def create(
cls,
text: str,
category: EntityCategory,
segment_ids: list[int],
confidence: float,
meeting_id: MeetingId | None = None,
) -> NamedEntity:
"""Create a new named entity with validation and normalization.
Factory method that handles text normalization, segment deduplication,
and confidence validation before entity construction.
Args:
text: The entity text as it appears in transcript.
category: Classification category.
segment_ids: Segments where entity appears (will be deduplicated and sorted).
confidence: Extraction confidence (0.0-1.0).
meeting_id: Optional meeting association.
Returns:
New NamedEntity instance with normalized fields.
Raises:
ValueError: If text is empty or confidence is out of range.
"""
# Validate required text
stripped_text = text.strip()
if not stripped_text:
raise ValueError("Entity text cannot be empty")
# Normalize and deduplicate segment_ids
unique_segments = sorted(set(segment_ids))
return cls(
text=stripped_text,
normalized_text=stripped_text.lower(),
category=category,
segment_ids=unique_segments,
confidence=confidence,
meeting_id=meeting_id,
)
@property
def occurrence_count(self) -> int:
"""Number of unique segments where this entity appears.
Returns:
Count of distinct segment IDs. Returns 0 if segment_ids is empty.
"""
if not self.segment_ids:
return 0
return len(self.segment_ids)
def merge_segments(self, other_segment_ids: list[int]) -> None:
"""Merge segment IDs from another occurrence of this entity.
Args:
other_segment_ids: Segment IDs from another occurrence.
"""
self.segment_ids = sorted(set(self.segment_ids) | set(other_segment_ids))

View File

@@ -1,13 +1,21 @@
"""Domain ports (interfaces) for NoteFlow."""
from .calendar import (
CalendarEventInfo,
CalendarPort,
OAuthConnectionInfo,
OAuthPort,
)
from .diarization import (
CancellationError,
DiarizationError,
JobAlreadyActiveError,
JobNotFoundError,
)
from .ner import NerPort
from .repositories import (
AnnotationRepository,
EntityRepository,
MeetingRepository,
SegmentRepository,
SummaryRepository,
@@ -16,11 +24,17 @@ from .unit_of_work import UnitOfWork
__all__ = [
"AnnotationRepository",
"CalendarEventInfo",
"CalendarPort",
"CancellationError",
"DiarizationError",
"EntityRepository",
"JobAlreadyActiveError",
"JobNotFoundError",
"MeetingRepository",
"NerPort",
"OAuthConnectionInfo",
"OAuthPort",
"SegmentRepository",
"SummaryRepository",
"UnitOfWork",

View File

@@ -0,0 +1,168 @@
"""Calendar integration port interfaces.
Defines protocols for OAuth operations and calendar event fetching.
Implementations live in infrastructure/calendar/.
"""
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING:
from noteflow.domain.value_objects import OAuthProvider, OAuthTokens
@dataclass(frozen=True, slots=True)
class CalendarEventInfo:
"""Calendar event information from provider API.
Richer than trigger-layer CalendarEvent, used for API responses
and caching in CalendarEventModel.
"""
id: str
title: str
start_time: datetime
end_time: datetime
attendees: tuple[str, ...]
location: str | None = None
description: str | None = None
meeting_url: str | None = None
is_recurring: bool = False
is_all_day: bool = False
provider: str = ""
raw: dict[str, object] | None = None
@dataclass(frozen=True, slots=True)
class OAuthConnectionInfo:
"""OAuth connection status for a provider."""
provider: str
status: str # "connected", "disconnected", "error", "expired"
email: str | None = None
expires_at: datetime | None = None
error_message: str | None = None
class OAuthPort(Protocol):
"""Port for OAuth operations.
Handles PKCE flow for Google and Outlook calendar providers.
"""
def initiate_auth(
self,
provider: OAuthProvider,
redirect_uri: str,
) -> tuple[str, str]:
"""Generate OAuth authorization URL with PKCE.
Args:
provider: OAuth provider (google or outlook).
redirect_uri: Callback URL after authorization.
Returns:
Tuple of (authorization_url, state_token).
"""
...
async def complete_auth(
self,
provider: OAuthProvider,
code: str,
state: str,
) -> OAuthTokens:
"""Exchange authorization code for tokens.
Args:
provider: OAuth provider.
code: Authorization code from callback.
state: State parameter from callback.
Returns:
OAuth tokens.
Raises:
ValueError: If state doesn't match or code exchange fails.
"""
...
async def refresh_tokens(
self,
provider: OAuthProvider,
refresh_token: str,
) -> OAuthTokens:
"""Refresh expired access token.
Args:
provider: OAuth provider.
refresh_token: Refresh token from previous exchange.
Returns:
New OAuth tokens.
Raises:
ValueError: If refresh fails.
"""
...
async def revoke_tokens(
self,
provider: OAuthProvider,
access_token: str,
) -> bool:
"""Revoke OAuth tokens with provider.
Args:
provider: OAuth provider.
access_token: Access token to revoke.
Returns:
True if revoked successfully.
"""
...
class CalendarPort(Protocol):
"""Port for calendar event operations.
Each provider (Google, Outlook) implements this protocol.
"""
async def list_events(
self,
access_token: str,
hours_ahead: int = 24,
limit: int = 20,
) -> list[CalendarEventInfo]:
"""Fetch upcoming calendar events.
Args:
access_token: Provider OAuth access token.
hours_ahead: Hours to look ahead.
limit: Maximum number of events.
Returns:
Calendar events sorted chronologically.
Raises:
ValueError: If token is invalid or API call fails.
"""
...
async def get_user_email(self, access_token: str) -> str:
"""Get authenticated user's email address.
Args:
access_token: Valid OAuth access token.
Returns:
User's email address.
Raises:
ValueError: If token is invalid or API call fails.
"""
...

View File

@@ -0,0 +1,53 @@
"""NER port interface (hexagonal architecture)."""
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING:
from noteflow.domain.entities.named_entity import NamedEntity
class NerPort(Protocol):
"""Port for named entity recognition.
This is the domain port that the application layer uses.
Infrastructure adapters (like NerEngine) implement this protocol.
"""
def extract(self, text: str) -> list[NamedEntity]:
"""Extract named entities from text.
Args:
text: Input text to analyze.
Returns:
List of extracted entities (deduplicated by normalized text).
"""
...
def extract_from_segments(
self,
segments: list[tuple[int, str]],
) -> list[NamedEntity]:
"""Extract entities from multiple segments with tracking.
Processes each segment individually and tracks which segments
contain each entity via segment_ids. Entities appearing in
multiple segments are deduplicated with merged segment lists.
Args:
segments: List of (segment_id, text) tuples.
Returns:
Entities with segment_ids populated (deduplicated across segments).
"""
...
def is_ready(self) -> bool:
"""Check if the NER engine is loaded and ready.
Returns:
True if model is loaded and ready for inference.
"""
...

View File

@@ -7,12 +7,12 @@ from datetime import datetime
from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING:
from noteflow.domain.entities import Annotation, Meeting, Segment, Summary
from uuid import UUID
from noteflow.domain.entities import Annotation, Integration, Meeting, Segment, Summary
from noteflow.domain.entities.named_entity import NamedEntity
from noteflow.domain.value_objects import AnnotationId, MeetingId, MeetingState
from noteflow.infrastructure.persistence.repositories import (
DiarizationJob,
StreamingTurn,
)
from noteflow.infrastructure.persistence.repositories import DiarizationJob, StreamingTurn
class MeetingRepository(Protocol):
@@ -493,3 +493,186 @@ class PreferencesRepository(Protocol):
True if deleted, False if not found.
"""
...
class EntityRepository(Protocol):
"""Repository protocol for NamedEntity operations (NER results)."""
async def save(self, entity: NamedEntity) -> NamedEntity:
"""Save or update a named entity.
Args:
entity: Entity to save.
Returns:
Saved entity with db_id populated.
"""
...
async def save_batch(self, entities: Sequence[NamedEntity]) -> Sequence[NamedEntity]:
"""Save multiple entities efficiently.
Args:
entities: Entities to save.
Returns:
Saved entities with db_ids populated.
"""
...
async def get(self, entity_id: UUID) -> NamedEntity | None:
"""Get entity by ID.
Args:
entity_id: Entity UUID.
Returns:
Entity if found, None otherwise.
"""
...
async def get_by_meeting(self, meeting_id: MeetingId) -> Sequence[NamedEntity]:
"""Get all entities for a meeting.
Args:
meeting_id: Meeting UUID.
Returns:
List of entities.
"""
...
async def delete_by_meeting(self, meeting_id: MeetingId) -> int:
"""Delete all entities for a meeting.
Args:
meeting_id: Meeting UUID.
Returns:
Number of deleted entities.
"""
...
async def update_pinned(self, entity_id: UUID, is_pinned: bool) -> bool:
"""Update the pinned status of an entity.
Args:
entity_id: Entity UUID.
is_pinned: New pinned status.
Returns:
True if entity was found and updated.
"""
...
async def exists_for_meeting(self, meeting_id: MeetingId) -> bool:
"""Check if any entities exist for a meeting.
Args:
meeting_id: Meeting UUID.
Returns:
True if at least one entity exists.
"""
...
class IntegrationRepository(Protocol):
"""Repository protocol for external service integrations.
Manages OAuth-connected services like calendars, email providers, and PKM tools.
"""
async def get(self, integration_id: UUID) -> Integration | None:
"""Retrieve an integration by ID.
Args:
integration_id: Integration UUID.
Returns:
Integration if found, None otherwise.
"""
...
async def get_by_provider(
self,
provider: str,
integration_type: str | None = None,
) -> Integration | None:
"""Retrieve an integration by provider name.
Args:
provider: Provider name (e.g., 'google', 'outlook').
integration_type: Optional type filter.
Returns:
Integration if found, None otherwise.
"""
...
async def create(self, integration: Integration) -> Integration:
"""Persist a new integration.
Args:
integration: Integration to create.
Returns:
Created integration.
"""
...
async def update(self, integration: Integration) -> Integration:
"""Update an existing integration.
Args:
integration: Integration with updated fields.
Returns:
Updated integration.
Raises:
ValueError: If integration does not exist.
"""
...
async def delete(self, integration_id: UUID) -> bool:
"""Delete an integration and its secrets.
Args:
integration_id: Integration UUID.
Returns:
True if deleted, False if not found.
"""
...
async def get_secrets(self, integration_id: UUID) -> dict[str, str] | None:
"""Get encrypted secrets for an integration.
Args:
integration_id: Integration UUID.
Returns:
Dictionary of secret key-value pairs, or None if not found.
"""
...
async def set_secrets(self, integration_id: UUID, secrets: dict[str, str]) -> None:
"""Store encrypted secrets for an integration.
Args:
integration_id: Integration UUID.
secrets: Dictionary of secret key-value pairs.
"""
...
async def list_by_type(self, integration_type: str) -> Sequence[Integration]:
"""List integrations by type.
Args:
integration_type: Integration type (e.g., 'calendar', 'email').
Returns:
List of integrations of the specified type.
"""
...

View File

@@ -8,6 +8,8 @@ if TYPE_CHECKING:
from .repositories import (
AnnotationRepository,
DiarizationJobRepository,
EntityRepository,
IntegrationRepository,
MeetingRepository,
PreferencesRepository,
SegmentRepository,
@@ -65,6 +67,16 @@ class UnitOfWork(Protocol):
"""Access the preferences repository."""
...
@property
def entities(self) -> EntityRepository:
"""Access the entities repository for NER results."""
...
@property
def integrations(self) -> IntegrationRepository:
"""Access the integrations repository for OAuth connections."""
...
# Feature flags for DB-only capabilities
@property
def supports_annotations(self) -> bool:
@@ -90,6 +102,23 @@ class UnitOfWork(Protocol):
"""
...
@property
def supports_entities(self) -> bool:
"""Check if NER entity persistence is supported.
Returns False for memory-only implementations.
"""
...
@property
def supports_integrations(self) -> bool:
"""Check if OAuth integration persistence is supported.
Returns False for memory-only implementations that don't support
encrypted secret storage.
"""
...
async def __aenter__(self) -> Self:
"""Enter the unit of work context.

View File

@@ -58,8 +58,13 @@ class SummarizationResult:
@property
def is_success(self) -> bool:
"""Check if summarization succeeded with content."""
return bool(self.summary.executive_summary)
"""Check if summarization succeeded with meaningful content.
Returns:
True if executive_summary has non-whitespace content.
"""
summary_text = self.summary.executive_summary
return bool(summary_text and summary_text.strip())
@dataclass(frozen=True)

View File

@@ -16,6 +16,6 @@ def utc_now() -> datetime:
consistent timezone-aware datetime handling.
Returns:
Current datetime in UTC timezone.
Current datetime in UTC timezone with microsecond precision.
"""
return datetime.now(UTC)

View File

@@ -2,7 +2,9 @@
from __future__ import annotations
from enum import Enum, IntEnum
from dataclasses import dataclass
from datetime import datetime
from enum import Enum, IntEnum, StrEnum
from typing import NewType
from uuid import UUID
@@ -79,3 +81,88 @@ class MeetingState(IntEnum):
MeetingState.ERROR: set(), # Terminal state
}
return target in valid_transitions.get(self, set())
# OAuth value objects for calendar integration
class OAuthProvider(StrEnum):
"""Supported OAuth providers for calendar integration."""
GOOGLE = "google"
OUTLOOK = "outlook"
@dataclass(frozen=True, slots=True)
class OAuthState:
"""CSRF state for OAuth PKCE flow.
Stored in-memory with TTL for security. Contains the code_verifier
needed to complete the PKCE exchange.
"""
state: str
provider: OAuthProvider
redirect_uri: str
code_verifier: str
created_at: datetime
expires_at: datetime
def is_expired(self) -> bool:
"""Check if the state has expired."""
return datetime.now(self.created_at.tzinfo) > self.expires_at
@dataclass(frozen=True, slots=True)
class OAuthTokens:
"""OAuth tokens returned from provider.
Stored encrypted in IntegrationSecretModel via the repository.
"""
access_token: str
refresh_token: str | None
token_type: str
expires_at: datetime
scope: str
def is_expired(self) -> bool:
"""Check if the access token has expired."""
return datetime.now(self.expires_at.tzinfo) > self.expires_at
def to_secrets_dict(self) -> dict[str, str]:
"""Convert to dictionary for encrypted storage."""
result: dict[str, str] = {
"access_token": self.access_token,
"token_type": self.token_type,
"expires_at": self.expires_at.isoformat(),
"scope": self.scope,
}
if self.refresh_token:
result["refresh_token"] = self.refresh_token
return result
@classmethod
def from_secrets_dict(cls, secrets: dict[str, str]) -> OAuthTokens:
"""Create from dictionary retrieved from encrypted storage.
Args:
secrets: Token dictionary from keystore.
Returns:
OAuthTokens with parsed expiration datetime.
Raises:
KeyError: If access_token or expires_at missing.
ValueError: If expires_at is not valid ISO format.
"""
# Parse expiration datetime from ISO string
expires_at = datetime.fromisoformat(secrets["expires_at"])
return cls(
access_token=secrets["access_token"],
refresh_token=secrets.get("refresh_token"),
token_type=secrets.get("token_type", "Bearer"),
expires_at=expires_at,
scope=secrets.get("scope", ""),
)

View File

@@ -10,7 +10,7 @@ if TYPE_CHECKING:
import numpy as np
from numpy.typing import NDArray
from noteflow.grpc._types import ConnectionCallback, TranscriptCallback
from noteflow.grpc._types import ConnectionCallback, TranscriptCallback, TranscriptSegment
from noteflow.grpc.proto import noteflow_pb2_grpc
@@ -52,7 +52,7 @@ class ClientHost(Protocol):
"""Background thread for audio streaming."""
...
def _notify_transcript(self, segment: object) -> None:
def _notify_transcript(self, segment: TranscriptSegment) -> None:
"""Notify transcript callback."""
...

View File

@@ -1,8 +1,10 @@
"""gRPC service mixins for NoteFlowServicer."""
from .annotation import AnnotationMixin
from .calendar import CalendarMixin
from .diarization import DiarizationMixin
from .diarization_job import DiarizationJobMixin
from .entities import EntitiesMixin
from .export import ExportMixin
from .meeting import MeetingMixin
from .streaming import StreamingMixin
@@ -10,8 +12,10 @@ from .summarization import SummarizationMixin
__all__ = [
"AnnotationMixin",
"CalendarMixin",
"DiarizationJobMixin",
"DiarizationMixin",
"EntitiesMixin",
"ExportMixin",
"MeetingMixin",
"StreamingMixin",

View File

@@ -0,0 +1,176 @@
"""Calendar integration mixin for gRPC service."""
from __future__ import annotations
from typing import TYPE_CHECKING
import grpc.aio
from noteflow.application.services.calendar_service import CalendarServiceError
from ..proto import noteflow_pb2
from .errors import abort_internal, abort_invalid_argument, abort_unavailable
if TYPE_CHECKING:
from .protocols import ServicerHost
class CalendarMixin:
"""Mixin providing calendar integration functionality.
Requires host to implement ServicerHost protocol with _calendar_service.
Provides OAuth flow and calendar event fetching for Google and Outlook.
"""
async def ListCalendarEvents(
self: ServicerHost,
request: noteflow_pb2.ListCalendarEventsRequest,
context: grpc.aio.ServicerContext,
) -> noteflow_pb2.ListCalendarEventsResponse:
"""List upcoming calendar events from connected providers."""
if self._calendar_service is None:
await abort_unavailable(context, "Calendar integration not enabled")
provider = request.provider if request.provider else None
hours_ahead = request.hours_ahead if request.hours_ahead > 0 else None
limit = request.limit if request.limit > 0 else None
try:
events = await self._calendar_service.list_calendar_events(
provider=provider,
hours_ahead=hours_ahead,
limit=limit,
)
except CalendarServiceError as e:
await abort_internal(context, str(e))
proto_events = [
noteflow_pb2.CalendarEvent(
id=event.id,
title=event.title,
start_time=int(event.start_time.timestamp()),
end_time=int(event.end_time.timestamp()),
location=event.location or "",
attendees=list(event.attendees),
meeting_url=event.meeting_url or "",
is_recurring=event.is_recurring,
provider=event.provider,
)
for event in events
]
return noteflow_pb2.ListCalendarEventsResponse(
events=proto_events,
total_count=len(proto_events),
)
async def GetCalendarProviders(
self: ServicerHost,
request: noteflow_pb2.GetCalendarProvidersRequest,
context: grpc.aio.ServicerContext,
) -> noteflow_pb2.GetCalendarProvidersResponse:
"""Get available calendar providers with authentication status."""
if self._calendar_service is None:
await abort_unavailable(context, "Calendar integration not enabled")
providers = []
for provider_name, display_name in [
("google", "Google Calendar"),
("outlook", "Microsoft Outlook"),
]:
status = await self._calendar_service.get_connection_status(provider_name)
providers.append(
noteflow_pb2.CalendarProvider(
name=provider_name,
is_authenticated=status.status == "connected",
display_name=display_name,
)
)
return noteflow_pb2.GetCalendarProvidersResponse(providers=providers)
async def InitiateOAuth(
self: ServicerHost,
request: noteflow_pb2.InitiateOAuthRequest,
context: grpc.aio.ServicerContext,
) -> noteflow_pb2.InitiateOAuthResponse:
"""Start OAuth flow for a calendar provider."""
if self._calendar_service is None:
await abort_unavailable(context, "Calendar integration not enabled")
try:
auth_url, state = await self._calendar_service.initiate_oauth(
provider=request.provider,
redirect_uri=request.redirect_uri if request.redirect_uri else None,
)
except CalendarServiceError as e:
await abort_invalid_argument(context, str(e))
return noteflow_pb2.InitiateOAuthResponse(
auth_url=auth_url,
state=state,
)
async def CompleteOAuth(
self: ServicerHost,
request: noteflow_pb2.CompleteOAuthRequest,
context: grpc.aio.ServicerContext,
) -> noteflow_pb2.CompleteOAuthResponse:
"""Complete OAuth flow with authorization code."""
if self._calendar_service is None:
await abort_unavailable(context, "Calendar integration not enabled")
try:
success = await self._calendar_service.complete_oauth(
provider=request.provider,
code=request.code,
state=request.state,
)
except CalendarServiceError as e:
return noteflow_pb2.CompleteOAuthResponse(
success=False,
error_message=str(e),
)
# Get the provider email after successful connection
status = await self._calendar_service.get_connection_status(request.provider)
return noteflow_pb2.CompleteOAuthResponse(
success=success,
provider_email=status.email or "",
)
async def GetOAuthConnectionStatus(
self: ServicerHost,
request: noteflow_pb2.GetOAuthConnectionStatusRequest,
context: grpc.aio.ServicerContext,
) -> noteflow_pb2.GetOAuthConnectionStatusResponse:
"""Get OAuth connection status for a provider."""
if self._calendar_service is None:
await abort_unavailable(context, "Calendar integration not enabled")
status = await self._calendar_service.get_connection_status(request.provider)
connection = noteflow_pb2.OAuthConnection(
provider=status.provider,
status=status.status,
email=status.email or "",
expires_at=int(status.expires_at.timestamp()) if status.expires_at else 0,
error_message=status.error_message or "",
integration_type=request.integration_type or "calendar",
)
return noteflow_pb2.GetOAuthConnectionStatusResponse(connection=connection)
async def DisconnectOAuth(
self: ServicerHost,
request: noteflow_pb2.DisconnectOAuthRequest,
context: grpc.aio.ServicerContext,
) -> noteflow_pb2.DisconnectOAuthResponse:
"""Disconnect OAuth integration and revoke tokens."""
if self._calendar_service is None:
await abort_unavailable(context, "Calendar integration not enabled")
success = await self._calendar_service.disconnect(request.provider)
return noteflow_pb2.DisconnectOAuthResponse(success=success)

View File

@@ -0,0 +1,98 @@
"""Entity extraction gRPC mixin."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
import grpc.aio
from ..proto import noteflow_pb2
from .converters import parse_meeting_id_or_abort
from .errors import abort_failed_precondition, abort_not_found
if TYPE_CHECKING:
from noteflow.application.services.ner_service import NerService
from noteflow.domain.entities.named_entity import NamedEntity
from .protocols import ServicerHost
logger = logging.getLogger(__name__)
class EntitiesMixin:
"""Mixin for entity extraction RPC methods.
Implements the ExtractEntities RPC using the NerService application layer.
Architecture: gRPC → NerService (application) → NerEngine (infrastructure)
"""
_ner_service: NerService | None
async def ExtractEntities(
self: ServicerHost,
request: noteflow_pb2.ExtractEntitiesRequest,
context: grpc.aio.ServicerContext,
) -> noteflow_pb2.ExtractEntitiesResponse:
"""Extract named entities from meeting transcript.
Delegates to NerService for extraction, caching, and persistence.
Returns cached results if available, unless force_refresh is True.
"""
meeting_id = await parse_meeting_id_or_abort(request.meeting_id, context)
if self._ner_service is None:
await abort_failed_precondition(
context,
"NER service not configured. Set NOTEFLOW_FEATURE_NER_ENABLED=true",
)
try:
result = await self._ner_service.extract_entities(
meeting_id=meeting_id,
force_refresh=request.force_refresh,
)
except ValueError:
# Meeting not found
await abort_not_found(context, "Meeting", request.meeting_id)
except RuntimeError as e:
# Feature disabled
await abort_failed_precondition(context, str(e))
# Convert to proto
proto_entities = [
noteflow_pb2.ExtractedEntity(
id=str(entity.id),
text=entity.text,
category=entity.category.value,
segment_ids=list(entity.segment_ids),
confidence=entity.confidence,
is_pinned=entity.is_pinned,
)
for entity in result.entities
]
return noteflow_pb2.ExtractEntitiesResponse(
entities=proto_entities,
total_count=result.total_count,
cached=result.cached,
)
def entity_to_proto(entity: NamedEntity) -> noteflow_pb2.ExtractedEntity:
"""Convert domain NamedEntity to proto ExtractedEntity.
Args:
entity: NamedEntity domain object.
Returns:
Proto ExtractedEntity message.
"""
return noteflow_pb2.ExtractedEntity(
id=str(entity.id),
text=entity.text,
category=entity.category.value,
segment_ids=list(entity.segment_ids),
confidence=entity.confidence,
is_pinned=entity.is_pinned,
)

View File

@@ -12,6 +12,8 @@ from numpy.typing import NDArray
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from noteflow.application.services.calendar_service import CalendarService
from noteflow.application.services.ner_service import NerService
from noteflow.domain.entities import Meeting
from noteflow.domain.ports.unit_of_work import UnitOfWork
from noteflow.infrastructure.asr import FasterWhisperEngine, Segmenter, StreamingVad
@@ -41,10 +43,12 @@ class ServicerHost(Protocol):
_meetings_dir: Path
_crypto: AesGcmCryptoBox
# Engines
# Engines and services
_asr_engine: FasterWhisperEngine | None
_diarization_engine: DiarizationEngine | None
_summarization_service: object | None
_ner_service: NerService | None
_calendar_service: CalendarService | None
_diarization_refinement_enabled: bool
# Audio writers

View File

@@ -185,7 +185,7 @@ class StreamingMixin:
await repo.meetings.update(meeting)
await repo.commit()
next_segment_id = await repo.segments.get_next_segment_id(meeting.id)
next_segment_id = await repo.segments.compute_next_segment_id(meeting.id)
self._open_meeting_audio_writer(
meeting_id, dek, wrapped_dek, asset_path=meeting.asset_path
)

View File

@@ -80,7 +80,7 @@ class StreamingSession:
return self._meeting_id
@property
def is_active(self) -> bool:
def is_streaming(self) -> bool:
"""Check if the session is currently streaming."""
return self._started and self._thread is not None and self._thread.is_alive()

View File

@@ -93,13 +93,13 @@ class MeetingStore:
"""List meetings with optional filtering.
Args:
states: Optional list of states to filter by.
limit: Maximum number of meetings to return.
offset: Number of meetings to skip.
sort_desc: Sort by created_at descending if True.
states: Filter by these states (all if None).
limit: Max meetings per page.
offset: Pagination offset.
sort_desc: Sort by created_at descending.
Returns:
Tuple of (meetings list, total count).
Tuple of (paginated meeting list, total matching count).
"""
with self._lock:
meetings = list(self._meetings.values())
@@ -166,8 +166,8 @@ class MeetingStore:
meeting.add_segment(segment)
return meeting
def get_segments(self, meeting_id: str) -> list[Segment]:
"""Get a copy of segments for a meeting.
def fetch_segments(self, meeting_id: str) -> list[Segment]:
"""Fetch a copy of segments for in-memory meeting.
Args:
meeting_id: Meeting ID.
@@ -277,8 +277,8 @@ class MeetingStore:
meeting.ended_at = end_time
return True
def get_next_segment_id(self, meeting_id: str) -> int:
"""Get next segment ID for a meeting.
def compute_next_segment_id(self, meeting_id: str) -> int:
"""Compute next segment ID for in-memory meeting.
Args:
meeting_id: Meeting ID.

View File

@@ -48,8 +48,12 @@ service NoteFlowService {
// Calendar integration (Sprint 5)
rpc ListCalendarEvents(ListCalendarEventsRequest) returns (ListCalendarEventsResponse);
rpc GetCalendarProviders(GetCalendarProvidersRequest) returns (GetCalendarProvidersResponse);
rpc InitiateCalendarAuth(InitiateCalendarAuthRequest) returns (InitiateCalendarAuthResponse);
rpc CompleteCalendarAuth(CompleteCalendarAuthRequest) returns (CompleteCalendarAuthResponse);
// OAuth integration (generic for calendar, email, PKM, etc.)
rpc InitiateOAuth(InitiateOAuthRequest) returns (InitiateOAuthResponse);
rpc CompleteOAuth(CompleteOAuthRequest) returns (CompleteOAuthResponse);
rpc GetOAuthConnectionStatus(GetOAuthConnectionStatusRequest) returns (GetOAuthConnectionStatusResponse);
rpc DisconnectOAuth(DisconnectOAuthRequest) returns (DisconnectOAuthResponse);
}
// =============================================================================
@@ -702,15 +706,22 @@ message GetCalendarProvidersResponse {
repeated CalendarProvider providers = 1;
}
message InitiateCalendarAuthRequest {
// Provider to authenticate: google, outlook
// =============================================================================
// OAuth Integration Messages (generic for calendar, email, PKM, etc.)
// =============================================================================
message InitiateOAuthRequest {
// Provider to authenticate: google, outlook, notion, etc.
string provider = 1;
// Redirect URI for OAuth callback
string redirect_uri = 2;
// Integration type: calendar, email, pkm, custom
string integration_type = 3;
}
message InitiateCalendarAuthResponse {
message InitiateOAuthResponse {
// Authorization URL to redirect user to
string auth_url = 1;
@@ -718,7 +729,7 @@ message InitiateCalendarAuthResponse {
string state = 2;
}
message CompleteCalendarAuthRequest {
message CompleteOAuthRequest {
// Provider being authenticated
string provider = 1;
@@ -729,13 +740,62 @@ message CompleteCalendarAuthRequest {
string state = 3;
}
message CompleteCalendarAuthResponse {
message CompleteOAuthResponse {
// Whether authentication succeeded
bool success = 1;
// Error message if failed
string error_message = 2;
// Email of authenticated calendar account
// Email of authenticated account
string provider_email = 3;
}
message OAuthConnection {
// Provider name: google, outlook, notion
string provider = 1;
// Connection status: disconnected, connected, error
string status = 2;
// Email of authenticated account
string email = 3;
// Token expiration timestamp (Unix epoch seconds)
int64 expires_at = 4;
// Error message if status is error
string error_message = 5;
// Integration type: calendar, email, pkm, custom
string integration_type = 6;
}
message GetOAuthConnectionStatusRequest {
// Provider to check: google, outlook, notion
string provider = 1;
// Optional integration type filter
string integration_type = 2;
}
message GetOAuthConnectionStatusResponse {
// Connection details
OAuthConnection connection = 1;
}
message DisconnectOAuthRequest {
// Provider to disconnect
string provider = 1;
// Optional integration type
string integration_type = 2;
}
message DisconnectOAuthResponse {
// Whether disconnection succeeded
bool success = 1;
// Error message if failed
string error_message = 2;
}

File diff suppressed because one or more lines are too long

View File

@@ -620,15 +620,17 @@ class GetCalendarProvidersResponse(_message.Message):
providers: _containers.RepeatedCompositeFieldContainer[CalendarProvider]
def __init__(self, providers: _Optional[_Iterable[_Union[CalendarProvider, _Mapping]]] = ...) -> None: ...
class InitiateCalendarAuthRequest(_message.Message):
__slots__ = ("provider", "redirect_uri")
class InitiateOAuthRequest(_message.Message):
__slots__ = ("provider", "redirect_uri", "integration_type")
PROVIDER_FIELD_NUMBER: _ClassVar[int]
REDIRECT_URI_FIELD_NUMBER: _ClassVar[int]
INTEGRATION_TYPE_FIELD_NUMBER: _ClassVar[int]
provider: str
redirect_uri: str
def __init__(self, provider: _Optional[str] = ..., redirect_uri: _Optional[str] = ...) -> None: ...
integration_type: str
def __init__(self, provider: _Optional[str] = ..., redirect_uri: _Optional[str] = ..., integration_type: _Optional[str] = ...) -> None: ...
class InitiateCalendarAuthResponse(_message.Message):
class InitiateOAuthResponse(_message.Message):
__slots__ = ("auth_url", "state")
AUTH_URL_FIELD_NUMBER: _ClassVar[int]
STATE_FIELD_NUMBER: _ClassVar[int]
@@ -636,7 +638,7 @@ class InitiateCalendarAuthResponse(_message.Message):
state: str
def __init__(self, auth_url: _Optional[str] = ..., state: _Optional[str] = ...) -> None: ...
class CompleteCalendarAuthRequest(_message.Message):
class CompleteOAuthRequest(_message.Message):
__slots__ = ("provider", "code", "state")
PROVIDER_FIELD_NUMBER: _ClassVar[int]
CODE_FIELD_NUMBER: _ClassVar[int]
@@ -646,7 +648,7 @@ class CompleteCalendarAuthRequest(_message.Message):
state: str
def __init__(self, provider: _Optional[str] = ..., code: _Optional[str] = ..., state: _Optional[str] = ...) -> None: ...
class CompleteCalendarAuthResponse(_message.Message):
class CompleteOAuthResponse(_message.Message):
__slots__ = ("success", "error_message", "provider_email")
SUCCESS_FIELD_NUMBER: _ClassVar[int]
ERROR_MESSAGE_FIELD_NUMBER: _ClassVar[int]
@@ -655,3 +657,49 @@ class CompleteCalendarAuthResponse(_message.Message):
error_message: str
provider_email: str
def __init__(self, success: bool = ..., error_message: _Optional[str] = ..., provider_email: _Optional[str] = ...) -> None: ...
class OAuthConnection(_message.Message):
__slots__ = ("provider", "status", "email", "expires_at", "error_message", "integration_type")
PROVIDER_FIELD_NUMBER: _ClassVar[int]
STATUS_FIELD_NUMBER: _ClassVar[int]
EMAIL_FIELD_NUMBER: _ClassVar[int]
EXPIRES_AT_FIELD_NUMBER: _ClassVar[int]
ERROR_MESSAGE_FIELD_NUMBER: _ClassVar[int]
INTEGRATION_TYPE_FIELD_NUMBER: _ClassVar[int]
provider: str
status: str
email: str
expires_at: int
error_message: str
integration_type: str
def __init__(self, provider: _Optional[str] = ..., status: _Optional[str] = ..., email: _Optional[str] = ..., expires_at: _Optional[int] = ..., error_message: _Optional[str] = ..., integration_type: _Optional[str] = ...) -> None: ...
class GetOAuthConnectionStatusRequest(_message.Message):
__slots__ = ("provider", "integration_type")
PROVIDER_FIELD_NUMBER: _ClassVar[int]
INTEGRATION_TYPE_FIELD_NUMBER: _ClassVar[int]
provider: str
integration_type: str
def __init__(self, provider: _Optional[str] = ..., integration_type: _Optional[str] = ...) -> None: ...
class GetOAuthConnectionStatusResponse(_message.Message):
__slots__ = ("connection",)
CONNECTION_FIELD_NUMBER: _ClassVar[int]
connection: OAuthConnection
def __init__(self, connection: _Optional[_Union[OAuthConnection, _Mapping]] = ...) -> None: ...
class DisconnectOAuthRequest(_message.Message):
__slots__ = ("provider", "integration_type")
PROVIDER_FIELD_NUMBER: _ClassVar[int]
INTEGRATION_TYPE_FIELD_NUMBER: _ClassVar[int]
provider: str
integration_type: str
def __init__(self, provider: _Optional[str] = ..., integration_type: _Optional[str] = ...) -> None: ...
class DisconnectOAuthResponse(_message.Message):
__slots__ = ("success", "error_message")
SUCCESS_FIELD_NUMBER: _ClassVar[int]
ERROR_MESSAGE_FIELD_NUMBER: _ClassVar[int]
success: bool
error_message: str
def __init__(self, success: bool = ..., error_message: _Optional[str] = ...) -> None: ...

View File

@@ -143,15 +143,25 @@ class NoteFlowServiceStub(object):
request_serializer=noteflow__pb2.GetCalendarProvidersRequest.SerializeToString,
response_deserializer=noteflow__pb2.GetCalendarProvidersResponse.FromString,
_registered_method=True)
self.InitiateCalendarAuth = channel.unary_unary(
'/noteflow.NoteFlowService/InitiateCalendarAuth',
request_serializer=noteflow__pb2.InitiateCalendarAuthRequest.SerializeToString,
response_deserializer=noteflow__pb2.InitiateCalendarAuthResponse.FromString,
self.InitiateOAuth = channel.unary_unary(
'/noteflow.NoteFlowService/InitiateOAuth',
request_serializer=noteflow__pb2.InitiateOAuthRequest.SerializeToString,
response_deserializer=noteflow__pb2.InitiateOAuthResponse.FromString,
_registered_method=True)
self.CompleteCalendarAuth = channel.unary_unary(
'/noteflow.NoteFlowService/CompleteCalendarAuth',
request_serializer=noteflow__pb2.CompleteCalendarAuthRequest.SerializeToString,
response_deserializer=noteflow__pb2.CompleteCalendarAuthResponse.FromString,
self.CompleteOAuth = channel.unary_unary(
'/noteflow.NoteFlowService/CompleteOAuth',
request_serializer=noteflow__pb2.CompleteOAuthRequest.SerializeToString,
response_deserializer=noteflow__pb2.CompleteOAuthResponse.FromString,
_registered_method=True)
self.GetOAuthConnectionStatus = channel.unary_unary(
'/noteflow.NoteFlowService/GetOAuthConnectionStatus',
request_serializer=noteflow__pb2.GetOAuthConnectionStatusRequest.SerializeToString,
response_deserializer=noteflow__pb2.GetOAuthConnectionStatusResponse.FromString,
_registered_method=True)
self.DisconnectOAuth = channel.unary_unary(
'/noteflow.NoteFlowService/DisconnectOAuth',
request_serializer=noteflow__pb2.DisconnectOAuthRequest.SerializeToString,
response_deserializer=noteflow__pb2.DisconnectOAuthResponse.FromString,
_registered_method=True)
@@ -297,13 +307,26 @@ class NoteFlowServiceServicer(object):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def InitiateCalendarAuth(self, request, context):
def InitiateOAuth(self, request, context):
"""OAuth integration (generic for calendar, email, PKM, etc.)
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def CompleteOAuth(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def CompleteCalendarAuth(self, request, context):
def GetOAuthConnectionStatus(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def DisconnectOAuth(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
@@ -417,15 +440,25 @@ def add_NoteFlowServiceServicer_to_server(servicer, server):
request_deserializer=noteflow__pb2.GetCalendarProvidersRequest.FromString,
response_serializer=noteflow__pb2.GetCalendarProvidersResponse.SerializeToString,
),
'InitiateCalendarAuth': grpc.unary_unary_rpc_method_handler(
servicer.InitiateCalendarAuth,
request_deserializer=noteflow__pb2.InitiateCalendarAuthRequest.FromString,
response_serializer=noteflow__pb2.InitiateCalendarAuthResponse.SerializeToString,
'InitiateOAuth': grpc.unary_unary_rpc_method_handler(
servicer.InitiateOAuth,
request_deserializer=noteflow__pb2.InitiateOAuthRequest.FromString,
response_serializer=noteflow__pb2.InitiateOAuthResponse.SerializeToString,
),
'CompleteCalendarAuth': grpc.unary_unary_rpc_method_handler(
servicer.CompleteCalendarAuth,
request_deserializer=noteflow__pb2.CompleteCalendarAuthRequest.FromString,
response_serializer=noteflow__pb2.CompleteCalendarAuthResponse.SerializeToString,
'CompleteOAuth': grpc.unary_unary_rpc_method_handler(
servicer.CompleteOAuth,
request_deserializer=noteflow__pb2.CompleteOAuthRequest.FromString,
response_serializer=noteflow__pb2.CompleteOAuthResponse.SerializeToString,
),
'GetOAuthConnectionStatus': grpc.unary_unary_rpc_method_handler(
servicer.GetOAuthConnectionStatus,
request_deserializer=noteflow__pb2.GetOAuthConnectionStatusRequest.FromString,
response_serializer=noteflow__pb2.GetOAuthConnectionStatusResponse.SerializeToString,
),
'DisconnectOAuth': grpc.unary_unary_rpc_method_handler(
servicer.DisconnectOAuth,
request_deserializer=noteflow__pb2.DisconnectOAuthRequest.FromString,
response_serializer=noteflow__pb2.DisconnectOAuthResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
@@ -1010,7 +1043,7 @@ class NoteFlowService(object):
_registered_method=True)
@staticmethod
def InitiateCalendarAuth(request,
def InitiateOAuth(request,
target,
options=(),
channel_credentials=None,
@@ -1023,9 +1056,9 @@ class NoteFlowService(object):
return grpc.experimental.unary_unary(
request,
target,
'/noteflow.NoteFlowService/InitiateCalendarAuth',
noteflow__pb2.InitiateCalendarAuthRequest.SerializeToString,
noteflow__pb2.InitiateCalendarAuthResponse.FromString,
'/noteflow.NoteFlowService/InitiateOAuth',
noteflow__pb2.InitiateOAuthRequest.SerializeToString,
noteflow__pb2.InitiateOAuthResponse.FromString,
options,
channel_credentials,
insecure,
@@ -1037,7 +1070,7 @@ class NoteFlowService(object):
_registered_method=True)
@staticmethod
def CompleteCalendarAuth(request,
def CompleteOAuth(request,
target,
options=(),
channel_credentials=None,
@@ -1050,9 +1083,63 @@ class NoteFlowService(object):
return grpc.experimental.unary_unary(
request,
target,
'/noteflow.NoteFlowService/CompleteCalendarAuth',
noteflow__pb2.CompleteCalendarAuthRequest.SerializeToString,
noteflow__pb2.CompleteCalendarAuthResponse.FromString,
'/noteflow.NoteFlowService/CompleteOAuth',
noteflow__pb2.CompleteOAuthRequest.SerializeToString,
noteflow__pb2.CompleteOAuthResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def GetOAuthConnectionStatus(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/noteflow.NoteFlowService/GetOAuthConnectionStatus',
noteflow__pb2.GetOAuthConnectionStatusRequest.SerializeToString,
noteflow__pb2.GetOAuthConnectionStatusResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def DisconnectOAuth(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/noteflow.NoteFlowService/DisconnectOAuth',
noteflow__pb2.DisconnectOAuthRequest.SerializeToString,
noteflow__pb2.DisconnectOAuthResponse.FromString,
options,
channel_credentials,
insecure,

View File

@@ -7,17 +7,19 @@ import asyncio
import logging
import signal
import time
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, TypedDict
import grpc.aio
from pydantic import ValidationError
from noteflow.application.services import RecoveryService
from noteflow.application.services.ner_service import NerService
from noteflow.application.services.summarization_service import SummarizationService
from noteflow.config.settings import Settings, get_settings
from noteflow.infrastructure.asr import FasterWhisperEngine
from noteflow.infrastructure.asr.engine import VALID_MODEL_SIZES
from noteflow.infrastructure.diarization import DiarizationEngine
from noteflow.infrastructure.ner import NerEngine
from noteflow.infrastructure.persistence.database import (
create_async_session_factory,
ensure_schema_ready,
@@ -41,6 +43,16 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class _DiarizationEngineKwargs(TypedDict, total=False):
"""Type-safe kwargs for DiarizationEngine initialization."""
device: str
hf_token: str | None
streaming_latency: float
min_speakers: int
max_speakers: int
class NoteFlowServer:
"""Async gRPC server for NoteFlow."""
@@ -54,6 +66,7 @@ class NoteFlowServer:
summarization_service: SummarizationService | None = None,
diarization_engine: DiarizationEngine | None = None,
diarization_refinement_enabled: bool = True,
ner_service: NerService | None = None,
) -> None:
"""Initialize the server.
@@ -66,6 +79,7 @@ class NoteFlowServer:
summarization_service: Optional summarization service for generating summaries.
diarization_engine: Optional diarization engine for speaker identification.
diarization_refinement_enabled: Whether to allow diarization refinement RPCs.
ner_service: Optional NER service for entity extraction.
"""
self._port = port
self._asr_model = asr_model
@@ -75,6 +89,7 @@ class NoteFlowServer:
self._summarization_service = summarization_service
self._diarization_engine = diarization_engine
self._diarization_refinement_enabled = diarization_refinement_enabled
self._ner_service = ner_service
self._server: grpc.aio.Server | None = None
self._servicer: NoteFlowServicer | None = None
@@ -105,13 +120,14 @@ class NoteFlowServer:
self._summarization_service = create_summarization_service()
logger.info("Summarization service initialized (default factory)")
# Create servicer with session factory, summarization, and diarization
# Create servicer with session factory, summarization, diarization, and NER
self._servicer = NoteFlowServicer(
asr_engine=asr_engine,
session_factory=self._session_factory,
summarization_service=self._summarization_service,
diarization_engine=self._diarization_engine,
diarization_refinement_enabled=self._diarization_refinement_enabled,
ner_service=self._ner_service,
)
# Create async gRPC server
@@ -217,6 +233,24 @@ async def run_server_with_config(config: GrpcServerConfig) -> None:
summarization_service.on_consent_change = persist_consent
# Create NER service if enabled
ner_service: NerService | None = None
settings = get_settings()
if settings.feature_flags.ner_enabled:
if session_factory:
logger.info("Initializing NER service (spaCy)...")
ner_engine = NerEngine()
ner_service = NerService(
ner_engine=ner_engine,
uow_factory=lambda: SqlAlchemyUnitOfWork(session_factory),
)
logger.info("NER service initialized (model loaded on demand)")
else:
logger.warning(
"NER feature enabled but no database configured. "
"NER requires database for entity persistence."
)
# Create diarization engine if enabled
diarization_engine: DiarizationEngine | None = None
diarization = config.diarization
@@ -228,7 +262,7 @@ async def run_server_with_config(config: GrpcServerConfig) -> None:
)
else:
logger.info("Initializing diarization engine on %s...", diarization.device)
diarization_kwargs: dict[str, Any] = {
diarization_kwargs: _DiarizationEngineKwargs = {
"device": diarization.device,
"hf_token": diarization.hf_token,
}
@@ -250,6 +284,7 @@ async def run_server_with_config(config: GrpcServerConfig) -> None:
summarization_service=summarization_service,
diarization_engine=diarization_engine,
diarization_refinement_enabled=diarization.refinement_enabled,
ner_service=ner_service,
)
# Set up graceful shutdown

View File

@@ -27,8 +27,10 @@ from noteflow.infrastructure.security.keystore import KeyringKeyStore
from ._mixins import (
AnnotationMixin,
CalendarMixin,
DiarizationJobMixin,
DiarizationMixin,
EntitiesMixin,
ExportMixin,
MeetingMixin,
StreamingMixin,
@@ -41,6 +43,8 @@ if TYPE_CHECKING:
from numpy.typing import NDArray
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from noteflow.application.services.calendar_service import CalendarService
from noteflow.application.services.ner_service import NerService
from noteflow.application.services.summarization_service import SummarizationService
from noteflow.infrastructure.asr import FasterWhisperEngine
from noteflow.infrastructure.diarization import DiarizationEngine, SpeakerTurn
@@ -56,6 +60,8 @@ class NoteFlowServicer(
SummarizationMixin,
AnnotationMixin,
ExportMixin,
EntitiesMixin,
CalendarMixin,
noteflow_pb2_grpc.NoteFlowServiceServicer,
):
"""Async gRPC service implementation for NoteFlow with PostgreSQL persistence."""
@@ -75,6 +81,8 @@ class NoteFlowServicer(
summarization_service: SummarizationService | None = None,
diarization_engine: DiarizationEngine | None = None,
diarization_refinement_enabled: bool = True,
ner_service: NerService | None = None,
calendar_service: CalendarService | None = None,
) -> None:
"""Initialize the service.
@@ -87,12 +95,16 @@ class NoteFlowServicer(
summarization_service: Optional summarization service for generating summaries.
diarization_engine: Optional diarization engine for speaker identification.
diarization_refinement_enabled: Whether to allow post-meeting diarization refinement.
ner_service: Optional NER service for entity extraction.
calendar_service: Optional calendar service for OAuth and event fetching.
"""
self._asr_engine = asr_engine
self._session_factory = session_factory
self._summarization_service = summarization_service
self._diarization_engine = diarization_engine
self._diarization_refinement_enabled = diarization_refinement_enabled
self._ner_service = ner_service
self._calendar_service = calendar_service
self._start_time = time.time()
# Fallback to in-memory store if no database configured
self._memory_store: MeetingStore | None = (

View File

@@ -0,0 +1,14 @@
"""Calendar integration infrastructure.
OAuth management and provider adapters for Google Calendar and Outlook.
"""
from .google_adapter import GoogleCalendarAdapter
from .oauth_manager import OAuthManager
from .outlook_adapter import OutlookCalendarAdapter
__all__ = [
"GoogleCalendarAdapter",
"OAuthManager",
"OutlookCalendarAdapter",
]

View File

@@ -0,0 +1,198 @@
"""Google Calendar API adapter.
Implements CalendarPort for Google Calendar using the Google Calendar API v3.
"""
from __future__ import annotations
import logging
from datetime import UTC, datetime, timedelta
import httpx
from noteflow.domain.ports.calendar import CalendarEventInfo, CalendarPort
logger = logging.getLogger(__name__)
class GoogleCalendarError(Exception):
"""Google Calendar API error."""
class GoogleCalendarAdapter(CalendarPort):
"""Google Calendar API adapter.
Fetches calendar events and user info using Google Calendar API v3.
Requires a valid OAuth access token with calendar.readonly scope.
"""
CALENDAR_API_BASE = "https://www.googleapis.com/calendar/v3"
USERINFO_API_URL = "https://www.googleapis.com/oauth2/v2/userinfo"
async def list_events(
self,
access_token: str,
hours_ahead: int = 24,
limit: int = 20,
) -> list[CalendarEventInfo]:
"""Fetch upcoming calendar events from Google Calendar.
Args:
access_token: Google OAuth token with calendar.readonly scope.
hours_ahead: Number of hours to look ahead.
limit: Maximum events to fetch (capped by API).
Returns:
List of Google Calendar events ordered by start time.
Raises:
GoogleCalendarError: If API call fails.
"""
now = datetime.now(UTC)
time_min = now.isoformat()
time_max = (now + timedelta(hours=hours_ahead)).isoformat()
url = f"{self.CALENDAR_API_BASE}/calendars/primary/events"
params: dict[str, str | int] = {
"timeMin": time_min,
"timeMax": time_max,
"maxResults": limit,
"singleEvents": "true",
"orderBy": "startTime",
}
headers = {"Authorization": f"Bearer {access_token}"}
async with httpx.AsyncClient() as client:
response = await client.get(url, params=params, headers=headers)
if response.status_code == 401:
raise GoogleCalendarError("Access token expired or invalid")
if response.status_code != 200:
error_msg = response.text
logger.error("Google Calendar API error: %s", error_msg)
raise GoogleCalendarError(f"API error: {error_msg}")
data = response.json()
items = data.get("items", [])
return [self._parse_event(item) for item in items]
async def get_user_email(self, access_token: str) -> str:
"""Get authenticated user's email address.
Args:
access_token: Valid OAuth access token.
Returns:
User's email address.
Raises:
GoogleCalendarError: If API call fails.
"""
headers = {"Authorization": f"Bearer {access_token}"}
async with httpx.AsyncClient() as client:
response = await client.get(self.USERINFO_API_URL, headers=headers)
if response.status_code == 401:
raise GoogleCalendarError("Access token expired or invalid")
if response.status_code != 200:
error_msg = response.text
logger.error("Google userinfo API error: %s", error_msg)
raise GoogleCalendarError(f"API error: {error_msg}")
data = response.json()
email = data.get("email")
if not email:
raise GoogleCalendarError("No email in userinfo response")
return str(email)
def _parse_event(self, item: dict[str, object]) -> CalendarEventInfo:
"""Parse Google Calendar event into CalendarEventInfo."""
event_id = str(item.get("id", ""))
title = str(item.get("summary", "Untitled"))
# Parse start/end times
start_data = item.get("start", {})
end_data = item.get("end", {})
is_all_day = "date" in start_data if isinstance(start_data, dict) else False
start_time = self._parse_datetime(start_data)
end_time = self._parse_datetime(end_data)
# Parse attendees
attendees_data = item.get("attendees", [])
attendees = tuple(
str(a.get("email", ""))
for a in attendees_data
if isinstance(a, dict) and a.get("email")
) if isinstance(attendees_data, list) else ()
# Extract meeting URL from conferenceData or hangoutLink
meeting_url = self._extract_meeting_url(item)
# Check if recurring
is_recurring = bool(item.get("recurringEventId"))
location = item.get("location")
description = item.get("description")
return CalendarEventInfo(
id=event_id,
title=title,
start_time=start_time,
end_time=end_time,
attendees=attendees,
location=str(location) if location else None,
description=str(description) if description else None,
meeting_url=meeting_url,
is_recurring=is_recurring,
is_all_day=is_all_day,
provider="google",
raw=dict(item) if isinstance(item, dict) else None,
)
def _parse_datetime(self, dt_data: object) -> datetime:
"""Parse datetime from Google Calendar format."""
if not isinstance(dt_data, dict):
return datetime.now(UTC)
# All-day events use "date", timed events use "dateTime"
dt_str = dt_data.get("dateTime") or dt_data.get("date")
if not dt_str or not isinstance(dt_str, str):
return datetime.now(UTC)
# Handle Z suffix for UTC
if dt_str.endswith("Z"):
dt_str = f"{dt_str[:-1]}+00:00"
try:
return datetime.fromisoformat(dt_str)
except ValueError:
logger.warning("Failed to parse datetime: %s", dt_str)
return datetime.now(UTC)
def _extract_meeting_url(self, item: dict[str, object]) -> str | None:
"""Extract video meeting URL from event data."""
# Try hangoutLink first (Google Meet)
hangout_link = item.get("hangoutLink")
if isinstance(hangout_link, str) and hangout_link:
return hangout_link
# Try conferenceData for other providers
conference_data = item.get("conferenceData")
if isinstance(conference_data, dict):
entry_points = conference_data.get("entryPoints", [])
if isinstance(entry_points, list):
for entry in entry_points:
if isinstance(entry, dict) and entry.get("entryPointType") == "video":
uri = entry.get("uri")
if isinstance(uri, str) and uri:
return uri
return None

View File

@@ -0,0 +1,424 @@
"""OAuth manager with PKCE support for calendar providers.
Implements the OAuthPort protocol using httpx for async HTTP requests.
Uses PKCE (Proof Key for Code Exchange) for secure OAuth 2.0 flow.
"""
from __future__ import annotations
import base64
import hashlib
import logging
import secrets
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING, ClassVar
from urllib.parse import urlencode
import httpx
from noteflow.domain.ports.calendar import OAuthPort
from noteflow.domain.value_objects import OAuthProvider, OAuthState, OAuthTokens
if TYPE_CHECKING:
from noteflow.config.settings import CalendarSettings
logger = logging.getLogger(__name__)
class OAuthError(Exception):
"""OAuth operation failed."""
class OAuthManager(OAuthPort):
"""OAuth manager implementing PKCE flow for Google and Outlook.
Manages OAuth authorization URLs, token exchange, refresh, and revocation.
State tokens are stored in-memory with TTL for security.
Deployment Note:
State tokens are stored in-memory (self._pending_states dict). This works
correctly for single-worker deployments (current Tauri desktop model) but
would require database-backed state storage for multi-worker/load-balanced
deployments where OAuth initiate and complete requests may hit different
workers.
"""
# OAuth endpoints
GOOGLE_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
GOOGLE_TOKEN_URL = "https://oauth2.googleapis.com/token"
GOOGLE_REVOKE_URL = "https://oauth2.googleapis.com/revoke"
OUTLOOK_AUTH_URL = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
OUTLOOK_TOKEN_URL = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
OUTLOOK_REVOKE_URL = "https://login.microsoftonline.com/common/oauth2/v2.0/logout"
# OAuth scopes
GOOGLE_SCOPES: ClassVar[list[str]] = [
"https://www.googleapis.com/auth/calendar.readonly",
"https://www.googleapis.com/auth/userinfo.email",
"openid",
]
OUTLOOK_SCOPES: ClassVar[list[str]] = [
"Calendars.Read",
"User.Read",
"offline_access",
"openid",
]
# State TTL (10 minutes)
STATE_TTL_SECONDS = 600
def __init__(self, settings: CalendarSettings) -> None:
"""Initialize OAuth manager with calendar settings.
Args:
settings: Calendar settings with OAuth credentials.
"""
self._settings = settings
self._pending_states: dict[str, OAuthState] = {}
def initiate_auth(
self,
provider: OAuthProvider,
redirect_uri: str,
) -> tuple[str, str]:
"""Generate OAuth authorization URL with PKCE.
Args:
provider: OAuth provider (google or outlook).
redirect_uri: Callback URL after authorization.
Returns:
Tuple of (authorization_url, state_token).
Raises:
OAuthError: If provider credentials are not configured.
"""
self._cleanup_expired_states()
self._validate_provider_config(provider)
# Generate PKCE code verifier and challenge
code_verifier = self._generate_code_verifier()
code_challenge = self._generate_code_challenge(code_verifier)
# Generate state token for CSRF protection
state_token = secrets.token_urlsafe(32)
# Store state for validation during callback
now = datetime.now(UTC)
oauth_state = OAuthState(
state=state_token,
provider=provider,
redirect_uri=redirect_uri,
code_verifier=code_verifier,
created_at=now,
expires_at=now + timedelta(seconds=self.STATE_TTL_SECONDS),
)
self._pending_states[state_token] = oauth_state
# Build authorization URL
auth_url = self._build_auth_url(
provider=provider,
redirect_uri=redirect_uri,
state=state_token,
code_challenge=code_challenge,
)
logger.info("Initiated OAuth flow for provider=%s", provider.value)
return auth_url, state_token
async def complete_auth(
self,
provider: OAuthProvider,
code: str,
state: str,
) -> OAuthTokens:
"""Exchange authorization code for tokens.
Args:
provider: OAuth provider.
code: Authorization code from callback.
state: State parameter from callback.
Returns:
OAuth tokens.
Raises:
OAuthError: If state is invalid, expired, or token exchange fails.
"""
# Validate and retrieve state
oauth_state = self._pending_states.pop(state, None)
if oauth_state is None:
raise OAuthError("Invalid or expired state token")
if oauth_state.is_expired():
raise OAuthError("State token has expired")
if oauth_state.provider != provider:
raise OAuthError(
f"Provider mismatch: expected {oauth_state.provider}, got {provider}"
)
# Exchange code for tokens
tokens = await self._exchange_code(
provider=provider,
code=code,
redirect_uri=oauth_state.redirect_uri,
code_verifier=oauth_state.code_verifier,
)
logger.info("Completed OAuth flow for provider=%s", provider.value)
return tokens
async def refresh_tokens(
self,
provider: OAuthProvider,
refresh_token: str,
) -> OAuthTokens:
"""Refresh expired access token.
Args:
provider: OAuth provider.
refresh_token: Refresh token from previous exchange.
Returns:
New OAuth tokens.
Raises:
OAuthError: If refresh fails.
"""
token_url = self._get_token_url(provider)
client_id, client_secret = self._get_credentials(provider)
data = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": client_id,
}
# Google requires client_secret for refresh
if provider == OAuthProvider.GOOGLE:
data["client_secret"] = client_secret
async with httpx.AsyncClient() as client:
response = await client.post(token_url, data=data)
if response.status_code != 200:
error_detail = response.text
logger.error(
"Token refresh failed for provider=%s: %s",
provider.value,
error_detail,
)
raise OAuthError(f"Token refresh failed: {error_detail}")
token_data = response.json()
tokens = self._parse_token_response(token_data, refresh_token)
logger.info("Refreshed tokens for provider=%s", provider.value)
return tokens
async def revoke_tokens(
self,
provider: OAuthProvider,
access_token: str,
) -> bool:
"""Revoke OAuth tokens with provider.
Args:
provider: OAuth provider.
access_token: Access token to revoke.
Returns:
True if revoked successfully.
"""
revoke_url = self._get_revoke_url(provider)
async with httpx.AsyncClient() as client:
if provider == OAuthProvider.GOOGLE:
response = await client.post(
revoke_url,
params={"token": access_token},
)
else:
# Outlook uses logout endpoint with token hint
response = await client.get(
revoke_url,
params={"post_logout_redirect_uri": self._settings.redirect_uri},
)
if response.status_code in (200, 204):
logger.info("Revoked tokens for provider=%s", provider.value)
return True
logger.warning(
"Token revocation returned status=%d for provider=%s",
response.status_code,
provider.value,
)
return False
def _validate_provider_config(self, provider: OAuthProvider) -> None:
"""Validate that provider credentials are configured."""
client_id, client_secret = self._get_credentials(provider)
if not client_id or not client_secret:
raise OAuthError(
f"OAuth credentials not configured for {provider.value}"
)
def _get_credentials(self, provider: OAuthProvider) -> tuple[str, str]:
"""Get client credentials for provider."""
if provider == OAuthProvider.GOOGLE:
return (
self._settings.google_client_id,
self._settings.google_client_secret,
)
return (
self._settings.outlook_client_id,
self._settings.outlook_client_secret,
)
def _get_auth_url(self, provider: OAuthProvider) -> str:
"""Get authorization URL for provider."""
if provider == OAuthProvider.GOOGLE:
return self.GOOGLE_AUTH_URL
return self.OUTLOOK_AUTH_URL
def _get_token_url(self, provider: OAuthProvider) -> str:
"""Get token URL for provider."""
if provider == OAuthProvider.GOOGLE:
return self.GOOGLE_TOKEN_URL
return self.OUTLOOK_TOKEN_URL
def _get_revoke_url(self, provider: OAuthProvider) -> str:
"""Get revoke URL for provider."""
if provider == OAuthProvider.GOOGLE:
return self.GOOGLE_REVOKE_URL
return self.OUTLOOK_REVOKE_URL
def _get_scopes(self, provider: OAuthProvider) -> list[str]:
"""Get OAuth scopes for provider."""
if provider == OAuthProvider.GOOGLE:
return self.GOOGLE_SCOPES
return self.OUTLOOK_SCOPES
def _build_auth_url(
self,
provider: OAuthProvider,
redirect_uri: str,
state: str,
code_challenge: str,
) -> str:
"""Build OAuth authorization URL with PKCE parameters."""
client_id, _ = self._get_credentials(provider)
scopes = self._get_scopes(provider)
base_url = self._get_auth_url(provider)
params = {
"client_id": client_id,
"redirect_uri": redirect_uri,
"response_type": "code",
"scope": " ".join(scopes),
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
}
# Provider-specific parameters
if provider == OAuthProvider.GOOGLE:
params["access_type"] = "offline"
params["prompt"] = "consent"
elif provider == OAuthProvider.OUTLOOK:
params["response_mode"] = "query"
return f"{base_url}?{urlencode(params)}"
async def _exchange_code(
self,
provider: OAuthProvider,
code: str,
redirect_uri: str,
code_verifier: str,
) -> OAuthTokens:
"""Exchange authorization code for tokens."""
token_url = self._get_token_url(provider)
client_id, client_secret = self._get_credentials(provider)
data = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": redirect_uri,
"client_id": client_id,
"code_verifier": code_verifier,
}
# Google requires client_secret even with PKCE
if provider == OAuthProvider.GOOGLE:
data["client_secret"] = client_secret
async with httpx.AsyncClient() as client:
response = await client.post(token_url, data=data)
if response.status_code != 200:
error_detail = response.text
logger.error(
"Token exchange failed for provider=%s: %s",
provider.value,
error_detail,
)
raise OAuthError(f"Token exchange failed: {error_detail}")
token_data = response.json()
return self._parse_token_response(token_data)
def _parse_token_response(
self,
data: dict[str, object],
existing_refresh_token: str | None = None,
) -> OAuthTokens:
"""Parse token response into OAuthTokens."""
access_token = str(data.get("access_token", ""))
if not access_token:
raise OAuthError("No access_token in response")
# Calculate expiry time
expires_in_raw = data.get("expires_in", 3600)
expires_in = int(expires_in_raw) if isinstance(expires_in_raw, (int, float, str)) else 3600
expires_at = datetime.now(UTC) + timedelta(seconds=expires_in)
# Refresh token may not be returned on refresh
refresh_token = data.get("refresh_token")
if isinstance(refresh_token, str):
final_refresh_token: str | None = refresh_token
else:
final_refresh_token = existing_refresh_token
return OAuthTokens(
access_token=access_token,
refresh_token=final_refresh_token,
token_type=str(data.get("token_type", "Bearer")),
expires_at=expires_at,
scope=str(data.get("scope", "")),
)
@staticmethod
def _generate_code_verifier() -> str:
"""Generate a cryptographically random code verifier for PKCE."""
return secrets.token_urlsafe(64)
@staticmethod
def _generate_code_challenge(verifier: str) -> str:
"""Generate code challenge from verifier using S256 method."""
digest = hashlib.sha256(verifier.encode("ascii")).digest()
return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
def _cleanup_expired_states(self) -> None:
"""Remove expired state tokens."""
now = datetime.now(UTC)
expired_keys = [
key
for key, state in self._pending_states.items()
if state.is_expired() or now > state.expires_at
]
for key in expired_keys:
del self._pending_states[key]

View File

@@ -0,0 +1,226 @@
"""Microsoft Outlook Calendar API adapter.
Implements CalendarPort for Outlook using Microsoft Graph API.
"""
from __future__ import annotations
import logging
from datetime import UTC, datetime, timedelta
import httpx
from noteflow.domain.ports.calendar import CalendarEventInfo, CalendarPort
logger = logging.getLogger(__name__)
class OutlookCalendarError(Exception):
"""Outlook Calendar API error."""
class OutlookCalendarAdapter(CalendarPort):
"""Microsoft Graph Calendar API adapter.
Fetches calendar events and user info using Microsoft Graph API.
Requires a valid OAuth access token with Calendars.Read scope.
"""
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
async def list_events(
self,
access_token: str,
hours_ahead: int = 24,
limit: int = 20,
) -> list[CalendarEventInfo]:
"""Fetch upcoming calendar events from Outlook Calendar.
Args:
access_token: Microsoft Graph OAuth token with Calendars.Read scope.
hours_ahead: Hours to look ahead from current time.
limit: Maximum events to return (capped by Graph API).
Returns:
List of Outlook calendar events ordered by start datetime.
Raises:
OutlookCalendarError: If API call fails.
"""
now = datetime.now(UTC)
start_time = now.strftime("%Y-%m-%dT%H:%M:%SZ")
end_time = (now + timedelta(hours=hours_ahead)).strftime("%Y-%m-%dT%H:%M:%SZ")
url = f"{self.GRAPH_API_BASE}/me/calendarView"
params: dict[str, str | int] = {
"startDateTime": start_time,
"endDateTime": end_time,
"$top": limit,
"$orderby": "start/dateTime",
"$select": (
"id,subject,start,end,location,bodyPreview,"
"attendees,isAllDay,seriesMasterId,onlineMeeting,onlineMeetingUrl"
),
}
headers = {
"Authorization": f"Bearer {access_token}",
"Prefer": 'outlook.timezone="UTC"',
}
async with httpx.AsyncClient() as client:
response = await client.get(url, params=params, headers=headers)
if response.status_code == 401:
raise OutlookCalendarError("Access token expired or invalid")
if response.status_code != 200:
error_msg = response.text
logger.error("Microsoft Graph API error: %s", error_msg)
raise OutlookCalendarError(f"API error: {error_msg}")
data = response.json()
items = data.get("value", [])
return [self._parse_event(item) for item in items]
async def get_user_email(self, access_token: str) -> str:
"""Get authenticated user's email address.
Args:
access_token: Valid OAuth access token.
Returns:
User's email address.
Raises:
OutlookCalendarError: If API call fails.
"""
url = f"{self.GRAPH_API_BASE}/me"
params = {"$select": "mail,userPrincipalName"}
headers = {"Authorization": f"Bearer {access_token}"}
async with httpx.AsyncClient() as client:
response = await client.get(url, params=params, headers=headers)
if response.status_code == 401:
raise OutlookCalendarError("Access token expired or invalid")
if response.status_code != 200:
error_msg = response.text
logger.error("Microsoft Graph API error: %s", error_msg)
raise OutlookCalendarError(f"API error: {error_msg}")
data = response.json()
# Prefer mail, fall back to userPrincipalName
email = data.get("mail") or data.get("userPrincipalName")
if not email:
raise OutlookCalendarError("No email in user profile response")
return str(email)
def _parse_event(self, item: dict[str, object]) -> CalendarEventInfo:
"""Parse Microsoft Graph event into CalendarEventInfo."""
event_id = str(item.get("id", ""))
title = str(item.get("subject", "Untitled"))
# Parse start/end times
start_data = item.get("start", {})
end_data = item.get("end", {})
start_time = self._parse_datetime(start_data)
end_time = self._parse_datetime(end_data)
# Check if all-day event
is_all_day = bool(item.get("isAllDay", False))
# Parse attendees
attendees_data = item.get("attendees", [])
attendees = self._parse_attendees(attendees_data)
# Extract meeting URL
meeting_url = self._extract_meeting_url(item)
# Check if recurring (has seriesMasterId)
is_recurring = bool(item.get("seriesMasterId"))
# Location
location_data = item.get("location", {})
location = (
str(location_data.get("displayName"))
if isinstance(location_data, dict) and location_data.get("displayName")
else None
)
# Description (bodyPreview)
description = item.get("bodyPreview")
return CalendarEventInfo(
id=event_id,
title=title,
start_time=start_time,
end_time=end_time,
attendees=attendees,
location=location,
description=str(description) if description else None,
meeting_url=meeting_url,
is_recurring=is_recurring,
is_all_day=is_all_day,
provider="outlook",
raw=dict(item) if isinstance(item, dict) else None,
)
def _parse_datetime(self, dt_data: object) -> datetime:
"""Parse datetime from Microsoft Graph format."""
if not isinstance(dt_data, dict):
return datetime.now(UTC)
dt_str = dt_data.get("dateTime")
timezone = dt_data.get("timeZone", "UTC")
if not dt_str or not isinstance(dt_str, str):
return datetime.now(UTC)
try:
# Graph API returns ISO format without timezone suffix
dt = datetime.fromisoformat(dt_str)
# If no timezone info, assume UTC (we requested UTC in Prefer header)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=UTC)
return dt
except ValueError:
logger.warning("Failed to parse datetime: %s (tz: %s)", dt_str, timezone)
return datetime.now(UTC)
def _parse_attendees(self, attendees_data: object) -> tuple[str, ...]:
"""Parse attendees from Microsoft Graph format."""
if not isinstance(attendees_data, list):
return ()
emails: list[str] = []
for attendee in attendees_data:
if not isinstance(attendee, dict):
continue
email_address = attendee.get("emailAddress", {})
if isinstance(email_address, dict):
email = email_address.get("address")
if email and isinstance(email, str):
emails.append(email)
return tuple(emails)
def _extract_meeting_url(self, item: dict[str, object]) -> str | None:
"""Extract online meeting URL from event data."""
# Try onlineMeetingUrl first (Teams link)
online_url = item.get("onlineMeetingUrl")
if isinstance(online_url, str) and online_url:
return online_url
# Try onlineMeeting object
online_meeting = item.get("onlineMeeting")
if isinstance(online_meeting, dict):
join_url = online_meeting.get("joinUrl")
if isinstance(join_url, str) and join_url:
return join_url
return None

View File

@@ -1,9 +1,11 @@
"""Infrastructure converters for data transformation between layers."""
from noteflow.infrastructure.converters.asr_converters import AsrConverter
from noteflow.infrastructure.converters.calendar_converters import CalendarEventConverter
from noteflow.infrastructure.converters.orm_converters import OrmConverter
__all__ = [
"AsrConverter",
"CalendarEventConverter",
"OrmConverter",
]

View File

@@ -0,0 +1,116 @@
"""Calendar data converters.
Convert between CalendarEventModel (ORM) and CalendarEventInfo (domain).
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from uuid import UUID
from noteflow.domain.ports.calendar import CalendarEventInfo
from noteflow.infrastructure.triggers.calendar import CalendarEvent
if TYPE_CHECKING:
from noteflow.infrastructure.persistence.models.integrations.integration import (
CalendarEventModel,
)
class CalendarEventConverter:
"""Convert between CalendarEventModel and CalendarEventInfo."""
@staticmethod
def orm_to_info(model: CalendarEventModel, provider: str = "") -> CalendarEventInfo:
"""Convert ORM model to domain CalendarEventInfo.
Args:
model: SQLAlchemy CalendarEventModel instance.
provider: Provider name (google, outlook).
Returns:
Domain CalendarEventInfo value object.
"""
return CalendarEventInfo(
id=model.external_id,
title=model.title,
start_time=model.start_time,
end_time=model.end_time,
attendees=tuple(model.attendees) if model.attendees else (),
location=model.location,
description=model.description,
meeting_url=model.meeting_link,
is_recurring=False, # Not stored in current model
is_all_day=model.is_all_day,
provider=provider,
raw=dict(model.raw) if model.raw else None,
)
@staticmethod
def info_to_orm_kwargs(
event: CalendarEventInfo,
integration_id: UUID,
calendar_id: str = "primary",
calendar_name: str = "Primary",
) -> dict[str, object]:
"""Convert CalendarEventInfo to ORM model kwargs.
Args:
event: Domain CalendarEventInfo.
integration_id: Parent integration UUID.
calendar_id: Calendar identifier from provider.
calendar_name: Display name of the calendar.
Returns:
Kwargs dict for CalendarEventModel construction.
"""
return {
"integration_id": integration_id,
"external_id": event.id,
"calendar_id": calendar_id,
"calendar_name": calendar_name,
"title": event.title,
"description": event.description,
"start_time": event.start_time,
"end_time": event.end_time,
"location": event.location,
"attendees": list(event.attendees) if event.attendees else None,
"is_all_day": event.is_all_day,
"meeting_link": event.meeting_url,
"raw": event.raw or {},
}
@staticmethod
def info_to_trigger_event(event: CalendarEventInfo) -> CalendarEvent:
"""Convert CalendarEventInfo to trigger-layer CalendarEvent.
The trigger-layer CalendarEvent is a lightweight dataclass used
for signal detection in the trigger system.
Args:
event: Domain CalendarEventInfo.
Returns:
Trigger-layer CalendarEvent value object.
"""
return CalendarEvent(
start=event.start_time,
end=event.end_time,
title=event.title,
)
@staticmethod
def orm_to_trigger_event(model: CalendarEventModel) -> CalendarEvent:
"""Convert ORM model to trigger-layer CalendarEvent.
Args:
model: SQLAlchemy CalendarEventModel instance.
Returns:
Trigger-layer CalendarEvent value object.
"""
return CalendarEvent(
start=model.start_time,
end=model.end_time,
title=model.title,
)

View File

@@ -0,0 +1,68 @@
"""Integration domain ↔ ORM converters."""
from __future__ import annotations
from typing import TYPE_CHECKING
from noteflow.domain.entities.integration import (
Integration,
IntegrationStatus,
IntegrationType,
)
if TYPE_CHECKING:
from noteflow.infrastructure.persistence.models.integrations import IntegrationModel
class IntegrationConverter:
"""Convert between Integration domain objects and ORM models."""
@staticmethod
def orm_to_domain(model: IntegrationModel) -> Integration:
"""Convert ORM model to domain entity.
Args:
model: SQLAlchemy IntegrationModel instance.
Returns:
Domain Integration entity.
"""
return Integration(
id=model.id,
workspace_id=model.workspace_id,
name=model.name,
type=IntegrationType(model.type),
status=IntegrationStatus(model.status),
config=dict(model.config) if model.config else {},
last_sync=model.last_sync,
error_message=model.error_message,
created_at=model.created_at,
updated_at=model.updated_at,
)
@staticmethod
def to_orm_kwargs(entity: Integration) -> dict[str, object]:
"""Convert domain entity to ORM model kwargs.
Returns a dict of kwargs rather than instantiating IntegrationModel
directly to avoid circular imports and allow the repository to
handle ORM construction.
Args:
entity: Domain Integration.
Returns:
Kwargs dict for IntegrationModel construction.
"""
return {
"id": entity.id,
"workspace_id": entity.workspace_id,
"name": entity.name,
"type": entity.type.value,
"status": entity.status.value,
"config": entity.config,
"last_sync": entity.last_sync,
"error_message": entity.error_message,
"created_at": entity.created_at,
"updated_at": entity.updated_at,
}

View File

@@ -0,0 +1,69 @@
"""NER domain ↔ ORM converters."""
from __future__ import annotations
from typing import TYPE_CHECKING
from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity
from noteflow.domain.value_objects import MeetingId
if TYPE_CHECKING:
from noteflow.infrastructure.persistence.models import NamedEntityModel
class NerConverter:
"""Convert between NamedEntity domain objects and ORM models."""
@staticmethod
def orm_to_domain(model: NamedEntityModel) -> NamedEntity:
"""Convert ORM model to domain entity.
Args:
model: SQLAlchemy NamedEntityModel instance.
Returns:
Domain NamedEntity with validated and normalized fields.
Raises:
ValueError: If model data is invalid (e.g., confidence out of range).
"""
# Validate and convert segment_ids to ensure it's a proper list
segment_ids = list(model.segment_ids) if model.segment_ids else []
# Create domain entity - validation happens in __post_init__
return NamedEntity(
id=model.id,
meeting_id=MeetingId(model.meeting_id),
text=model.text,
normalized_text=model.normalized_text,
category=EntityCategory(model.category),
segment_ids=segment_ids,
confidence=model.confidence,
is_pinned=model.is_pinned,
db_id=model.id,
)
@staticmethod
def to_orm_kwargs(entity: NamedEntity) -> dict[str, object]:
"""Convert domain entity to ORM model kwargs.
Returns a dict of kwargs rather than instantiating NamedEntityModel
directly to avoid circular imports and allow the repository to
handle ORM construction.
Args:
entity: Domain NamedEntity.
Returns:
Kwargs dict for NamedEntityModel construction.
"""
return {
"id": entity.id,
"meeting_id": entity.meeting_id,
"text": entity.text,
"normalized_text": entity.normalized_text,
"category": entity.category.value,
"segment_ids": entity.segment_ids,
"confidence": entity.confidence,
"is_pinned": entity.is_pinned,
}

View File

@@ -6,7 +6,7 @@ Export meeting transcripts to PDF format.
from __future__ import annotations
import html
from typing import TYPE_CHECKING, Protocol
from typing import TYPE_CHECKING, Protocol, cast
from noteflow.infrastructure.export._formatting import format_datetime, format_timestamp
@@ -36,7 +36,10 @@ def _get_weasy_html() -> type[_WeasyHTMLProtocol] | None:
return None
weasyprint = importlib.import_module("weasyprint")
return weasyprint.HTML
html_class = getattr(weasyprint, "HTML", None)
if html_class is None:
return None
return cast(type[_WeasyHTMLProtocol], html_class)
def _escape(text: str) -> str:

View File

@@ -0,0 +1,5 @@
"""Named Entity Recognition infrastructure."""
from noteflow.infrastructure.ner.engine import NerEngine
__all__ = ["NerEngine"]

View File

@@ -0,0 +1,252 @@
"""NER engine implementation using spaCy.
Provides named entity extraction with lazy model loading and segment tracking.
"""
from __future__ import annotations
import asyncio
import logging
from functools import partial
from typing import TYPE_CHECKING, Final
from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity
if TYPE_CHECKING:
from spacy.language import Language
logger = logging.getLogger(__name__)
# Map spaCy entity types to our categories
_SPACY_CATEGORY_MAP: Final[dict[str, EntityCategory]] = {
# People
"PERSON": EntityCategory.PERSON,
# Organizations
"ORG": EntityCategory.COMPANY,
# Products and creative works
"PRODUCT": EntityCategory.PRODUCT,
"WORK_OF_ART": EntityCategory.PRODUCT,
# Locations
"GPE": EntityCategory.LOCATION, # Geo-political entity (countries, cities)
"LOC": EntityCategory.LOCATION, # Non-GPE locations (mountains, rivers)
"FAC": EntityCategory.LOCATION, # Facilities (buildings, airports)
# Dates and times
"DATE": EntityCategory.DATE,
"TIME": EntityCategory.DATE,
# Others (filtered out or mapped to OTHER)
"MONEY": EntityCategory.OTHER,
"PERCENT": EntityCategory.OTHER,
"CARDINAL": EntityCategory.OTHER,
"ORDINAL": EntityCategory.OTHER,
"QUANTITY": EntityCategory.OTHER,
"NORP": EntityCategory.OTHER, # Nationalities, religions
"EVENT": EntityCategory.OTHER,
"LAW": EntityCategory.OTHER,
"LANGUAGE": EntityCategory.OTHER,
}
# Entity types to skip (low value for meeting context)
_SKIP_ENTITY_TYPES: Final[frozenset[str]] = frozenset({
"CARDINAL",
"ORDINAL",
"QUANTITY",
"PERCENT",
"MONEY",
})
# Valid model names
VALID_SPACY_MODELS: Final[tuple[str, ...]] = (
"en_core_web_sm",
"en_core_web_md",
"en_core_web_lg",
"en_core_web_trf",
)
class NerEngine:
"""Named entity recognition engine using spaCy.
Lazy-loads the spaCy model on first use to avoid startup delay.
Implements the NerPort protocol for hexagonal architecture.
Uses chunking by segment (speaker turn/paragraph) to avoid OOM on
long transcripts while maintaining segment tracking.
"""
def __init__(self, model_name: str = "en_core_web_trf") -> None:
"""Initialize NER engine.
Args:
model_name: spaCy model to use. Defaults to en_core_web_trf
(transformer-based for higher accuracy).
"""
if model_name not in VALID_SPACY_MODELS:
raise ValueError(
f"Invalid model name: {model_name}. "
f"Valid models: {', '.join(VALID_SPACY_MODELS)}"
)
self._model_name = model_name
self._nlp: Language | None = None
def load_model(self) -> None:
"""Load the spaCy model.
Raises:
RuntimeError: If model loading fails.
"""
import spacy
logger.info("Loading spaCy model: %s", self._model_name)
try:
self._nlp = spacy.load(self._model_name)
logger.info("spaCy model loaded successfully")
except OSError as e:
msg = (
f"Failed to load spaCy model '{self._model_name}'. "
f"Run: python -m spacy download {self._model_name}"
)
raise RuntimeError(msg) from e
def _ensure_loaded(self) -> Language:
"""Ensure model is loaded, loading if necessary.
Returns:
The loaded spaCy Language model.
"""
if self._nlp is None:
self.load_model()
assert self._nlp is not None, "load_model() should set self._nlp"
return self._nlp
def is_ready(self) -> bool:
"""Check if model is loaded."""
return self._nlp is not None
def unload(self) -> None:
"""Unload the model to free memory."""
self._nlp = None
logger.info("spaCy model unloaded")
@property
def model_name(self) -> str:
"""Return the model name."""
return self._model_name
def extract(self, text: str) -> list[NamedEntity]:
"""Extract named entities from text.
Args:
text: Input text to analyze.
Returns:
List of extracted entities (deduplicated by normalized text).
"""
if not text or not text.strip():
return []
nlp = self._ensure_loaded()
doc = nlp(text)
entities: list[NamedEntity] = []
seen: set[str] = set()
for ent in doc.ents:
# Normalize for deduplication
normalized = ent.text.lower().strip()
if not normalized or normalized in seen:
continue
# Skip low-value entity types
if ent.label_ in _SKIP_ENTITY_TYPES:
continue
seen.add(normalized)
category = _SPACY_CATEGORY_MAP.get(ent.label_, EntityCategory.OTHER)
entities.append(
NamedEntity.create(
text=ent.text,
category=category,
segment_ids=[], # Filled by caller
confidence=0.8, # spaCy doesn't provide per-entity confidence
)
)
return entities
def extract_from_segments(
self,
segments: list[tuple[int, str]],
) -> list[NamedEntity]:
"""Extract entities from multiple segments with segment tracking.
Processes each segment individually (chunking by speaker turn/paragraph)
to avoid OOM on long transcripts. Entities appearing in multiple segments
are deduplicated with merged segment lists.
Args:
segments: List of (segment_id, text) tuples.
Returns:
Entities with segment_ids populated (deduplicated across segments).
"""
if not segments:
return []
# Track entities and their segment occurrences
# Key: normalized text, Value: NamedEntity
all_entities: dict[str, NamedEntity] = {}
for segment_id, text in segments:
if not text or not text.strip():
continue
segment_entities = self.extract(text)
for entity in segment_entities:
key = entity.normalized_text
if key in all_entities:
# Merge segment IDs
all_entities[key].merge_segments([segment_id])
else:
# New entity - set its segment_ids
entity.segment_ids = [segment_id]
all_entities[key] = entity
return list(all_entities.values())
async def extract_async(self, text: str) -> list[NamedEntity]:
"""Extract entities asynchronously using executor.
Offloads blocking extraction to a thread pool executor to avoid
blocking the asyncio event loop.
Args:
text: Input text to analyze.
Returns:
List of extracted entities.
"""
return await self._run_in_executor(partial(self.extract, text))
async def extract_from_segments_async(
self,
segments: list[tuple[int, str]],
) -> list[NamedEntity]:
"""Extract entities from segments asynchronously.
Args:
segments: List of (segment_id, text) tuples.
Returns:
Entities with segment_ids populated.
"""
return await self._run_in_executor(partial(self.extract_from_segments, segments))
async def _run_in_executor(
self,
func: partial[list[NamedEntity]],
) -> list[NamedEntity]:
"""Run sync extraction in thread pool executor."""
return await asyncio.get_running_loop().run_in_executor(None, func)

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import asyncio
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING
import asyncpg
from sqlalchemy import text
@@ -27,14 +27,40 @@ logger = logging.getLogger(__name__)
_original_asyncpg_connect = asyncpg.connect
async def _patched_asyncpg_connect(*args: Any, **kwargs: Any) -> Any:
async def _patched_asyncpg_connect(
dsn: str | None = None,
*,
host: str | None = None,
port: int | None = None,
user: str | None = None,
password: str | None = None,
passfile: str | None = None,
database: str | None = None,
timeout: float = 60,
ssl: object | None = None,
direct_tls: bool = False,
**kwargs: object,
) -> asyncpg.Connection:
"""Patched asyncpg.connect that filters out unsupported 'schema' parameter.
SQLAlchemy may try to pass 'schema' to asyncpg.connect(), but asyncpg
doesn't support this parameter. This wrapper filters it out.
"""
kwargs.pop("schema", None)
return await _original_asyncpg_connect(*args, **kwargs)
# Remove schema if present - SQLAlchemy passes this but asyncpg doesn't support it
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "schema"}
return await _original_asyncpg_connect(
dsn,
host=host,
port=port,
user=user,
password=password,
passfile=passfile,
database=database,
timeout=timeout,
ssl=ssl,
direct_tls=direct_tls,
**filtered_kwargs,
)
# Patch asyncpg.connect to filter out schema parameter

View File

@@ -14,7 +14,11 @@ from noteflow.domain.entities import Meeting, Segment, Summary
from noteflow.domain.value_objects import MeetingState
if TYPE_CHECKING:
from uuid import UUID
from noteflow.domain.entities import Annotation
from noteflow.domain.entities.integration import Integration
from noteflow.domain.entities.named_entity import NamedEntity
from noteflow.domain.value_objects import AnnotationId, MeetingId
from noteflow.grpc.meeting_store import MeetingStore
from noteflow.infrastructure.persistence.repositories import (
@@ -97,7 +101,7 @@ class MemorySegmentRepository:
include_words: bool = True,
) -> Sequence[Segment]:
"""Get all segments for a meeting."""
return self._store.get_segments(str(meeting_id))
return self._store.fetch_segments(str(meeting_id))
async def search_semantic(
self,
@@ -127,9 +131,14 @@ class MemorySegmentRepository:
This method exists for interface compatibility.
"""
async def get_next_segment_id(self, meeting_id: MeetingId) -> int:
"""Get next segment ID for a meeting."""
return self._store.get_next_segment_id(str(meeting_id))
async def compute_next_segment_id(self, meeting_id: MeetingId) -> int:
"""Compute next available segment ID for a meeting.
Returns:
Next sequential segment ID (0 if meeting has no segments).
"""
meeting_id_str = str(meeting_id)
return self._store.compute_next_segment_id(meeting_id_str)
class MemorySummaryRepository:
@@ -262,3 +271,109 @@ class UnsupportedPreferencesRepository:
async def delete(self, key: str) -> bool:
"""Not supported in memory mode."""
raise NotImplementedError("Preferences require database persistence")
class UnsupportedEntityRepository:
"""Entity repository that raises for unsupported operations.
Used in memory mode where NER entities require database persistence.
"""
async def save(self, entity: NamedEntity) -> NamedEntity:
"""Not supported in memory mode."""
raise NotImplementedError("NER entities require database persistence")
async def save_batch(self, entities: Sequence[NamedEntity]) -> Sequence[NamedEntity]:
"""Not supported in memory mode."""
raise NotImplementedError("NER entities require database persistence")
async def get(self, entity_id: UUID) -> NamedEntity | None:
"""Not supported in memory mode."""
raise NotImplementedError("NER entities require database persistence")
async def get_by_meeting(self, meeting_id: MeetingId) -> Sequence[NamedEntity]:
"""Not supported in memory mode."""
raise NotImplementedError("NER entities require database persistence")
async def delete_by_meeting(self, meeting_id: MeetingId) -> int:
"""Not supported in memory mode."""
raise NotImplementedError("NER entities require database persistence")
async def update_pinned(self, entity_id: UUID, is_pinned: bool) -> bool:
"""Not supported in memory mode."""
raise NotImplementedError("NER entities require database persistence")
async def exists_for_meeting(self, meeting_id: MeetingId) -> bool:
"""Not supported in memory mode."""
raise NotImplementedError("NER entities require database persistence")
class InMemoryIntegrationRepository:
"""In-memory integration repository for testing.
Provides a functional integration repository backed by dictionaries
for use in tests and memory mode.
"""
def __init__(self) -> None:
"""Initialize with empty storage."""
self._integrations: dict[UUID, Integration] = {}
self._secrets: dict[UUID, dict[str, str]] = {}
async def get(self, integration_id: UUID) -> Integration | None:
"""Retrieve an integration by ID."""
return self._integrations.get(integration_id)
async def get_by_provider(
self,
provider: str,
integration_type: str | None = None,
) -> Integration | None:
"""Retrieve an integration by provider name."""
for integration in self._integrations.values():
provider_match = (
integration.config.get("provider") == provider
or provider.lower() in integration.name.lower()
)
type_match = integration_type is None or integration.type.value == integration_type
if provider_match and type_match:
return integration
return None
async def create(self, integration: Integration) -> Integration:
"""Persist a new integration."""
self._integrations[integration.id] = integration
return integration
async def update(self, integration: Integration) -> Integration:
"""Update an existing integration."""
if integration.id not in self._integrations:
msg = f"Integration {integration.id} not found"
raise ValueError(msg)
self._integrations[integration.id] = integration
return integration
async def delete(self, integration_id: UUID) -> bool:
"""Delete an integration and its secrets."""
if integration_id not in self._integrations:
return False
del self._integrations[integration_id]
self._secrets.pop(integration_id, None)
return True
async def get_secrets(self, integration_id: UUID) -> dict[str, str] | None:
"""Get secrets for an integration."""
if integration_id not in self._integrations:
return None
return self._secrets.get(integration_id, {})
async def set_secrets(self, integration_id: UUID, secrets: dict[str, str]) -> None:
"""Store secrets for an integration."""
self._secrets[integration_id] = secrets
async def list_by_type(self, integration_type: str) -> Sequence[Integration]:
"""List integrations by type."""
return [
i for i in self._integrations.values()
if i.type.value == integration_type
]

View File

@@ -12,6 +12,8 @@ from typing import TYPE_CHECKING, Self
from noteflow.domain.ports.repositories import (
AnnotationRepository,
DiarizationJobRepository,
EntityRepository,
IntegrationRepository,
MeetingRepository,
PreferencesRepository,
SegmentRepository,
@@ -19,11 +21,13 @@ from noteflow.domain.ports.repositories import (
)
from .repositories import (
InMemoryIntegrationRepository,
MemoryMeetingRepository,
MemorySegmentRepository,
MemorySummaryRepository,
UnsupportedAnnotationRepository,
UnsupportedDiarizationJobRepository,
UnsupportedEntityRepository,
UnsupportedPreferencesRepository,
)
@@ -60,6 +64,8 @@ class MemoryUnitOfWork:
self._annotations = UnsupportedAnnotationRepository()
self._diarization_jobs = UnsupportedDiarizationJobRepository()
self._preferences = UnsupportedPreferencesRepository()
self._entities = UnsupportedEntityRepository()
self._integrations = InMemoryIntegrationRepository()
# Core repositories
@property
@@ -93,6 +99,16 @@ class MemoryUnitOfWork:
"""Get preferences repository (unsupported)."""
return self._preferences
@property
def entities(self) -> EntityRepository:
"""Get entities repository (unsupported)."""
return self._entities
@property
def integrations(self) -> IntegrationRepository:
"""Get integrations repository."""
return self._integrations
# Feature flags - limited in memory mode
@property
def supports_annotations(self) -> bool:
@@ -109,6 +125,16 @@ class MemoryUnitOfWork:
"""User preferences not supported in memory mode."""
return False
@property
def supports_entities(self) -> bool:
"""Entity extraction not supported in memory mode."""
return False
@property
def supports_integrations(self) -> bool:
"""Integration persistence supported in memory mode."""
return True
async def __aenter__(self) -> Self:
"""Enter the unit of work context.

View File

@@ -7,6 +7,8 @@ from .diarization_job_repo import (
SqlAlchemyDiarizationJobRepository,
StreamingTurn,
)
from .entity_repo import SqlAlchemyEntityRepository
from .integration_repo import SqlAlchemyIntegrationRepository
from .meeting_repo import SqlAlchemyMeetingRepository
from .preferences_repo import SqlAlchemyPreferencesRepository
from .segment_repo import SqlAlchemySegmentRepository
@@ -17,6 +19,8 @@ __all__ = [
"DiarizationJob",
"SqlAlchemyAnnotationRepository",
"SqlAlchemyDiarizationJobRepository",
"SqlAlchemyEntityRepository",
"SqlAlchemyIntegrationRepository",
"SqlAlchemyMeetingRepository",
"SqlAlchemyPreferencesRepository",
"SqlAlchemySegmentRepository",

View File

@@ -0,0 +1,140 @@
"""SQLAlchemy implementation of entity repository for named entities."""
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
from uuid import UUID
from sqlalchemy import delete, select
from noteflow.domain.entities.named_entity import NamedEntity
from noteflow.infrastructure.converters.ner_converters import NerConverter
from noteflow.infrastructure.persistence.models import NamedEntityModel
from noteflow.infrastructure.persistence.repositories._base import BaseRepository
if TYPE_CHECKING:
from noteflow.domain.value_objects import MeetingId
class SqlAlchemyEntityRepository(BaseRepository):
"""SQLAlchemy implementation of entity repository for NER results."""
async def save(self, entity: NamedEntity) -> NamedEntity:
"""Save or update a named entity.
Uses merge to handle both insert and update cases.
Args:
entity: The entity to save.
Returns:
Saved entity with db_id populated.
"""
kwargs = NerConverter.to_orm_kwargs(entity)
model = NamedEntityModel(**kwargs)
merged = await self._session.merge(model)
await self._session.flush()
entity.db_id = merged.id
return entity
async def save_batch(self, entities: Sequence[NamedEntity]) -> Sequence[NamedEntity]:
"""Save multiple entities efficiently.
Uses individual merges to handle upsert semantics with the unique
constraint on (meeting_id, normalized_text).
Args:
entities: List of entities to save.
Returns:
Saved entities with db_ids populated.
"""
for entity in entities:
kwargs = NerConverter.to_orm_kwargs(entity)
model = NamedEntityModel(**kwargs)
merged = await self._session.merge(model)
entity.db_id = merged.id
await self._session.flush()
return entities
async def get(self, entity_id: UUID) -> NamedEntity | None:
"""Get entity by ID.
Args:
entity_id: The entity UUID.
Returns:
Entity if found, None otherwise.
"""
stmt = select(NamedEntityModel).where(NamedEntityModel.id == entity_id)
model = await self._execute_scalar(stmt)
return NerConverter.orm_to_domain(model) if model else None
async def get_by_meeting(self, meeting_id: MeetingId) -> Sequence[NamedEntity]:
"""Get all entities for a meeting.
Args:
meeting_id: The meeting UUID.
Returns:
List of entities ordered by category and text.
"""
stmt = (
select(NamedEntityModel)
.where(NamedEntityModel.meeting_id == UUID(str(meeting_id)))
.order_by(NamedEntityModel.category, NamedEntityModel.text)
)
models = await self._execute_scalars(stmt)
return [NerConverter.orm_to_domain(m) for m in models]
async def delete_by_meeting(self, meeting_id: MeetingId) -> int:
"""Delete all entities for a meeting.
Args:
meeting_id: The meeting UUID.
Returns:
Number of deleted entities.
"""
stmt = delete(NamedEntityModel).where(
NamedEntityModel.meeting_id == UUID(str(meeting_id))
)
return await self._execute_delete(stmt)
async def update_pinned(self, entity_id: UUID, is_pinned: bool) -> bool:
"""Update the pinned status of an entity.
Args:
entity_id: The entity UUID.
is_pinned: New pinned status.
Returns:
True if entity was found and updated.
"""
stmt = select(NamedEntityModel).where(NamedEntityModel.id == entity_id)
model = await self._execute_scalar(stmt)
if model is None:
return False
model.is_pinned = is_pinned
await self._session.flush()
return True
async def exists_for_meeting(self, meeting_id: MeetingId) -> bool:
"""Check if any entities exist for a meeting.
More efficient than fetching all entities when only checking existence.
Args:
meeting_id: The meeting UUID.
Returns:
True if at least one entity exists.
"""
stmt = select(NamedEntityModel).where(
NamedEntityModel.meeting_id == UUID(str(meeting_id))
)
return await self._execute_exists(stmt)

View File

@@ -0,0 +1,215 @@
"""SQLAlchemy repository for Integration entities."""
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
from uuid import UUID
from sqlalchemy import delete, select
from noteflow.domain.entities.integration import Integration
from noteflow.infrastructure.converters.integration_converters import IntegrationConverter
from noteflow.infrastructure.persistence.models.integrations import (
IntegrationModel,
IntegrationSecretModel,
)
from noteflow.infrastructure.persistence.repositories._base import BaseRepository
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
class SqlAlchemyIntegrationRepository(BaseRepository):
"""SQLAlchemy implementation of IntegrationRepository.
Manages external service integrations and their encrypted secrets.
"""
def __init__(self, session: AsyncSession) -> None:
"""Initialize repository with database session.
Args:
session: SQLAlchemy async session.
"""
super().__init__(session)
async def get(self, integration_id: UUID) -> Integration | None:
"""Retrieve an integration by ID.
Args:
integration_id: Integration UUID.
Returns:
Integration if found, None otherwise.
"""
stmt = select(IntegrationModel).where(IntegrationModel.id == integration_id)
model = await self._execute_scalar(stmt)
return IntegrationConverter.orm_to_domain(model) if model else None
async def get_by_provider(
self,
provider: str,
integration_type: str | None = None,
) -> Integration | None:
"""Retrieve an integration by provider name.
Args:
provider: Provider name (stored in config['provider'] or name).
integration_type: Optional type filter.
Returns:
Integration if found, None otherwise.
"""
stmt = select(IntegrationModel).where(
IntegrationModel.config["provider"].astext == provider,
)
if integration_type:
stmt = stmt.where(IntegrationModel.type == integration_type)
model = await self._execute_scalar(stmt)
if model:
return IntegrationConverter.orm_to_domain(model)
# Fallback: check by name if provider not in config
fallback_stmt = select(IntegrationModel).where(
IntegrationModel.name.ilike(f"%{provider}%"),
)
if integration_type:
fallback_stmt = fallback_stmt.where(IntegrationModel.type == integration_type)
fallback_model = await self._execute_scalar(fallback_stmt)
return IntegrationConverter.orm_to_domain(fallback_model) if fallback_model else None
async def create(self, integration: Integration) -> Integration:
"""Persist a new integration.
Args:
integration: Integration to create.
Returns:
Created integration.
"""
kwargs = IntegrationConverter.to_orm_kwargs(integration)
model = IntegrationModel(**kwargs)
await self._add_and_flush(model)
return IntegrationConverter.orm_to_domain(model)
async def update(self, integration: Integration) -> Integration:
"""Update an existing integration.
Args:
integration: Integration with updated fields.
Returns:
Updated integration.
Raises:
ValueError: If integration does not exist.
"""
stmt = select(IntegrationModel).where(IntegrationModel.id == integration.id)
model = await self._execute_scalar(stmt)
if not model:
msg = f"Integration {integration.id} not found"
raise ValueError(msg)
# Update fields
model.name = integration.name
model.type = integration.type.value
model.status = integration.status.value
model.config = integration.config
model.last_sync = integration.last_sync
model.error_message = integration.error_message
model.updated_at = integration.updated_at
await self._session.flush()
return IntegrationConverter.orm_to_domain(model)
async def delete(self, integration_id: UUID) -> bool:
"""Delete an integration and its secrets.
Args:
integration_id: Integration UUID.
Returns:
True if deleted, False if not found.
"""
stmt = select(IntegrationModel).where(IntegrationModel.id == integration_id)
model = await self._execute_scalar(stmt)
if not model:
return False
await self._delete_and_flush(model)
return True
async def get_secrets(self, integration_id: UUID) -> dict[str, str] | None:
"""Get secrets for an integration.
Note: Secrets are stored as bytes. The caller is responsible for
encryption/decryption. This method decodes bytes to UTF-8 strings.
Args:
integration_id: Integration UUID.
Returns:
Dictionary of secret key-value pairs, or None if integration not found.
"""
# Verify integration exists
exists_stmt = select(IntegrationModel.id).where(
IntegrationModel.id == integration_id,
)
if not await self._execute_exists(exists_stmt):
return None
# Fetch secrets
stmt = select(IntegrationSecretModel).where(
IntegrationSecretModel.integration_id == integration_id,
)
models = await self._execute_scalars(stmt)
return {model.secret_key: model.secret_value.decode("utf-8") for model in models}
async def set_secrets(self, integration_id: UUID, secrets: dict[str, str]) -> None:
"""Store secrets for an integration.
Note: Secrets are stored as bytes. The caller is responsible for
encryption/decryption. This method encodes strings to UTF-8 bytes.
Args:
integration_id: Integration UUID.
secrets: Dictionary of secret key-value pairs.
"""
# Delete existing secrets
delete_stmt = delete(IntegrationSecretModel).where(
IntegrationSecretModel.integration_id == integration_id,
)
await self._session.execute(delete_stmt)
# Insert new secrets
if secrets:
models = [
IntegrationSecretModel(
integration_id=integration_id,
secret_key=key,
secret_value=value.encode("utf-8"),
)
for key, value in secrets.items()
]
await self._add_all_and_flush(models)
async def list_by_type(self, integration_type: str) -> Sequence[Integration]:
"""List integrations by type.
Args:
integration_type: Integration type (e.g., 'calendar', 'email').
Returns:
List of integrations of the specified type.
"""
stmt = (
select(IntegrationModel)
.where(IntegrationModel.type == integration_type)
.order_by(IntegrationModel.created_at.desc())
)
models = await self._execute_scalars(stmt)
return [IntegrationConverter.orm_to_domain(m) for m in models]

View File

@@ -207,8 +207,8 @@ class SqlAlchemySegmentRepository(BaseRepository):
model.speaker_confidence = speaker_confidence
await self._session.flush()
async def get_next_segment_id(self, meeting_id: MeetingId) -> int:
"""Get the next segment_id for a meeting.
async def compute_next_segment_id(self, meeting_id: MeetingId) -> int:
"""Compute the next segment_id for a meeting.
Args:
meeting_id: Meeting identifier.

View File

@@ -16,6 +16,8 @@ from noteflow.infrastructure.persistence.database import (
from .repositories import (
SqlAlchemyAnnotationRepository,
SqlAlchemyDiarizationJobRepository,
SqlAlchemyEntityRepository,
SqlAlchemyIntegrationRepository,
SqlAlchemyMeetingRepository,
SqlAlchemyPreferencesRepository,
SqlAlchemySegmentRepository,
@@ -46,6 +48,8 @@ class SqlAlchemyUnitOfWork:
self._session: AsyncSession | None = None
self._annotations_repo: SqlAlchemyAnnotationRepository | None = None
self._diarization_jobs_repo: SqlAlchemyDiarizationJobRepository | None = None
self._entities_repo: SqlAlchemyEntityRepository | None = None
self._integrations_repo: SqlAlchemyIntegrationRepository | None = None
self._meetings_repo: SqlAlchemyMeetingRepository | None = None
self._preferences_repo: SqlAlchemyPreferencesRepository | None = None
self._segments_repo: SqlAlchemySegmentRepository | None = None
@@ -98,6 +102,20 @@ class SqlAlchemyUnitOfWork:
raise RuntimeError("UnitOfWork not in context")
return self._diarization_jobs_repo
@property
def entities(self) -> SqlAlchemyEntityRepository:
"""Get entities repository for NER results."""
if self._entities_repo is None:
raise RuntimeError("UnitOfWork not in context")
return self._entities_repo
@property
def integrations(self) -> SqlAlchemyIntegrationRepository:
"""Get integrations repository for OAuth connections."""
if self._integrations_repo is None:
raise RuntimeError("UnitOfWork not in context")
return self._integrations_repo
@property
def meetings(self) -> SqlAlchemyMeetingRepository:
"""Get meetings repository."""
@@ -142,6 +160,16 @@ class SqlAlchemyUnitOfWork:
"""User preferences persistence is fully supported with database."""
return True
@property
def supports_entities(self) -> bool:
"""NER entity persistence is fully supported with database."""
return True
@property
def supports_integrations(self) -> bool:
"""OAuth integration persistence is fully supported with database."""
return True
async def __aenter__(self) -> Self:
"""Enter the unit of work context.
@@ -153,6 +181,8 @@ class SqlAlchemyUnitOfWork:
self._session = self._session_factory()
self._annotations_repo = SqlAlchemyAnnotationRepository(self._session)
self._diarization_jobs_repo = SqlAlchemyDiarizationJobRepository(self._session)
self._entities_repo = SqlAlchemyEntityRepository(self._session)
self._integrations_repo = SqlAlchemyIntegrationRepository(self._session)
self._meetings_repo = SqlAlchemyMeetingRepository(self._session)
self._preferences_repo = SqlAlchemyPreferencesRepository(self._session)
self._segments_repo = SqlAlchemySegmentRepository(self._session)
@@ -184,6 +214,8 @@ class SqlAlchemyUnitOfWork:
self._session = None
self._annotations_repo = None
self._diarization_jobs_repo = None
self._entities_repo = None
self._integrations_repo = None
self._meetings_repo = None
self._preferences_repo = None
self._segments_repo = None

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import json
import logging
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, TypedDict, cast
from noteflow.domain.entities import ActionItem, KeyPoint, Summary
from noteflow.domain.summarization import InvalidResponseError
@@ -18,6 +18,30 @@ if TYPE_CHECKING:
_logger = logging.getLogger(__name__)
class _KeyPointData(TypedDict, total=False):
"""Expected structure for key point data from LLM response."""
text: str
segment_ids: list[int]
class _ActionItemData(TypedDict, total=False):
"""Expected structure for action item data from LLM response."""
text: str
assignee: str
priority: int
segment_ids: list[int]
class _LLMResponseData(TypedDict, total=False):
"""Expected structure for LLM response JSON."""
executive_summary: str
key_points: list[_KeyPointData]
action_items: list[_ActionItemData]
_TONE_INSTRUCTIONS: dict[str, str] = {
"professional": "Use formal, business-appropriate language.",
"casual": "Use conversational, approachable language.",
@@ -153,7 +177,7 @@ def _strip_markdown_fences(text: str) -> str:
def _parse_key_point(
data: dict[str, Any],
data: _KeyPointData,
valid_ids: set[int],
segments: Sequence[Segment],
) -> KeyPoint:
@@ -181,7 +205,7 @@ def _parse_key_point(
)
def _parse_action_item(data: dict[str, Any], valid_ids: set[int]) -> ActionItem:
def _parse_action_item(data: _ActionItemData, valid_ids: set[int]) -> ActionItem:
"""Parse a single action item from LLM response data.
Args:
@@ -219,7 +243,7 @@ def parse_llm_response(response_text: str, request: SummarizationRequest) -> Sum
text = _strip_markdown_fences(response_text)
try:
data = json.loads(text)
data = cast(_LLMResponseData, json.loads(text))
except json.JSONDecodeError as e:
raise InvalidResponseError(f"Invalid JSON response: {e}") from e

View File

@@ -88,7 +88,7 @@ class CalendarProvider:
return event_start <= window_end and event_end >= window_start
def parse_calendar_events(raw_events: object) -> list[CalendarEvent]:
def parse_calendar_event_config(raw_events: object) -> list[CalendarEvent]:
"""Parse calendar events from config/env payloads."""
if raw_events is None:
return []
@@ -108,8 +108,8 @@ def parse_calendar_events(raw_events: object) -> list[CalendarEvent]:
events.append(item)
continue
if isinstance(item, dict):
start = _parse_datetime(item.get("start"))
end = _parse_datetime(item.get("end"))
start = _parse_event_datetime(item.get("start"))
end = _parse_event_datetime(item.get("end"))
if start and end:
events.append(CalendarEvent(start=start, end=end, title=item.get("title")))
return events
@@ -126,7 +126,7 @@ def _load_events_from_json(raw: str) -> list[dict[str, object]]:
return [parsed] if isinstance(parsed, dict) else []
def _parse_datetime(value: object) -> datetime | None:
def _parse_event_datetime(value: object) -> datetime | None:
if isinstance(value, datetime):
return value
if not isinstance(value, str) or not value:

View File

@@ -0,0 +1,456 @@
"""Tests for calendar service."""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import pytest
from noteflow.domain.entities import Integration, IntegrationStatus, IntegrationType
from noteflow.domain.ports.calendar import CalendarEventInfo, OAuthConnectionInfo
from noteflow.domain.value_objects import OAuthProvider, OAuthTokens
if TYPE_CHECKING:
from noteflow.config.settings import CalendarSettings
@pytest.fixture
def calendar_settings() -> CalendarSettings:
"""Create test calendar settings."""
from noteflow.config.settings import CalendarSettings
return CalendarSettings(
google_client_id="test-google-client-id",
google_client_secret="test-google-client-secret",
outlook_client_id="test-outlook-client-id",
outlook_client_secret="test-outlook-client-secret",
)
@pytest.fixture
def mock_oauth_manager() -> MagicMock:
"""Create mock OAuth manager."""
manager = MagicMock()
manager.initiate_auth.return_value = ("https://auth.example.com", "state-123")
manager.complete_auth = AsyncMock(
return_value=OAuthTokens(
access_token="access-token",
refresh_token="refresh-token",
token_type="Bearer",
expires_at=datetime.now(UTC) + timedelta(hours=1),
scope="email calendar",
)
)
manager.refresh_tokens = AsyncMock(
return_value=OAuthTokens(
access_token="new-access-token",
refresh_token="refresh-token",
token_type="Bearer",
expires_at=datetime.now(UTC) + timedelta(hours=1),
scope="email calendar",
)
)
manager.revoke_tokens = AsyncMock()
return manager
@pytest.fixture
def mock_google_adapter() -> MagicMock:
"""Create mock Google Calendar adapter."""
adapter = MagicMock()
adapter.list_events = AsyncMock(
return_value=[
CalendarEventInfo(
id="event-1",
title="Test Meeting",
start_time=datetime.now(UTC) + timedelta(hours=1),
end_time=datetime.now(UTC) + timedelta(hours=2),
attendees=("alice@example.com",),
provider="google",
)
]
)
adapter.get_user_email = AsyncMock(return_value="user@gmail.com")
return adapter
@pytest.fixture
def mock_outlook_adapter() -> MagicMock:
"""Create mock Outlook Calendar adapter."""
adapter = MagicMock()
adapter.list_events = AsyncMock(return_value=[])
adapter.get_user_email = AsyncMock(return_value="user@outlook.com")
return adapter
@pytest.fixture
def mock_uow() -> MagicMock:
"""Create mock unit of work."""
uow = MagicMock()
uow.__aenter__ = AsyncMock(return_value=uow)
uow.__aexit__ = AsyncMock(return_value=None)
uow.integrations = MagicMock()
uow.integrations.get_by_type_and_provider = AsyncMock(return_value=None)
uow.integrations.add = AsyncMock()
uow.integrations.get_secrets = AsyncMock(return_value=None)
uow.integrations.set_secrets = AsyncMock()
uow.commit = AsyncMock()
return uow
class TestCalendarServiceInitiateOAuth:
"""Tests for CalendarService.initiate_oauth."""
@pytest.mark.asyncio
async def test_initiate_oauth_returns_auth_url_and_state(
self,
calendar_settings: CalendarSettings,
mock_oauth_manager: MagicMock,
mock_google_adapter: MagicMock,
mock_outlook_adapter: MagicMock,
mock_uow: MagicMock,
) -> None:
"""initiate_oauth should return auth URL and state."""
from noteflow.application.services.calendar_service import CalendarService
service = CalendarService(
uow_factory=lambda: mock_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
google_adapter=mock_google_adapter,
outlook_adapter=mock_outlook_adapter,
)
auth_url, state = await service.initiate_oauth("google", "http://localhost/callback")
assert auth_url == "https://auth.example.com"
assert state == "state-123"
mock_oauth_manager.initiate_auth.assert_called_once()
class TestCalendarServiceCompleteOAuth:
"""Tests for CalendarService.complete_oauth."""
@pytest.mark.asyncio
async def test_complete_oauth_stores_tokens(
self,
calendar_settings: CalendarSettings,
mock_oauth_manager: MagicMock,
mock_google_adapter: MagicMock,
mock_outlook_adapter: MagicMock,
mock_uow: MagicMock,
) -> None:
"""complete_oauth should store tokens in integration secrets."""
from noteflow.application.services.calendar_service import CalendarService
mock_uow.integrations.get_by_provider = AsyncMock(return_value=None)
mock_uow.integrations.create = AsyncMock()
mock_uow.integrations.update = AsyncMock()
service = CalendarService(
uow_factory=lambda: mock_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
google_adapter=mock_google_adapter,
outlook_adapter=mock_outlook_adapter,
)
result = await service.complete_oauth("google", "auth-code", "state-123")
assert result is True
mock_oauth_manager.complete_auth.assert_called_once()
mock_uow.integrations.set_secrets.assert_called_once()
mock_uow.commit.assert_called()
@pytest.mark.asyncio
async def test_complete_oauth_creates_integration_if_not_exists(
self,
calendar_settings: CalendarSettings,
mock_oauth_manager: MagicMock,
mock_google_adapter: MagicMock,
mock_outlook_adapter: MagicMock,
mock_uow: MagicMock,
) -> None:
"""complete_oauth should create new integration if none exists."""
from noteflow.application.services.calendar_service import CalendarService
mock_uow.integrations.get_by_provider = AsyncMock(return_value=None)
mock_uow.integrations.create = AsyncMock()
mock_uow.integrations.update = AsyncMock()
service = CalendarService(
uow_factory=lambda: mock_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
google_adapter=mock_google_adapter,
outlook_adapter=mock_outlook_adapter,
)
await service.complete_oauth("google", "auth-code", "state-123")
mock_uow.integrations.create.assert_called_once()
@pytest.mark.asyncio
async def test_complete_oauth_updates_existing_integration(
self,
calendar_settings: CalendarSettings,
mock_oauth_manager: MagicMock,
mock_google_adapter: MagicMock,
mock_outlook_adapter: MagicMock,
mock_uow: MagicMock,
) -> None:
"""complete_oauth should update existing integration."""
from noteflow.application.services.calendar_service import CalendarService
existing_integration = Integration.create(
workspace_id=uuid4(),
name="Google Calendar",
integration_type=IntegrationType.CALENDAR,
config={"provider": "google"},
)
mock_uow.integrations.get_by_provider = AsyncMock(return_value=existing_integration)
mock_uow.integrations.create = AsyncMock()
mock_uow.integrations.update = AsyncMock()
service = CalendarService(
uow_factory=lambda: mock_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
google_adapter=mock_google_adapter,
outlook_adapter=mock_outlook_adapter,
)
await service.complete_oauth("google", "auth-code", "state-123")
mock_uow.integrations.create.assert_not_called()
assert existing_integration.status == IntegrationStatus.CONNECTED
class TestCalendarServiceGetConnectionStatus:
"""Tests for CalendarService.get_connection_status."""
@pytest.mark.asyncio
async def test_get_connection_status_returns_connected_info(
self,
calendar_settings: CalendarSettings,
mock_oauth_manager: MagicMock,
mock_google_adapter: MagicMock,
mock_outlook_adapter: MagicMock,
mock_uow: MagicMock,
) -> None:
"""get_connection_status should return connection info for connected provider."""
from noteflow.application.services.calendar_service import CalendarService
integration = Integration.create(
workspace_id=uuid4(),
name="Google Calendar",
integration_type=IntegrationType.CALENDAR,
config={"provider": "google"},
)
integration.connect(provider_email="user@gmail.com")
mock_uow.integrations.get_by_provider = AsyncMock(return_value=integration)
mock_uow.integrations.get_secrets = AsyncMock(return_value={
"access_token": "token",
"refresh_token": "refresh",
"token_type": "Bearer",
"expires_at": (datetime.now(UTC) + timedelta(hours=1)).isoformat(),
"scope": "calendar",
})
service = CalendarService(
uow_factory=lambda: mock_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
google_adapter=mock_google_adapter,
outlook_adapter=mock_outlook_adapter,
)
status = await service.get_connection_status("google")
assert status.status == "connected"
assert status.provider == "google"
@pytest.mark.asyncio
async def test_get_connection_status_returns_disconnected_when_no_integration(
self,
calendar_settings: CalendarSettings,
mock_oauth_manager: MagicMock,
mock_google_adapter: MagicMock,
mock_outlook_adapter: MagicMock,
mock_uow: MagicMock,
) -> None:
"""get_connection_status should return disconnected when no integration."""
from noteflow.application.services.calendar_service import CalendarService
mock_uow.integrations.get_by_provider = AsyncMock(return_value=None)
service = CalendarService(
uow_factory=lambda: mock_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
google_adapter=mock_google_adapter,
outlook_adapter=mock_outlook_adapter,
)
status = await service.get_connection_status("google")
assert status.status == "disconnected"
class TestCalendarServiceDisconnect:
"""Tests for CalendarService.disconnect."""
@pytest.mark.asyncio
async def test_disconnect_revokes_tokens_and_deletes_integration(
self,
calendar_settings: CalendarSettings,
mock_oauth_manager: MagicMock,
mock_google_adapter: MagicMock,
mock_outlook_adapter: MagicMock,
mock_uow: MagicMock,
) -> None:
"""disconnect should revoke tokens and delete integration."""
from noteflow.application.services.calendar_service import CalendarService
integration = Integration.create(
workspace_id=uuid4(),
name="Google Calendar",
integration_type=IntegrationType.CALENDAR,
config={"provider": "google"},
)
integration.connect(provider_email="user@gmail.com")
mock_uow.integrations.get_by_provider = AsyncMock(return_value=integration)
mock_uow.integrations.get_secrets = AsyncMock(return_value={"access_token": "token"})
mock_uow.integrations.delete = AsyncMock()
service = CalendarService(
uow_factory=lambda: mock_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
google_adapter=mock_google_adapter,
outlook_adapter=mock_outlook_adapter,
)
result = await service.disconnect("google")
assert result is True
mock_oauth_manager.revoke_tokens.assert_called_once()
mock_uow.integrations.delete.assert_called_once()
mock_uow.commit.assert_called()
class TestCalendarServiceListEvents:
"""Tests for CalendarService.list_calendar_events."""
@pytest.mark.asyncio
async def test_list_events_fetches_from_connected_provider(
self,
calendar_settings: CalendarSettings,
mock_oauth_manager: MagicMock,
mock_google_adapter: MagicMock,
mock_outlook_adapter: MagicMock,
mock_uow: MagicMock,
) -> None:
"""list_calendar_events should fetch events from connected provider."""
from noteflow.application.services.calendar_service import CalendarService
integration = Integration.create(
workspace_id=uuid4(),
name="Google Calendar",
integration_type=IntegrationType.CALENDAR,
config={"provider": "google"},
)
integration.connect(provider_email="user@gmail.com")
mock_uow.integrations.get_by_provider = AsyncMock(return_value=integration)
mock_uow.integrations.get_secrets = AsyncMock(return_value={
"access_token": "token",
"refresh_token": "refresh",
"token_type": "Bearer",
"expires_at": (datetime.now(UTC) + timedelta(hours=1)).isoformat(),
"scope": "calendar",
})
mock_uow.integrations.update = AsyncMock()
service = CalendarService(
uow_factory=lambda: mock_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
google_adapter=mock_google_adapter,
outlook_adapter=mock_outlook_adapter,
)
events = await service.list_calendar_events(provider="google")
assert len(events) == 1
assert events[0].title == "Test Meeting"
mock_google_adapter.list_events.assert_called_once()
@pytest.mark.asyncio
async def test_list_events_refreshes_expired_token(
self,
calendar_settings: CalendarSettings,
mock_oauth_manager: MagicMock,
mock_google_adapter: MagicMock,
mock_outlook_adapter: MagicMock,
mock_uow: MagicMock,
) -> None:
"""list_calendar_events should refresh expired token before fetching."""
from noteflow.application.services.calendar_service import CalendarService
integration = Integration.create(
workspace_id=uuid4(),
name="Google Calendar",
integration_type=IntegrationType.CALENDAR,
config={"provider": "google"},
)
integration.connect(provider_email="user@gmail.com")
mock_uow.integrations.get_by_provider = AsyncMock(return_value=integration)
mock_uow.integrations.get_secrets = AsyncMock(return_value={
"access_token": "expired-token",
"refresh_token": "refresh-token",
"token_type": "Bearer",
"expires_at": (datetime.now(UTC) - timedelta(hours=1)).isoformat(),
"scope": "calendar",
})
mock_uow.integrations.update = AsyncMock()
service = CalendarService(
uow_factory=lambda: mock_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
google_adapter=mock_google_adapter,
outlook_adapter=mock_outlook_adapter,
)
await service.list_calendar_events(provider="google")
mock_oauth_manager.refresh_tokens.assert_called_once()
mock_uow.integrations.set_secrets.assert_called()
@pytest.mark.asyncio
async def test_list_events_raises_when_not_connected(
self,
calendar_settings: CalendarSettings,
mock_oauth_manager: MagicMock,
mock_google_adapter: MagicMock,
mock_outlook_adapter: MagicMock,
mock_uow: MagicMock,
) -> None:
"""list_calendar_events should raise error when provider not connected."""
from noteflow.application.services.calendar_service import CalendarService, CalendarServiceError
mock_uow.integrations.get_by_provider = AsyncMock(return_value=None)
service = CalendarService(
uow_factory=lambda: mock_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
google_adapter=mock_google_adapter,
outlook_adapter=mock_outlook_adapter,
)
with pytest.raises(CalendarServiceError, match="not connected"):
await service.list_calendar_events(provider="google")

View File

@@ -83,5 +83,6 @@ def test_get_supported_formats_returns_names_and_extensions() -> None:
formats = {name.lower(): ext for name, ext in service.get_supported_formats()}
assert formats["markdown"] == ".md"
assert formats["html"] == ".html"
assert formats["markdown"] == ".md", "Markdown format should have .md extension"
assert formats["html"] == ".html", "HTML format should have .html extension"
assert formats["pdf"] == ".pdf", "PDF format should have .pdf extension"

View File

@@ -0,0 +1,364 @@
"""Tests for NER application service."""
from __future__ import annotations
from collections.abc import Sequence
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
import pytest
from noteflow.application.services.ner_service import ExtractionResult, NerService
from noteflow.domain.entities import Meeting, Segment
from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity
from noteflow.domain.value_objects import MeetingId
class MockNerEngine:
"""Mock NER engine for testing."""
def __init__(self, entities: list[NamedEntity] | None = None) -> None:
self._entities = entities or []
self._ready = False
self.extract_call_count = 0
def extract(self, text: str) -> list[NamedEntity]:
"""Extract entities from text (mock)."""
self._ready = True
self.extract_call_count += 1
return self._entities
def extract_from_segments(
self, segments: list[tuple[int, str]]
) -> list[NamedEntity]:
"""Extract entities from segments (mock)."""
self._ready = True
self.extract_call_count += 1
return self._entities
def is_ready(self) -> bool:
"""Check if engine is ready."""
return self._ready
def _create_entity(
text: str,
category: EntityCategory = EntityCategory.PERSON,
segment_ids: list[int] | None = None,
) -> NamedEntity:
"""Create a test entity."""
return NamedEntity(
text=text,
normalized_text=text.lower(),
category=category,
segment_ids=segment_ids or [0],
confidence=0.9,
)
def _create_segment(segment_id: int, text: str) -> Segment:
"""Create a test segment."""
return Segment(
segment_id=segment_id,
text=text,
start_time=segment_id * 5.0,
end_time=(segment_id + 1) * 5.0,
)
class TestNerServiceExtraction:
"""Tests for entity extraction."""
@pytest.fixture
def mock_uow_factory(self, mock_uow: MagicMock) -> MagicMock:
"""Create mock UoW factory."""
mock_uow.entities.get_by_meeting = AsyncMock(return_value=[])
mock_uow.entities.save_batch = AsyncMock(return_value=[])
mock_uow.entities.delete_by_meeting = AsyncMock(return_value=0)
mock_uow.segments.get_by_meeting = AsyncMock(return_value=[])
return MagicMock(return_value=mock_uow)
@pytest.fixture
def sample_meeting(self) -> Meeting:
"""Create a sample meeting with segments."""
meeting = Meeting.create(title="Test Meeting")
meeting.segments.extend([
_create_segment(0, "John talked to Mary."),
_create_segment(1, "They discussed the project."),
])
return meeting
@pytest.fixture
def mock_ner_engine(self) -> MockNerEngine:
"""Create mock NER engine with sample entities."""
return MockNerEngine(
entities=[
_create_entity("John", EntityCategory.PERSON, [0]),
_create_entity("Mary", EntityCategory.PERSON, [0]),
]
)
async def test_extract_entities_returns_result(
self,
mock_uow: MagicMock,
mock_uow_factory: MagicMock,
mock_ner_engine: MockNerEngine,
sample_meeting: Meeting,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Extract entities returns ExtractionResult."""
mock_uow.meetings.get = AsyncMock(return_value=sample_meeting)
mock_uow.entities.get_by_meeting = AsyncMock(return_value=[])
mock_uow.entities.save_batch = AsyncMock(return_value=[])
mock_uow.segments.get_by_meeting = AsyncMock(return_value=sample_meeting.segments)
# Mock feature flag
monkeypatch.setattr(
"noteflow.application.services.ner_service.get_settings",
lambda: MagicMock(feature_flags=MagicMock(ner_enabled=True)),
)
service = NerService(mock_ner_engine, mock_uow_factory)
result = await service.extract_entities(sample_meeting.id)
assert isinstance(result, ExtractionResult), "Should return ExtractionResult"
assert result.total_count == 2, "Should have 2 entities"
assert not result.cached, "Should not be cached on first extraction"
assert len(result.entities) == 2, "Entities list should match total_count"
async def test_extract_entities_uses_cache(
self,
mock_uow: MagicMock,
mock_uow_factory: MagicMock,
mock_ner_engine: MockNerEngine,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Returns cached entities without re-extraction."""
cached_entities = [_create_entity("Cached", EntityCategory.PERSON)]
mock_uow.entities.get_by_meeting = AsyncMock(return_value=cached_entities)
monkeypatch.setattr(
"noteflow.application.services.ner_service.get_settings",
lambda: MagicMock(feature_flags=MagicMock(ner_enabled=True)),
)
service = NerService(mock_ner_engine, mock_uow_factory)
result = await service.extract_entities(MeetingId(uuid4()))
assert result.cached
assert result.total_count == 1
assert mock_ner_engine.extract_call_count == 0
async def test_extract_entities_force_refresh_bypasses_cache(
self,
mock_uow: MagicMock,
mock_uow_factory: MagicMock,
mock_ner_engine: MockNerEngine,
sample_meeting: Meeting,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Force refresh re-extracts even with cached data."""
cached_entities = [_create_entity("Cached", EntityCategory.PERSON)]
mock_uow.meetings.get = AsyncMock(return_value=sample_meeting)
mock_uow.entities.get_by_meeting = AsyncMock(return_value=cached_entities)
mock_uow.entities.save_batch = AsyncMock(return_value=[])
mock_uow.entities.delete_by_meeting = AsyncMock(return_value=1)
mock_uow.segments.get_by_meeting = AsyncMock(return_value=sample_meeting.segments)
monkeypatch.setattr(
"noteflow.application.services.ner_service.get_settings",
lambda: MagicMock(feature_flags=MagicMock(ner_enabled=True)),
)
# Mark engine as ready to skip warmup extraction
mock_ner_engine._ready = True
initial_count = mock_ner_engine.extract_call_count
service = NerService(mock_ner_engine, mock_uow_factory)
result = await service.extract_entities(sample_meeting.id, force_refresh=True)
assert not result.cached
assert mock_ner_engine.extract_call_count == initial_count + 1
mock_uow.entities.delete_by_meeting.assert_called_once()
async def test_extract_entities_meeting_not_found(
self,
mock_uow: MagicMock,
mock_uow_factory: MagicMock,
mock_ner_engine: MockNerEngine,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Raises ValueError when meeting not found."""
mock_uow.entities.get_by_meeting = AsyncMock(return_value=[])
mock_uow.meetings.get = AsyncMock(return_value=None)
monkeypatch.setattr(
"noteflow.application.services.ner_service.get_settings",
lambda: MagicMock(feature_flags=MagicMock(ner_enabled=True)),
)
service = NerService(mock_ner_engine, mock_uow_factory)
meeting_id = MeetingId(uuid4())
with pytest.raises(ValueError, match=str(meeting_id)):
await service.extract_entities(meeting_id)
async def test_extract_entities_feature_disabled(
self,
mock_uow_factory: MagicMock,
mock_ner_engine: MockNerEngine,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Raises RuntimeError when feature flag is disabled."""
monkeypatch.setattr(
"noteflow.application.services.ner_service.get_settings",
lambda: MagicMock(feature_flags=MagicMock(ner_enabled=False)),
)
service = NerService(mock_ner_engine, mock_uow_factory)
with pytest.raises(RuntimeError, match="disabled"):
await service.extract_entities(MeetingId(uuid4()))
async def test_extract_entities_no_segments_returns_empty(
self,
mock_uow: MagicMock,
mock_uow_factory: MagicMock,
mock_ner_engine: MockNerEngine,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Returns empty result for meeting with no segments."""
meeting = Meeting.create(title="Empty Meeting")
mock_uow.meetings.get = AsyncMock(return_value=meeting)
mock_uow.entities.get_by_meeting = AsyncMock(return_value=[])
monkeypatch.setattr(
"noteflow.application.services.ner_service.get_settings",
lambda: MagicMock(feature_flags=MagicMock(ner_enabled=True)),
)
service = NerService(mock_ner_engine, mock_uow_factory)
result = await service.extract_entities(meeting.id)
assert result.total_count == 0
assert result.entities == []
assert not result.cached
class TestNerServicePinning:
"""Tests for entity pinning."""
async def test_pin_entity_success(self, mock_uow: MagicMock) -> None:
"""Pin entity updates database and returns True."""
mock_uow.entities.update_pinned = AsyncMock(return_value=True)
mock_uow_factory = MagicMock(return_value=mock_uow)
service = NerService(MockNerEngine(), mock_uow_factory)
entity_id = uuid4()
result = await service.pin_entity(entity_id, is_pinned=True)
assert result is True
mock_uow.entities.update_pinned.assert_called_once_with(entity_id, True)
mock_uow.commit.assert_called_once()
async def test_pin_entity_not_found(self, mock_uow: MagicMock) -> None:
"""Pin entity returns False when entity not found."""
mock_uow.entities.update_pinned = AsyncMock(return_value=False)
mock_uow_factory = MagicMock(return_value=mock_uow)
service = NerService(MockNerEngine(), mock_uow_factory)
entity_id = uuid4()
result = await service.pin_entity(entity_id)
assert result is False
mock_uow.commit.assert_not_called()
class TestNerServiceClear:
"""Tests for clearing entities."""
async def test_clear_entities(self, mock_uow: MagicMock) -> None:
"""Clear entities deletes all for meeting."""
mock_uow.entities.delete_by_meeting = AsyncMock(return_value=5)
mock_uow_factory = MagicMock(return_value=mock_uow)
service = NerService(MockNerEngine(), mock_uow_factory)
meeting_id = MeetingId(uuid4())
count = await service.clear_entities(meeting_id)
assert count == 5
mock_uow.entities.delete_by_meeting.assert_called_once_with(meeting_id)
mock_uow.commit.assert_called_once()
class TestNerServiceHelpers:
"""Tests for helper methods."""
async def test_get_entities(self, mock_uow: MagicMock) -> None:
"""Get entities returns cached entities without extraction."""
entities = [_create_entity("Test")]
mock_uow.entities.get_by_meeting = AsyncMock(return_value=entities)
mock_uow_factory = MagicMock(return_value=mock_uow)
service = NerService(MockNerEngine(), mock_uow_factory)
meeting_id = MeetingId(uuid4())
result = await service.get_entities(meeting_id)
assert result == entities
mock_uow.entities.get_by_meeting.assert_called_once_with(meeting_id)
async def test_has_entities_true(self, mock_uow: MagicMock) -> None:
"""Has entities returns True when entities exist."""
mock_uow.entities.exists_for_meeting = AsyncMock(return_value=True)
mock_uow_factory = MagicMock(return_value=mock_uow)
service = NerService(MockNerEngine(), mock_uow_factory)
meeting_id = MeetingId(uuid4())
result = await service.has_entities(meeting_id)
assert result is True
async def test_has_entities_false(self, mock_uow: MagicMock) -> None:
"""Has entities returns False when no entities exist."""
mock_uow.entities.exists_for_meeting = AsyncMock(return_value=False)
mock_uow_factory = MagicMock(return_value=mock_uow)
service = NerService(MockNerEngine(), mock_uow_factory)
meeting_id = MeetingId(uuid4())
result = await service.has_entities(meeting_id)
assert result is False
def test_is_engine_ready_checks_both_flag_and_engine(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Is engine ready checks both feature flag and engine state."""
engine = MockNerEngine()
service = NerService(engine, MagicMock())
# NER disabled - should be False regardless of engine state
monkeypatch.setattr(
"noteflow.application.services.ner_service.get_settings",
lambda: MagicMock(feature_flags=MagicMock(ner_enabled=False)),
)
engine._ready = True
assert not service.is_engine_ready(), "Should be False when NER disabled"
# NER enabled but engine not ready
monkeypatch.setattr(
"noteflow.application.services.ner_service.get_settings",
lambda: MagicMock(feature_flags=MagicMock(ner_enabled=True)),
)
engine._ready = False
assert not service.is_engine_ready(), "Should be False when engine not ready"
# NER enabled and engine ready
engine._ready = True
assert service.is_engine_ready(), "Should be True when both conditions met"

View File

@@ -136,6 +136,7 @@ def mock_uow() -> MagicMock:
uow.annotations = MagicMock()
uow.preferences = MagicMock()
uow.diarization_jobs = MagicMock()
uow.entities = MagicMock()
return uow

View File

@@ -211,31 +211,31 @@ class TestMeetingProperties:
meeting.started_at = utc_now() - timedelta(seconds=5)
assert meeting.duration_seconds >= 5.0
def test_is_active_created(self) -> None:
"""Test is_active returns True for CREATED state."""
def test_is_in_active_state_created(self) -> None:
"""Test is_in_active_state returns True for CREATED state."""
meeting = Meeting.create()
assert meeting.is_active() is True
assert meeting.is_in_active_state() is True
def test_is_active_recording(self) -> None:
"""Test is_active returns True for RECORDING state."""
def test_is_in_active_state_recording(self) -> None:
"""Test is_in_active_state returns True for RECORDING state."""
meeting = Meeting.create()
meeting.start_recording()
assert meeting.is_active() is True
assert meeting.is_in_active_state() is True
def test_is_active_stopping(self) -> None:
"""Test is_active returns False for STOPPING state."""
def test_is_in_active_state_stopping(self) -> None:
"""Test is_in_active_state returns False for STOPPING state."""
meeting = Meeting.create()
meeting.start_recording()
meeting.begin_stopping()
assert meeting.is_active() is False
assert meeting.is_in_active_state() is False
def test_is_active_stopped(self) -> None:
"""Test is_active returns False for STOPPED state."""
def test_is_in_active_state_stopped(self) -> None:
"""Test is_in_active_state returns False for STOPPED state."""
meeting = Meeting.create()
meeting.start_recording()
meeting.begin_stopping()
meeting.stop_recording()
assert meeting.is_active() is False
assert meeting.is_in_active_state() is False
def test_has_summary_false(self) -> None:
"""Test has_summary returns False when no summary."""

View File

@@ -0,0 +1,271 @@
"""Tests for NamedEntity domain entity."""
from __future__ import annotations
from uuid import uuid4
import pytest
from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity
from noteflow.domain.value_objects import MeetingId
class TestEntityCategory:
"""Tests for EntityCategory enum."""
@pytest.mark.parametrize(
("value", "expected"),
[
("person", EntityCategory.PERSON),
("company", EntityCategory.COMPANY),
("product", EntityCategory.PRODUCT),
("technical", EntityCategory.TECHNICAL),
("acronym", EntityCategory.ACRONYM),
("location", EntityCategory.LOCATION),
("date", EntityCategory.DATE),
("other", EntityCategory.OTHER),
],
)
def test_from_string_valid_values(
self, value: str, expected: EntityCategory
) -> None:
"""Convert lowercase string to EntityCategory."""
assert EntityCategory.from_string(value) == expected
@pytest.mark.parametrize("value", ["PERSON", "Person", "COMPANY"])
def test_from_string_case_insensitive(self, value: str) -> None:
"""Convert mixed case string to EntityCategory."""
result = EntityCategory.from_string(value)
assert result in EntityCategory
def test_from_string_invalid_raises(self) -> None:
"""Invalid category string raises ValueError."""
with pytest.raises(ValueError, match="Invalid entity category"):
EntityCategory.from_string("invalid_category")
class TestNamedEntityValidation:
"""Tests for NamedEntity validation in __post_init__."""
@pytest.mark.parametrize("confidence", [-0.1, 1.1, 2.0, -1.0])
def test_invalid_confidence_raises(self, confidence: float) -> None:
"""Confidence outside 0-1 range raises ValueError."""
with pytest.raises(ValueError, match="confidence must be between 0 and 1"):
NamedEntity(
text="John",
category=EntityCategory.PERSON,
confidence=confidence,
)
@pytest.mark.parametrize("confidence", [0.0, 0.5, 1.0, 0.95])
def test_valid_confidence_boundaries(self, confidence: float) -> None:
"""Confidence at valid boundaries is accepted."""
entity = NamedEntity(
text="John",
category=EntityCategory.PERSON,
confidence=confidence,
)
assert entity.confidence == confidence
def test_auto_computes_normalized_text(self) -> None:
"""Normalized text is auto-computed from text when not provided."""
entity = NamedEntity(
text="John SMITH",
category=EntityCategory.PERSON,
confidence=0.9,
)
assert entity.normalized_text == "john smith"
def test_preserves_explicit_normalized_text(self) -> None:
"""Explicit normalized_text is preserved."""
entity = NamedEntity(
text="John Smith",
normalized_text="custom_normalization",
category=EntityCategory.PERSON,
confidence=0.9,
)
assert entity.normalized_text == "custom_normalization"
class TestNamedEntityCreate:
"""Tests for NamedEntity.create() factory method."""
def test_create_with_valid_input(self) -> None:
"""Create entity with valid input returns properly initialized entity."""
meeting_id = MeetingId(uuid4())
entity = NamedEntity.create(
text="Acme Corporation",
category=EntityCategory.COMPANY,
segment_ids=[0, 2, 1],
confidence=0.85,
meeting_id=meeting_id,
)
assert entity.text == "Acme Corporation", "Text should be preserved"
assert entity.normalized_text == "acme corporation", "Normalized text should be lowercase"
assert entity.category == EntityCategory.COMPANY, "Category should be preserved"
assert entity.segment_ids == [0, 1, 2], "Segment IDs should be sorted"
assert entity.confidence == 0.85, "Confidence should be preserved"
assert entity.meeting_id == meeting_id, "Meeting ID should be preserved"
def test_create_strips_whitespace(self) -> None:
"""Create strips leading/trailing whitespace from text."""
entity = NamedEntity.create(
text=" John Smith ",
category=EntityCategory.PERSON,
segment_ids=[0],
confidence=0.9,
)
assert entity.text == "John Smith"
assert entity.normalized_text == "john smith"
def test_create_deduplicates_segment_ids(self) -> None:
"""Create deduplicates and sorts segment IDs."""
entity = NamedEntity.create(
text="Test",
category=EntityCategory.OTHER,
segment_ids=[3, 1, 1, 3, 2],
confidence=0.8,
)
assert entity.segment_ids == [1, 2, 3]
def test_create_empty_text_raises(self) -> None:
"""Create with empty text raises ValueError."""
with pytest.raises(ValueError, match="Entity text cannot be empty"):
NamedEntity.create(
text="",
category=EntityCategory.PERSON,
segment_ids=[0],
confidence=0.9,
)
def test_create_whitespace_only_text_raises(self) -> None:
"""Create with whitespace-only text raises ValueError."""
with pytest.raises(ValueError, match="Entity text cannot be empty"):
NamedEntity.create(
text=" ",
category=EntityCategory.PERSON,
segment_ids=[0],
confidence=0.9,
)
def test_create_invalid_confidence_raises(self) -> None:
"""Create with invalid confidence raises ValueError."""
with pytest.raises(ValueError, match="confidence must be between 0 and 1"):
NamedEntity.create(
text="John",
category=EntityCategory.PERSON,
segment_ids=[0],
confidence=1.5,
)
class TestNamedEntityOccurrenceCount:
"""Tests for occurrence_count property."""
def test_occurrence_count_with_segments(self) -> None:
"""Occurrence count returns number of unique segment IDs."""
entity = NamedEntity(
text="Test",
category=EntityCategory.OTHER,
segment_ids=[0, 1, 2],
confidence=0.8,
)
assert entity.occurrence_count == 3
def test_occurrence_count_empty_segments(self) -> None:
"""Occurrence count returns 0 for empty segment_ids."""
entity = NamedEntity(
text="Test",
category=EntityCategory.OTHER,
segment_ids=[],
confidence=0.8,
)
assert entity.occurrence_count == 0
def test_occurrence_count_single_segment(self) -> None:
"""Occurrence count returns 1 for single segment."""
entity = NamedEntity(
text="Test",
category=EntityCategory.OTHER,
segment_ids=[5],
confidence=0.8,
)
assert entity.occurrence_count == 1
class TestNamedEntityMergeSegments:
"""Tests for merge_segments method."""
def test_merge_segments_adds_new(self) -> None:
"""Merge segments adds new segment IDs."""
entity = NamedEntity(
text="John",
category=EntityCategory.PERSON,
segment_ids=[0, 1],
confidence=0.9,
)
entity.merge_segments([3, 4])
assert entity.segment_ids == [0, 1, 3, 4]
def test_merge_segments_deduplicates(self) -> None:
"""Merge segments deduplicates overlapping IDs."""
entity = NamedEntity(
text="John",
category=EntityCategory.PERSON,
segment_ids=[0, 1, 2],
confidence=0.9,
)
entity.merge_segments([1, 2, 3])
assert entity.segment_ids == [0, 1, 2, 3]
def test_merge_segments_sorts(self) -> None:
"""Merge segments keeps result sorted."""
entity = NamedEntity(
text="John",
category=EntityCategory.PERSON,
segment_ids=[5, 10],
confidence=0.9,
)
entity.merge_segments([1, 3])
assert entity.segment_ids == [1, 3, 5, 10]
def test_merge_empty_segments(self) -> None:
"""Merge with empty list preserves original segments."""
entity = NamedEntity(
text="John",
category=EntityCategory.PERSON,
segment_ids=[0, 1],
confidence=0.9,
)
entity.merge_segments([])
assert entity.segment_ids == [0, 1]
class TestNamedEntityDefaults:
"""Tests for NamedEntity default values."""
def test_default_meeting_id_is_none(self) -> None:
"""Default meeting_id is None."""
entity = NamedEntity(text="Test", category=EntityCategory.OTHER, confidence=0.5)
assert entity.meeting_id is None
def test_default_segment_ids_is_empty(self) -> None:
"""Default segment_ids is empty list."""
entity = NamedEntity(text="Test", category=EntityCategory.OTHER, confidence=0.5)
assert entity.segment_ids == []
def test_default_is_pinned_is_false(self) -> None:
"""Default is_pinned is False."""
entity = NamedEntity(text="Test", category=EntityCategory.OTHER, confidence=0.5)
assert entity.is_pinned is False
def test_default_db_id_is_none(self) -> None:
"""Default db_id is None."""
entity = NamedEntity(text="Test", category=EntityCategory.OTHER, confidence=0.5)
assert entity.db_id is None
def test_id_is_auto_generated(self) -> None:
"""UUID id is auto-generated."""
entity = NamedEntity(text="Test", category=EntityCategory.OTHER, confidence=0.5)
assert entity.id is not None

View File

@@ -0,0 +1 @@
"""Tests for calendar infrastructure."""

View File

@@ -0,0 +1,269 @@
"""Tests for Google Calendar adapter."""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
class TestGoogleCalendarAdapterListEvents:
"""Tests for GoogleCalendarAdapter.list_events."""
@pytest.mark.asyncio
async def test_list_events_returns_calendar_events(self) -> None:
"""list_events should return parsed calendar events."""
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
adapter = GoogleCalendarAdapter()
now = datetime.now(UTC)
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"items": [
{
"id": "event-1",
"summary": "Team Standup",
"start": {"dateTime": (now + timedelta(hours=1)).isoformat()},
"end": {"dateTime": (now + timedelta(hours=2)).isoformat()},
"attendees": [
{"email": "alice@example.com"},
{"email": "bob@example.com"},
],
"hangoutLink": "https://meet.google.com/abc-defg-hij",
},
{
"id": "event-2",
"summary": "All-Day Planning",
"start": {"date": now.strftime("%Y-%m-%d")},
"end": {"date": (now + timedelta(days=1)).strftime("%Y-%m-%d")},
},
]
}
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_response
events = await adapter.list_events("access-token", hours_ahead=24, limit=10)
assert len(events) == 2
assert events[0].title == "Team Standup"
assert events[0].attendees == ("alice@example.com", "bob@example.com")
assert events[0].meeting_url == "https://meet.google.com/abc-defg-hij"
assert events[0].provider == "google"
assert events[1].title == "All-Day Planning"
assert events[1].is_all_day is True
@pytest.mark.asyncio
async def test_list_events_handles_empty_response(self) -> None:
"""list_events should return empty list when no events."""
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
adapter = GoogleCalendarAdapter()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"items": []}
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_response
events = await adapter.list_events("access-token")
assert events == []
@pytest.mark.asyncio
async def test_list_events_raises_on_expired_token(self) -> None:
"""list_events should raise error on 401 response."""
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
from noteflow.infrastructure.calendar.google_adapter import GoogleCalendarError
adapter = GoogleCalendarAdapter()
mock_response = MagicMock()
mock_response.status_code = 401
mock_response.text = "Token expired"
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_response
with pytest.raises(GoogleCalendarError, match="expired or invalid"):
await adapter.list_events("expired-token")
@pytest.mark.asyncio
async def test_list_events_raises_on_api_error(self) -> None:
"""list_events should raise error on non-200 response."""
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
from noteflow.infrastructure.calendar.google_adapter import GoogleCalendarError
adapter = GoogleCalendarAdapter()
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.text = "Internal server error"
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_response
with pytest.raises(GoogleCalendarError, match="API error"):
await adapter.list_events("access-token")
@pytest.mark.asyncio
async def test_list_events_parses_conference_data(self) -> None:
"""list_events should extract meeting URL from conferenceData."""
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
adapter = GoogleCalendarAdapter()
now = datetime.now(UTC)
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"items": [
{
"id": "event-zoom",
"summary": "Zoom Meeting",
"start": {"dateTime": now.isoformat()},
"end": {"dateTime": (now + timedelta(hours=1)).isoformat()},
"conferenceData": {
"entryPoints": [
{"entryPointType": "video", "uri": "https://zoom.us/j/123456"},
{"entryPointType": "phone", "uri": "tel:+1234567890"},
]
},
}
]
}
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_response
events = await adapter.list_events("access-token")
assert events[0].meeting_url == "https://zoom.us/j/123456"
class TestGoogleCalendarAdapterGetUserEmail:
"""Tests for GoogleCalendarAdapter.get_user_email."""
@pytest.mark.asyncio
async def test_get_user_email_returns_email(self) -> None:
"""get_user_email should return user's email address."""
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
adapter = GoogleCalendarAdapter()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"email": "user@example.com",
"name": "Test User",
}
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_response
email = await adapter.get_user_email("access-token")
assert email == "user@example.com"
@pytest.mark.asyncio
async def test_get_user_email_raises_on_missing_email(self) -> None:
"""get_user_email should raise when email not in response."""
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
from noteflow.infrastructure.calendar.google_adapter import GoogleCalendarError
adapter = GoogleCalendarAdapter()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"name": "No Email User"}
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_response
with pytest.raises(GoogleCalendarError, match="No email"):
await adapter.get_user_email("access-token")
class TestGoogleCalendarAdapterDateParsing:
"""Tests for date/time parsing in GoogleCalendarAdapter."""
@pytest.mark.asyncio
async def test_parses_utc_datetime_with_z_suffix(self) -> None:
"""Should parse datetime with Z suffix correctly."""
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
adapter = GoogleCalendarAdapter()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"items": [
{
"id": "event-utc",
"summary": "UTC Event",
"start": {"dateTime": "2024-03-15T10:00:00Z"},
"end": {"dateTime": "2024-03-15T11:00:00Z"},
}
]
}
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_response
events = await adapter.list_events("access-token")
assert events[0].start_time.tzinfo is not None
assert events[0].start_time.hour == 10
@pytest.mark.asyncio
async def test_parses_datetime_with_offset(self) -> None:
"""Should parse datetime with timezone offset correctly."""
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
adapter = GoogleCalendarAdapter()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"items": [
{
"id": "event-offset",
"summary": "Offset Event",
"start": {"dateTime": "2024-03-15T10:00:00-08:00"},
"end": {"dateTime": "2024-03-15T11:00:00-08:00"},
}
]
}
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_response
events = await adapter.list_events("access-token")
assert events[0].start_time.tzinfo is not None
@pytest.mark.asyncio
async def test_identifies_recurring_events(self) -> None:
"""Should identify recurring events via recurringEventId."""
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
adapter = GoogleCalendarAdapter()
now = datetime.now(UTC)
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"items": [
{
"id": "event-instance",
"summary": "Weekly Meeting",
"start": {"dateTime": now.isoformat()},
"end": {"dateTime": (now + timedelta(hours=1)).isoformat()},
"recurringEventId": "event-master",
}
]
}
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_response
events = await adapter.list_events("access-token")
assert events[0].is_recurring is True

View File

@@ -0,0 +1,266 @@
"""Tests for OAuth manager."""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from noteflow.config.settings import CalendarSettings
from noteflow.domain.value_objects import OAuthProvider
from noteflow.infrastructure.calendar.oauth_manager import OAuthError
@pytest.fixture
def calendar_settings() -> CalendarSettings:
"""Create test calendar settings."""
return CalendarSettings(
google_client_id="test-google-client-id",
google_client_secret="test-google-client-secret",
outlook_client_id="test-outlook-client-id",
outlook_client_secret="test-outlook-client-secret",
)
class TestOAuthManagerInitiateAuth:
"""Tests for OAuthManager.initiate_auth."""
def test_initiate_google_auth_returns_url_and_state(
self, calendar_settings: CalendarSettings
) -> None:
"""initiate_auth should return auth URL and state for Google."""
from noteflow.infrastructure.calendar import OAuthManager
manager = OAuthManager(calendar_settings)
auth_url, state = manager.initiate_auth(OAuthProvider.GOOGLE, "http://localhost:8080/callback")
assert auth_url.startswith("https://accounts.google.com/o/oauth2/v2/auth")
assert "client_id=test-google-client-id" in auth_url
assert "redirect_uri=http" in auth_url
assert "code_challenge=" in auth_url
assert "code_challenge_method=S256" in auth_url
assert len(state) > 0
def test_initiate_outlook_auth_returns_url_and_state(
self, calendar_settings: CalendarSettings
) -> None:
"""initiate_auth should return auth URL and state for Outlook."""
from noteflow.infrastructure.calendar import OAuthManager
manager = OAuthManager(calendar_settings)
auth_url, state = manager.initiate_auth(OAuthProvider.OUTLOOK, "http://localhost:8080/callback")
assert auth_url.startswith("https://login.microsoftonline.com/common/oauth2/v2.0/authorize")
assert "client_id=test-outlook-client-id" in auth_url
assert len(state) > 0
def test_initiate_auth_stores_pending_state(
self, calendar_settings: CalendarSettings
) -> None:
"""initiate_auth should store pending state for later validation."""
from noteflow.infrastructure.calendar import OAuthManager
manager = OAuthManager(calendar_settings)
_, state = manager.initiate_auth(OAuthProvider.GOOGLE, "http://localhost:8080/callback")
assert state in manager._pending_states
assert manager._pending_states[state].provider == OAuthProvider.GOOGLE
def test_initiate_auth_missing_credentials_raises(
self, calendar_settings: CalendarSettings
) -> None:
"""initiate_auth should raise for missing OAuth credentials."""
from noteflow.infrastructure.calendar import OAuthManager
# Create settings with missing Google credentials
settings = CalendarSettings(
google_client_id="",
google_client_secret="",
outlook_client_id="test-id",
outlook_client_secret="test-secret",
)
manager = OAuthManager(settings)
with pytest.raises(OAuthError, match="not configured"):
manager.initiate_auth(OAuthProvider.GOOGLE, "http://localhost:8080/callback")
class TestOAuthManagerCompleteAuth:
"""Tests for OAuthManager.complete_auth."""
@pytest.mark.asyncio
async def test_complete_auth_exchanges_code_for_tokens(
self, calendar_settings: CalendarSettings
) -> None:
"""complete_auth should exchange authorization code for tokens."""
from noteflow.infrastructure.calendar import OAuthManager
manager = OAuthManager(calendar_settings)
_, state = manager.initiate_auth(OAuthProvider.GOOGLE, "http://localhost:8080/callback")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"access_token": "access-token-123",
"refresh_token": "refresh-token-456",
"token_type": "Bearer",
"expires_in": 3600,
"scope": "email calendar.readonly",
}
with patch("httpx.AsyncClient.post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
tokens = await manager.complete_auth(OAuthProvider.GOOGLE, "auth-code-xyz", state)
assert tokens.access_token == "access-token-123"
assert tokens.refresh_token == "refresh-token-456"
assert tokens.token_type == "Bearer"
@pytest.mark.asyncio
async def test_complete_auth_invalid_state_raises(
self, calendar_settings: CalendarSettings
) -> None:
"""complete_auth should raise for invalid state."""
from noteflow.infrastructure.calendar import OAuthManager
manager = OAuthManager(calendar_settings)
with pytest.raises(OAuthError, match="Invalid or expired"):
await manager.complete_auth(OAuthProvider.GOOGLE, "code", "invalid-state")
@pytest.mark.asyncio
async def test_complete_auth_expired_state_raises(
self, calendar_settings: CalendarSettings
) -> None:
"""complete_auth should raise for expired state."""
from noteflow.infrastructure.calendar import OAuthManager
manager = OAuthManager(calendar_settings)
_, state = manager.initiate_auth(OAuthProvider.GOOGLE, "http://localhost:8080/callback")
# Expire the state manually by replacing the OAuthState
old_state = manager._pending_states[state]
from noteflow.domain.value_objects import OAuthState
expired_state = OAuthState(
state=old_state.state,
provider=old_state.provider,
redirect_uri=old_state.redirect_uri,
code_verifier=old_state.code_verifier,
created_at=old_state.created_at,
expires_at=datetime.now(UTC) - timedelta(minutes=1),
)
manager._pending_states[state] = expired_state
with pytest.raises(OAuthError, match="expired"):
await manager.complete_auth(OAuthProvider.GOOGLE, "code", state)
@pytest.mark.asyncio
async def test_complete_auth_removes_used_state(
self, calendar_settings: CalendarSettings
) -> None:
"""complete_auth should remove state after successful use."""
from noteflow.infrastructure.calendar import OAuthManager
manager = OAuthManager(calendar_settings)
_, state = manager.initiate_auth(OAuthProvider.GOOGLE, "http://localhost:8080/callback")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"access_token": "token",
"token_type": "Bearer",
"expires_in": 3600,
}
with patch("httpx.AsyncClient.post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
await manager.complete_auth(OAuthProvider.GOOGLE, "code", state)
assert state not in manager._pending_states
class TestOAuthManagerRefreshTokens:
"""Tests for OAuthManager.refresh_tokens."""
@pytest.mark.asyncio
async def test_refresh_tokens_returns_new_tokens(
self, calendar_settings: CalendarSettings
) -> None:
"""refresh_tokens should return new access token."""
from noteflow.infrastructure.calendar import OAuthManager
manager = OAuthManager(calendar_settings)
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"access_token": "new-access-token",
"token_type": "Bearer",
"expires_in": 3600,
}
with patch("httpx.AsyncClient.post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
tokens = await manager.refresh_tokens(OAuthProvider.GOOGLE, "old-refresh-token")
assert tokens.access_token == "new-access-token"
@pytest.mark.asyncio
async def test_refresh_tokens_preserves_refresh_token_if_not_returned(
self, calendar_settings: CalendarSettings
) -> None:
"""refresh_tokens should preserve old refresh token if not in response."""
from noteflow.infrastructure.calendar import OAuthManager
manager = OAuthManager(calendar_settings)
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"access_token": "new-access-token",
"token_type": "Bearer",
"expires_in": 3600,
}
with patch("httpx.AsyncClient.post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
tokens = await manager.refresh_tokens(OAuthProvider.GOOGLE, "old-refresh-token")
assert tokens.refresh_token == "old-refresh-token"
class TestOAuthManagerRevokeTokens:
"""Tests for OAuthManager.revoke_tokens."""
@pytest.mark.asyncio
async def test_revoke_google_tokens_calls_revocation_endpoint(
self, calendar_settings: CalendarSettings
) -> None:
"""revoke_tokens should call Google revocation endpoint."""
from noteflow.infrastructure.calendar import OAuthManager
manager = OAuthManager(calendar_settings)
mock_response = MagicMock()
mock_response.status_code = 200
with patch("httpx.AsyncClient.post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
await manager.revoke_tokens(OAuthProvider.GOOGLE, "access-token")
mock_post.assert_called_once()
call_args = mock_post.call_args
assert "oauth2.googleapis.com/revoke" in call_args[0][0]
@pytest.mark.asyncio
async def test_revoke_outlook_tokens_handles_no_revocation_endpoint(
self, calendar_settings: CalendarSettings
) -> None:
"""revoke_tokens should succeed silently for Outlook (no revocation endpoint)."""
from noteflow.infrastructure.calendar import OAuthManager
manager = OAuthManager(calendar_settings)
# Should not raise - Outlook doesn't have a revocation endpoint
await manager.revoke_tokens(OAuthProvider.OUTLOOK, "access-token")

View File

@@ -0,0 +1 @@
"""NER infrastructure tests."""

View File

@@ -0,0 +1,127 @@
"""Tests for NER engine (spaCy wrapper)."""
from __future__ import annotations
import pytest
from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity
from noteflow.infrastructure.ner import NerEngine
@pytest.fixture(scope="module")
def ner_engine() -> NerEngine:
"""Create NER engine (module-scoped to avoid repeated model loads)."""
engine = NerEngine(model_name="en_core_web_sm")
return engine
class TestNerEngineBasics:
"""Basic NER engine functionality tests."""
def test_is_ready_before_load(self) -> None:
"""Engine is not ready before first use."""
engine = NerEngine()
assert not engine.is_ready()
def test_is_ready_after_extract(self, ner_engine: NerEngine) -> None:
"""Engine is ready after extraction triggers lazy load."""
ner_engine.extract("Hello, John.")
assert ner_engine.is_ready()
class TestEntityExtraction:
"""Tests for entity extraction from text."""
@pytest.mark.parametrize(
("text", "expected_category", "expected_text"),
[
("I met John Smith yesterday.", EntityCategory.PERSON, "John Smith"),
("Apple Inc. announced new products.", EntityCategory.COMPANY, "Apple Inc."),
("We visited New York City.", EntityCategory.LOCATION, "New York City"),
("The meeting is scheduled for Monday.", EntityCategory.DATE, "Monday"),
],
)
def test_extract_entity_types(
self,
ner_engine: NerEngine,
text: str,
expected_category: EntityCategory,
expected_text: str,
) -> None:
"""Extract entities of various types."""
entities = ner_engine.extract(text)
matching = [e for e in entities if e.category == expected_category]
assert matching, f"No {expected_category.value} entities found in: {text}"
texts = [e.text for e in matching]
assert expected_text in texts, f"Expected '{expected_text}' not found in {texts}"
def test_extract_returns_list(self, ner_engine: NerEngine) -> None:
"""Extract returns a list of NamedEntity objects."""
entities = ner_engine.extract("John works at Google.")
assert isinstance(entities, list)
assert all(isinstance(e, NamedEntity) for e in entities)
def test_extract_empty_text_returns_empty(self, ner_engine: NerEngine) -> None:
"""Empty text returns empty list."""
entities = ner_engine.extract("")
assert entities == []
def test_extract_no_entities_returns_empty(self, ner_engine: NerEngine) -> None:
"""Text with no entities returns empty list."""
entities = ner_engine.extract("The quick brown fox.")
# May still find entities depending on model, but should not crash
assert isinstance(entities, list)
class TestSegmentExtraction:
"""Tests for extraction from multiple segments."""
def test_extract_from_segments_tracks_segment_ids(self, ner_engine: NerEngine) -> None:
"""Segment extraction tracks which segments contain each entity."""
segments = [
(0, "John talked to Mary."),
(1, "Mary went to the store."),
(2, "Then John called Mary again."),
]
entities = ner_engine.extract_from_segments(segments)
# Mary should appear in multiple segments
mary_entities = [e for e in entities if "mary" in e.normalized_text]
assert mary_entities, "Mary entity not found"
mary = mary_entities[0]
assert len(mary.segment_ids) >= 2, "Mary should appear in multiple segments"
def test_extract_from_segments_deduplicates(self, ner_engine: NerEngine) -> None:
"""Same entity in multiple segments is deduplicated."""
segments = [
(0, "John Smith is here."),
(1, "John Smith left."),
]
entities = ner_engine.extract_from_segments(segments)
# Should have one John Smith entity (deduplicated by normalized text)
john_entities = [e for e in entities if "john" in e.normalized_text]
assert len(john_entities) == 1, "John Smith should be deduplicated"
assert len(john_entities[0].segment_ids) == 2, "Should track both segments"
def test_extract_from_segments_empty_returns_empty(self, ner_engine: NerEngine) -> None:
"""Empty segments list returns empty entities."""
entities = ner_engine.extract_from_segments([])
assert entities == []
class TestEntityNormalization:
"""Tests for entity text normalization."""
def test_normalized_text_is_lowercase(self, ner_engine: NerEngine) -> None:
"""Normalized text should be lowercase."""
entities = ner_engine.extract("John SMITH went to NYC.")
for entity in entities:
assert entity.normalized_text == entity.normalized_text.lower()
def test_confidence_is_set(self, ner_engine: NerEngine) -> None:
"""Entities should have confidence score."""
entities = ner_engine.extract("Microsoft Corporation is based in Seattle.")
assert entities, "Should find entities"
for entity in entities:
assert 0.0 <= entity.confidence <= 1.0, "Confidence should be between 0 and 1"

View File

@@ -2,11 +2,17 @@
from __future__ import annotations
from unittest.mock import MagicMock
from uuid import uuid4
import pytest
from noteflow.domain import entities
from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity
from noteflow.domain.value_objects import MeetingId
from noteflow.infrastructure.asr import dto
from noteflow.infrastructure.converters import AsrConverter, OrmConverter
from noteflow.infrastructure.converters.ner_converters import NerConverter
class TestAsrConverter:
@@ -125,3 +131,161 @@ class TestOrmConverterToOrmKwargs:
assert result["end_time"] == 0.987654321
assert result["probability"] == 0.111111
assert result["word_index"] == 5
class TestNerConverterToOrmKwargs:
"""Tests for NerConverter.to_orm_kwargs."""
def test_converts_domain_to_orm_kwargs(self) -> None:
"""Convert domain NamedEntity to ORM kwargs dict."""
meeting_id = MeetingId(uuid4())
entity = NamedEntity(
id=uuid4(),
meeting_id=meeting_id,
text="Acme Corp",
normalized_text="acme corp",
category=EntityCategory.COMPANY,
segment_ids=[0, 1, 2],
confidence=0.95,
is_pinned=True,
)
result = NerConverter.to_orm_kwargs(entity)
assert result["id"] == entity.id, "ID should be preserved"
assert result["meeting_id"] == meeting_id, "Meeting ID should be preserved"
assert result["text"] == "Acme Corp", "Text should be preserved"
assert result["normalized_text"] == "acme corp", "Normalized text should be preserved"
assert result["category"] == "company", "Category should be string value"
assert result["segment_ids"] == [0, 1, 2], "Segment IDs should be preserved"
assert result["confidence"] == 0.95, "Confidence should be preserved"
assert result["is_pinned"] is True, "is_pinned should be preserved"
def test_converts_category_to_string_value(self) -> None:
"""Category enum is converted to its string value."""
entity = NamedEntity(
text="John",
category=EntityCategory.PERSON,
confidence=0.8,
)
result = NerConverter.to_orm_kwargs(entity)
assert result["category"] == "person"
assert isinstance(result["category"], str)
@pytest.mark.parametrize(
("category", "expected_value"),
[
(EntityCategory.PERSON, "person"),
(EntityCategory.COMPANY, "company"),
(EntityCategory.PRODUCT, "product"),
(EntityCategory.LOCATION, "location"),
(EntityCategory.DATE, "date"),
(EntityCategory.OTHER, "other"),
],
)
def test_all_category_values_convert(
self, category: EntityCategory, expected_value: str
) -> None:
"""All category enum values convert to correct string."""
entity = NamedEntity(text="Test", category=category, confidence=0.5)
result = NerConverter.to_orm_kwargs(entity)
assert result["category"] == expected_value
class TestNerConverterOrmToDomain:
"""Tests for NerConverter.orm_to_domain."""
@pytest.fixture
def mock_orm_model(self) -> MagicMock:
"""Create mock ORM model with typical values."""
model = MagicMock()
model.id = uuid4()
model.meeting_id = uuid4()
model.text = "Google Inc."
model.normalized_text = "google inc."
model.category = "company"
model.segment_ids = [1, 3, 5]
model.confidence = 0.92
model.is_pinned = False
return model
def test_converts_orm_to_domain_entity(self, mock_orm_model: MagicMock) -> None:
"""Convert ORM model to domain NamedEntity."""
result = NerConverter.orm_to_domain(mock_orm_model)
assert isinstance(result, NamedEntity), "Should return NamedEntity instance"
assert result.id == mock_orm_model.id, "ID should match ORM model"
assert result.meeting_id == MeetingId(mock_orm_model.meeting_id), "Meeting ID should match"
assert result.text == "Google Inc.", "Text should match ORM model"
assert result.normalized_text == "google inc.", "Normalized text should match"
assert result.category == EntityCategory.COMPANY, "Category should convert to enum"
assert result.segment_ids == [1, 3, 5], "Segment IDs should match"
assert result.confidence == 0.92, "Confidence should match"
assert result.is_pinned is False, "is_pinned should match"
def test_sets_db_id_from_orm_id(self, mock_orm_model: MagicMock) -> None:
"""Domain entity db_id is set from ORM id."""
result = NerConverter.orm_to_domain(mock_orm_model)
assert result.db_id == mock_orm_model.id
def test_converts_category_string_to_enum(self, mock_orm_model: MagicMock) -> None:
"""Category string from ORM is converted to EntityCategory enum."""
mock_orm_model.category = "person"
result = NerConverter.orm_to_domain(mock_orm_model)
assert result.category == EntityCategory.PERSON
assert isinstance(result.category, EntityCategory)
def test_handles_none_segment_ids(self, mock_orm_model: MagicMock) -> None:
"""Null segment_ids in ORM becomes empty list in domain."""
mock_orm_model.segment_ids = None
result = NerConverter.orm_to_domain(mock_orm_model)
assert result.segment_ids == []
def test_handles_empty_segment_ids(self, mock_orm_model: MagicMock) -> None:
"""Empty segment_ids in ORM becomes empty list in domain."""
mock_orm_model.segment_ids = []
result = NerConverter.orm_to_domain(mock_orm_model)
assert result.segment_ids == []
class TestNerConverterRoundTrip:
"""Tests for round-trip conversion fidelity."""
def test_domain_to_orm_to_domain_preserves_values(self) -> None:
"""Round-trip conversion preserves all field values."""
original = NamedEntity(
id=uuid4(),
meeting_id=MeetingId(uuid4()),
text="Microsoft Corporation",
normalized_text="microsoft corporation",
category=EntityCategory.COMPANY,
segment_ids=[0, 5, 10],
confidence=0.88,
is_pinned=True,
)
# Convert to ORM kwargs and simulate ORM model
orm_kwargs = NerConverter.to_orm_kwargs(original)
mock_orm = MagicMock()
mock_orm.id = orm_kwargs["id"]
mock_orm.meeting_id = orm_kwargs["meeting_id"]
mock_orm.text = orm_kwargs["text"]
mock_orm.normalized_text = orm_kwargs["normalized_text"]
mock_orm.category = orm_kwargs["category"]
mock_orm.segment_ids = orm_kwargs["segment_ids"]
mock_orm.confidence = orm_kwargs["confidence"]
mock_orm.is_pinned = orm_kwargs["is_pinned"]
# Convert back to domain
result = NerConverter.orm_to_domain(mock_orm)
assert result.id == original.id, "ID preserved through round-trip"
assert result.meeting_id == original.meeting_id, "Meeting ID preserved"
assert result.text == original.text, "Text preserved"
assert result.normalized_text == original.normalized_text, "Normalized text preserved"
assert result.category == original.category, "Category preserved"
assert result.segment_ids == original.segment_ids, "Segment IDs preserved"
assert result.confidence == original.confidence, "Confidence preserved"
assert result.is_pinned == original.is_pinned, "is_pinned preserved"

View File

@@ -5,6 +5,7 @@ Tests the complete export workflow with database:
- gRPC ExportTranscript with database
- Markdown export format
- HTML export format
- PDF export format
- File export with correct extensions
- Error handling for export operations
"""
@@ -25,6 +26,22 @@ from noteflow.grpc.proto import noteflow_pb2
from noteflow.grpc.service import NoteFlowServicer
from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWork
def _weasyprint_available() -> bool:
"""Check if weasyprint is available (may fail due to missing system libraries)."""
try:
import weasyprint as _ # noqa: F401
return True
except (ImportError, OSError):
return False
requires_weasyprint = pytest.mark.skipif(
not _weasyprint_available(),
reason="weasyprint not available (missing system libraries: pango, gobject)",
)
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
@@ -114,6 +131,85 @@ class TestExportServiceDatabase:
assert "<html" in content.lower() or "<!doctype" in content.lower()
assert "HTML content test" in content
@pytest.mark.slow
@requires_weasyprint
async def test_export_pdf_from_database(
self, session_factory: async_sessionmaker[AsyncSession]
) -> None:
"""Test exporting meeting as PDF from database with full content verification."""
async with SqlAlchemyUnitOfWork(session_factory) as uow:
meeting = Meeting.create(title="Export PDF Test")
meeting.start_recording()
meeting.begin_stopping()
meeting.stop_recording()
await uow.meetings.create(meeting)
segments = [
Segment(
segment_id=0,
text="First speaker talks here.",
start_time=0.0,
end_time=3.0,
speaker_id="Alice",
),
Segment(
segment_id=1,
text="Second speaker responds.",
start_time=3.0,
end_time=6.0,
speaker_id="Bob",
),
]
for segment in segments:
await uow.segments.add(meeting.id, segment)
await uow.commit()
export_service = ExportService(SqlAlchemyUnitOfWork(session_factory))
content = await export_service.export_transcript(meeting.id, ExportFormat.PDF)
assert isinstance(content, bytes), "PDF export should return bytes"
assert content.startswith(b"%PDF-"), "PDF should have valid magic bytes"
assert len(content) > 1000, "PDF should have substantial content"
@pytest.mark.slow
@requires_weasyprint
async def test_export_to_file_creates_pdf_file(
self, session_factory: async_sessionmaker[AsyncSession]
) -> None:
"""Test export_to_file writes valid binary PDF file to disk."""
async with SqlAlchemyUnitOfWork(session_factory) as uow:
meeting = Meeting.create(title="PDF File Export Test")
await uow.meetings.create(meeting)
segment = Segment(
segment_id=0,
text="Content for PDF file.",
start_time=0.0,
end_time=2.0,
speaker_id="Speaker",
)
await uow.segments.add(meeting.id, segment)
await uow.commit()
with tempfile.TemporaryDirectory() as tmpdir:
output_path = Path(tmpdir) / "transcript.pdf"
export_service = ExportService(SqlAlchemyUnitOfWork(session_factory))
result_path = await export_service.export_to_file(
meeting.id,
output_path,
ExportFormat.PDF,
)
assert result_path.exists(), "PDF file should be created"
assert result_path.suffix == ".pdf", "File should have .pdf extension"
file_bytes = result_path.read_bytes()
assert file_bytes.startswith(b"%PDF-"), "File should contain valid PDF"
assert len(file_bytes) > 500, "PDF file should have content"
async def test_export_to_file_creates_file(
self, session_factory: async_sessionmaker[AsyncSession]
) -> None:
@@ -254,6 +350,44 @@ class TestExportGrpcServicer:
assert "HTML export content" in result.content
assert result.file_extension == ".html"
@pytest.mark.slow
@requires_weasyprint
async def test_export_transcript_pdf_via_grpc(
self, session_factory: async_sessionmaker[AsyncSession]
) -> None:
"""Test ExportTranscript RPC with PDF format returns valid base64-encoded PDF."""
import base64
async with SqlAlchemyUnitOfWork(session_factory) as uow:
meeting = Meeting.create(title="gRPC PDF Export Test")
await uow.meetings.create(meeting)
segment = Segment(
segment_id=0,
text="PDF via gRPC content.",
start_time=0.0,
end_time=5.0,
speaker_id="TestSpeaker",
)
await uow.segments.add(meeting.id, segment)
await uow.commit()
servicer = NoteFlowServicer(session_factory=session_factory)
request = noteflow_pb2.ExportTranscriptRequest(
meeting_id=str(meeting.id),
format=noteflow_pb2.EXPORT_FORMAT_PDF,
)
result = await servicer.ExportTranscript(request, MockContext())
assert result.content, "PDF export should return content"
assert result.file_extension == ".pdf", "File extension should be .pdf"
# gRPC returns base64-encoded PDF; verify it decodes to valid PDF
pdf_bytes = base64.b64decode(result.content)
assert pdf_bytes.startswith(b"%PDF-"), "Decoded content should be valid PDF"
assert len(pdf_bytes) > 500, "PDF should have substantial content"
async def test_export_transcript_nonexistent_meeting(
self, session_factory: async_sessionmaker[AsyncSession]
) -> None:
@@ -406,13 +540,13 @@ class TestExportFormats:
export_service = ExportService(SqlAlchemyUnitOfWork(session_factory))
formats = export_service.get_supported_formats()
assert len(formats) >= 2
assert len(formats) >= 3, "Should support at least markdown, html, and pdf"
format_names = [f[0] for f in formats]
extensions = [f[1] for f in formats]
assert ".md" in extensions
assert ".html" in extensions
assert ".md" in extensions, "Markdown format should be supported"
assert ".html" in extensions, "HTML format should be supported"
assert ".pdf" in extensions, "PDF format should be supported"
async def test_infer_format_markdown(
self, session_factory: async_sessionmaker[AsyncSession]
@@ -438,6 +572,15 @@ class TestExportFormats:
fmt = export_service._infer_format_from_extension(".htm")
assert fmt == ExportFormat.HTML
async def test_infer_format_pdf(
self, session_factory: async_sessionmaker[AsyncSession]
) -> None:
"""Test format inference from .pdf extension."""
export_service = ExportService(SqlAlchemyUnitOfWork(session_factory))
fmt = export_service._infer_format_from_extension(".pdf")
assert fmt == ExportFormat.PDF, "Should infer PDF format from .pdf extension"
async def test_infer_format_unknown_raises(
self, session_factory: async_sessionmaker[AsyncSession]
) -> None:
@@ -446,3 +589,29 @@ class TestExportFormats:
with pytest.raises(ValueError, match="Cannot infer format"):
export_service._infer_format_from_extension(".txt")
@pytest.mark.slow
@requires_weasyprint
async def test_export_pdf_returns_bytes(
self, session_factory: async_sessionmaker[AsyncSession]
) -> None:
"""Test PDF export returns bytes with valid PDF magic bytes."""
async with SqlAlchemyUnitOfWork(session_factory) as uow:
meeting = Meeting.create(title="PDF Export Test")
await uow.meetings.create(meeting)
segment = Segment(
segment_id=0,
text="PDF content test.",
start_time=0.0,
end_time=2.0,
)
await uow.segments.add(meeting.id, segment)
await uow.commit()
export_service = ExportService(SqlAlchemyUnitOfWork(session_factory))
content = await export_service.export_transcript(meeting.id, ExportFormat.PDF)
assert isinstance(content, bytes), "PDF export should return bytes"
assert content.startswith(b"%PDF-"), "PDF should have valid magic bytes"

View File

@@ -0,0 +1,407 @@
"""End-to-end integration tests for NER extraction.
Tests the complete NER workflow with database persistence:
- Entity extraction from meeting segments
- Persistence and retrieval
- Caching behavior
- Force refresh functionality
- Pin entity operations
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch
from uuid import uuid4
import grpc
import pytest
from noteflow.application.services.ner_service import ExtractionResult, NerService
from noteflow.domain.entities import Meeting, Segment
from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity
from noteflow.domain.value_objects import MeetingId
from noteflow.grpc.proto import noteflow_pb2
from noteflow.grpc.service import NoteFlowServicer
from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWork
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
@pytest.fixture(autouse=True)
def mock_feature_flags() -> MagicMock:
"""Mock feature flags to enable NER for all tests."""
mock_settings = MagicMock()
mock_settings.feature_flags = MagicMock(ner_enabled=True)
with patch(
"noteflow.application.services.ner_service.get_settings",
return_value=mock_settings,
):
yield mock_settings
class MockContext:
"""Mock gRPC context for testing."""
def __init__(self) -> None:
"""Initialize mock context."""
self.aborted = False
self.abort_code: grpc.StatusCode | None = None
self.abort_details: str | None = None
async def abort(self, code: grpc.StatusCode, details: str) -> None:
"""Record abort and raise to simulate gRPC behavior."""
self.aborted = True
self.abort_code = code
self.abort_details = details
raise grpc.RpcError()
class MockNerEngine:
"""Mock NER engine that returns controlled entities."""
def __init__(self, entities: list[NamedEntity] | None = None) -> None:
"""Initialize with predefined entities."""
self._entities = entities or []
self._ready = False
self.extract_call_count = 0
self.extract_from_segments_call_count = 0
def extract(self, text: str) -> list[NamedEntity]:
"""Extract entities from text (mock)."""
self._ready = True
self.extract_call_count += 1
return self._entities
def extract_from_segments(
self, segments: list[tuple[int, str]]
) -> list[NamedEntity]:
"""Extract entities from segments (mock)."""
self._ready = True
self.extract_from_segments_call_count += 1
return self._entities
def is_ready(self) -> bool:
"""Check if engine is ready."""
return self._ready
def _create_test_entity(
text: str,
category: EntityCategory = EntityCategory.PERSON,
segment_ids: list[int] | None = None,
confidence: float = 0.9,
) -> NamedEntity:
"""Create a test entity."""
return NamedEntity.create(
text=text,
category=category,
segment_ids=segment_ids or [0],
confidence=confidence,
)
@pytest.mark.integration
class TestNerExtractionFlow:
"""Integration tests for entity extraction workflow."""
async def test_extract_entities_persists_to_database(
self, session_factory: async_sessionmaker[AsyncSession]
) -> None:
"""Extracted entities are persisted to database."""
# Create meeting with segments
async with SqlAlchemyUnitOfWork(session_factory) as uow:
meeting = Meeting.create(title="NER Test Meeting")
await uow.meetings.create(meeting)
for i in range(3):
segment = Segment(
segment_id=i,
text=f"Segment {i} mentioning John Smith.",
start_time=float(i * 10),
end_time=float((i + 1) * 10),
)
await uow.segments.add(meeting.id, segment)
await uow.commit()
meeting_id = meeting.id
# Create mock engine that returns test entities
mock_engine = MockNerEngine(
entities=[
_create_test_entity("John Smith", EntityCategory.PERSON, [0, 1, 2]),
]
)
mock_engine._ready = True
# Create service and extract
uow_factory = lambda: SqlAlchemyUnitOfWork(session_factory)
service = NerService(mock_engine, uow_factory)
result = await service.extract_entities(meeting_id)
assert result.total_count == 1, "Should extract exactly one entity"
assert not result.cached, "First extraction should not be cached"
# Verify persistence
async with SqlAlchemyUnitOfWork(session_factory) as uow:
entities = await uow.entities.get_by_meeting(meeting_id)
assert len(entities) == 1, "Should persist exactly one entity"
assert entities[0].text == "John Smith", "Entity text should match"
assert entities[0].category == EntityCategory.PERSON, "Category should be PERSON"
assert entities[0].segment_ids == [0, 1, 2], "Segment IDs should match"
async def test_extract_entities_returns_cached_on_second_call(
self, session_factory: async_sessionmaker[AsyncSession]
) -> None:
"""Second extraction returns cached entities without re-extraction."""
async with SqlAlchemyUnitOfWork(session_factory) as uow:
meeting = Meeting.create(title="Cache Test")
await uow.meetings.create(meeting)
await uow.segments.add(
meeting.id, Segment(0, "John mentioned Acme Corp.", 0.0, 5.0)
)
await uow.commit()
meeting_id = meeting.id
mock_engine = MockNerEngine(
entities=[
_create_test_entity("John", EntityCategory.PERSON, [0]),
_create_test_entity("Acme Corp", EntityCategory.COMPANY, [0]),
]
)
mock_engine._ready = True
uow_factory = lambda: SqlAlchemyUnitOfWork(session_factory)
service = NerService(mock_engine, uow_factory)
# First extraction
result1 = await service.extract_entities(meeting_id)
assert result1.total_count == 2, "First extraction should find 2 entities"
assert not result1.cached, "First extraction should not be cached"
initial_count = mock_engine.extract_from_segments_call_count
# Second extraction should use cache
result2 = await service.extract_entities(meeting_id)
assert result2.total_count == 2, "Second extraction should return same count"
assert result2.cached, "Second extraction should come from cache"
assert mock_engine.extract_from_segments_call_count == initial_count, (
"Engine should not be called again when using cache"
)
async def test_extract_entities_force_refresh_re_extracts(
self, session_factory: async_sessionmaker[AsyncSession]
) -> None:
"""Force refresh re-extracts and replaces cached entities."""
async with SqlAlchemyUnitOfWork(session_factory) as uow:
meeting = Meeting.create(title="Force Refresh Test")
await uow.meetings.create(meeting)
await uow.segments.add(
meeting.id, Segment(0, "Testing force refresh.", 0.0, 5.0)
)
await uow.commit()
meeting_id = meeting.id
# First extraction
mock_engine = MockNerEngine(
entities=[_create_test_entity("Test", EntityCategory.OTHER, [0])]
)
mock_engine._ready = True
uow_factory = lambda: SqlAlchemyUnitOfWork(session_factory)
service = NerService(mock_engine, uow_factory)
await service.extract_entities(meeting_id)
initial_count = mock_engine.extract_from_segments_call_count
# Force refresh should re-extract
result = await service.extract_entities(meeting_id, force_refresh=True)
assert not result.cached
assert mock_engine.extract_from_segments_call_count == initial_count + 1
@pytest.mark.integration
class TestNerPersistence:
"""Integration tests for entity persistence operations."""
async def test_entities_persist_across_service_instances(
self, session_factory: async_sessionmaker[AsyncSession]
) -> None:
"""Entities persist in database across service instances."""
async with SqlAlchemyUnitOfWork(session_factory) as uow:
meeting = Meeting.create(title="Persistence Test")
await uow.meetings.create(meeting)
await uow.segments.add(meeting.id, Segment(0, "Hello world.", 0.0, 5.0))
await uow.commit()
meeting_id = meeting.id
# Extract with first service instance
mock_engine = MockNerEngine(
entities=[_create_test_entity("World", EntityCategory.OTHER, [0])]
)
mock_engine._ready = True
service1 = NerService(
mock_engine, lambda: SqlAlchemyUnitOfWork(session_factory)
)
await service1.extract_entities(meeting_id)
# Create new service instance (simulating server restart)
service2 = NerService(
MockNerEngine(), lambda: SqlAlchemyUnitOfWork(session_factory)
)
# Should get cached result without extraction
result = await service2.get_entities(meeting_id)
assert len(result) == 1
assert result[0].text == "World"
async def test_clear_entities_removes_all_for_meeting(
self, session_factory: async_sessionmaker[AsyncSession]
) -> None:
"""Clear entities removes all entities for meeting."""
async with SqlAlchemyUnitOfWork(session_factory) as uow:
meeting = Meeting.create(title="Clear Test")
await uow.meetings.create(meeting)
await uow.segments.add(meeting.id, Segment(0, "Test content.", 0.0, 5.0))
await uow.commit()
meeting_id = meeting.id
mock_engine = MockNerEngine(
entities=[
_create_test_entity("Entity1", EntityCategory.PERSON, [0]),
_create_test_entity("Entity2", EntityCategory.COMPANY, [0]),
]
)
mock_engine._ready = True
uow_factory = lambda: SqlAlchemyUnitOfWork(session_factory)
service = NerService(mock_engine, uow_factory)
await service.extract_entities(meeting_id)
deleted_count = await service.clear_entities(meeting_id)
assert deleted_count == 2
# Verify entities are gone
entities = await service.get_entities(meeting_id)
assert len(entities) == 0
@pytest.mark.integration
class TestNerPinning:
"""Integration tests for entity pinning operations."""
async def test_pin_entity_persists_pinned_state(
self, session_factory: async_sessionmaker[AsyncSession]
) -> None:
"""Pin entity updates and persists is_pinned flag."""
async with SqlAlchemyUnitOfWork(session_factory) as uow:
meeting = Meeting.create(title="Pin Test")
await uow.meetings.create(meeting)
await uow.segments.add(meeting.id, Segment(0, "John Doe test.", 0.0, 5.0))
await uow.commit()
meeting_id = meeting.id
mock_engine = MockNerEngine(
entities=[_create_test_entity("John Doe", EntityCategory.PERSON, [0])]
)
mock_engine._ready = True
uow_factory = lambda: SqlAlchemyUnitOfWork(session_factory)
service = NerService(mock_engine, uow_factory)
await service.extract_entities(meeting_id)
# Get entity ID
entities = await service.get_entities(meeting_id)
entity_id = entities[0].id
# Pin entity
result = await service.pin_entity(entity_id, is_pinned=True)
assert result is True
# Verify persistence
entities = await service.get_entities(meeting_id)
assert entities[0].is_pinned is True
# Unpin
await service.pin_entity(entity_id, is_pinned=False)
entities = await service.get_entities(meeting_id)
assert entities[0].is_pinned is False
async def test_pin_entity_nonexistent_returns_false(
self, session_factory: async_sessionmaker[AsyncSession]
) -> None:
"""Pin entity returns False for nonexistent entity."""
uow_factory = lambda: SqlAlchemyUnitOfWork(session_factory)
service = NerService(MockNerEngine(), uow_factory)
result = await service.pin_entity(uuid4(), is_pinned=True)
assert result is False
@pytest.mark.integration
class TestNerEdgeCases:
"""Integration tests for edge cases."""
async def test_extract_from_meeting_with_no_segments(
self, session_factory: async_sessionmaker[AsyncSession]
) -> None:
"""Extract from meeting with no segments returns empty result."""
async with SqlAlchemyUnitOfWork(session_factory) as uow:
meeting = Meeting.create(title="Empty Meeting")
await uow.meetings.create(meeting)
await uow.commit()
meeting_id = meeting.id
mock_engine = MockNerEngine()
uow_factory = lambda: SqlAlchemyUnitOfWork(session_factory)
service = NerService(mock_engine, uow_factory)
result = await service.extract_entities(meeting_id)
assert result.total_count == 0
assert result.entities == []
assert not result.cached
async def test_extract_from_nonexistent_meeting_raises(
self, session_factory: async_sessionmaker[AsyncSession]
) -> None:
"""Extract from nonexistent meeting raises ValueError."""
mock_engine = MockNerEngine()
uow_factory = lambda: SqlAlchemyUnitOfWork(session_factory)
service = NerService(mock_engine, uow_factory)
nonexistent_id = MeetingId(uuid4())
with pytest.raises(ValueError, match=str(nonexistent_id)):
await service.extract_entities(nonexistent_id)
async def test_has_entities_reflects_extraction_state(
self, session_factory: async_sessionmaker[AsyncSession]
) -> None:
"""has_entities returns correct state before and after extraction."""
async with SqlAlchemyUnitOfWork(session_factory) as uow:
meeting = Meeting.create(title="Has Entities Test")
await uow.meetings.create(meeting)
await uow.segments.add(meeting.id, Segment(0, "Test.", 0.0, 5.0))
await uow.commit()
meeting_id = meeting.id
mock_engine = MockNerEngine(
entities=[_create_test_entity("Test", EntityCategory.OTHER, [0])]
)
mock_engine._ready = True
uow_factory = lambda: SqlAlchemyUnitOfWork(session_factory)
service = NerService(mock_engine, uow_factory)
# Before extraction
assert await service.has_entities(meeting_id) is False
# After extraction
await service.extract_entities(meeting_id)
assert await service.has_entities(meeting_id) is True
# After clearing
await service.clear_entities(meeting_id)
assert await service.has_entities(meeting_id) is False

14
uv.lock generated
View File

@@ -237,6 +237,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/3a/2a/7cc015f5b9f5db42b7d48157e23356022889fc354a2813c15934b7cb5c0e/attrs-25.4.0-py3-none-any.whl", hash = "sha256:adcf7e2a1fb3b36ac48d97835bb6d8ade15b8dcce26aba8bf1d14847b57a3373", size = 67615, upload-time = "2025-10-06T13:54:43.17Z" },
]
[[package]]
name = "authlib"
version = "1.6.6"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "cryptography" },
]
sdist = { url = "https://files.pythonhosted.org/packages/bb/9b/b1661026ff24bc641b76b78c5222d614776b0c085bcfdac9bd15a1cb4b35/authlib-1.6.6.tar.gz", hash = "sha256:45770e8e056d0f283451d9996fbb59b70d45722b45d854d58f32878d0a40c38e", size = 164894, upload-time = "2025-12-12T08:01:41.464Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/54/51/321e821856452f7386c4e9df866f196720b1ad0c5ea1623ea7399969ae3b/authlib-1.6.6-py2.py3-none-any.whl", hash = "sha256:7d9e9bc535c13974313a87f53e8430eb6ea3d1cf6ae4f6efcd793f2e949143fd", size = 244005, upload-time = "2025-12-12T08:01:40.209Z" },
]
[[package]]
name = "av"
version = "16.0.1"
@@ -2212,6 +2224,7 @@ source = { editable = "." }
dependencies = [
{ name = "alembic" },
{ name = "asyncpg" },
{ name = "authlib" },
{ name = "cryptography" },
{ name = "diart" },
{ name = "faster-whisper" },
@@ -2298,6 +2311,7 @@ requires-dist = [
{ name = "alembic", specifier = ">=1.13" },
{ name = "anthropic", marker = "extra == 'summarization'", specifier = ">=0.75.0" },
{ name = "asyncpg", specifier = ">=0.29" },
{ name = "authlib", specifier = ">=1.6.6" },
{ name = "basedpyright", marker = "extra == 'dev'", specifier = ">=1.18" },
{ name = "cryptography", specifier = ">=42.0" },
{ name = "diart", specifier = ">=0.9.2" },