Update dependencies and enhance calendar integration features
- Updated `pyproject.toml` to include `authlib` as a dependency for OAuth integration. - Modified `uv.lock` to reflect the addition of `authlib` and its versioning. - Updated documentation to clarify existing components and their statuses in the calendar sync and webhook integration sprints. - Refactored various methods and properties for improved clarity and consistency in the meeting and export services. - Enhanced test coverage for export functionality, including assertions for PDF format. - Updated client submodule reference to ensure alignment with the latest changes.
This commit is contained in:
519
.serena/memories/sprint_4_grpc_validation.md
Normal file
519
.serena/memories/sprint_4_grpc_validation.md
Normal file
@@ -0,0 +1,519 @@
|
||||
# Sprint 4 gRPC Layer Validation Report
|
||||
|
||||
## Executive Summary
|
||||
|
||||
Sprint 4's proposed approach to implement `EntitiesMixin` for Named Entity Recognition aligns perfectly with NoteFlow's established gRPC architecture. The codebase already has:
|
||||
- Proto definitions for entity extraction (`ExtractEntitiesRequest/Response` lines 593-630)
|
||||
- ORM model (`NamedEntityModel`) and migration
|
||||
- No NER service yet (but pattern is well-established)
|
||||
|
||||
**Key Finding**: The mixin pattern is mature and battle-tested across 7 existing mixins. Adding `EntitiesMixin` follows proven conventions with minimal surprises.
|
||||
|
||||
---
|
||||
|
||||
## 1. EXISTING MIXINS ARCHITECTURE
|
||||
|
||||
### Mixins Implemented
|
||||
Located in `src/noteflow/grpc/_mixins/`:
|
||||
1. **StreamingMixin** - Audio bidirectional streaming + ASR
|
||||
2. **DiarizationMixin** - Speaker identification
|
||||
3. **DiarizationJobMixin** - Background diarization job tracking
|
||||
4. **MeetingMixin** - Meeting CRUD (create, get, list, delete, stop)
|
||||
5. **SummarizationMixin** - Summary generation
|
||||
6. **AnnotationMixin** - Annotation management (CRUD)
|
||||
7. **ExportMixin** - Transcript export (Markdown/HTML/PDF)
|
||||
|
||||
All exported from `src/noteflow/grpc/_mixins/__init__.py`.
|
||||
|
||||
### NoteFlowServicer Composition
|
||||
```python
|
||||
class NoteFlowServicer(
|
||||
StreamingMixin,
|
||||
DiarizationMixin,
|
||||
DiarizationJobMixin,
|
||||
MeetingMixin,
|
||||
SummarizationMixin,
|
||||
AnnotationMixin,
|
||||
ExportMixin,
|
||||
noteflow_pb2_grpc.NoteFlowServiceServicer,
|
||||
):
|
||||
```
|
||||
**Pattern**: Multiple inheritance ordering matters—last mixin checked wins on method conflicts.
|
||||
|
||||
---
|
||||
|
||||
## 2. MIXIN STRUCTURE PATTERN
|
||||
|
||||
### Standard Mixin Template (from AnnotationMixin)
|
||||
|
||||
**Class Definition**:
|
||||
```python
|
||||
class AnnotationMixin:
|
||||
"""Docstring explaining what it does.
|
||||
|
||||
Requires host to implement ServicerHost protocol.
|
||||
Works with both database and memory backends via RepositoryProvider.
|
||||
"""
|
||||
```
|
||||
|
||||
**Key Characteristics**:
|
||||
- No `__init__` required (state lives in NoteFlowServicer)
|
||||
- Type hint `self: ServicerHost` to access host attributes
|
||||
- All service dependencies injected through protocol
|
||||
- Works with both database and in-memory backends
|
||||
|
||||
### Method Signature Pattern
|
||||
```python
|
||||
async def RPCMethod(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.SomeRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.SomeResponse:
|
||||
"""RPC implementation."""
|
||||
# Get meeting ID and validate
|
||||
meeting_id = await parse_meeting_id_or_abort(request.meeting_id, context)
|
||||
|
||||
# Use unified repository provider
|
||||
async with self._create_repository_provider() as repo:
|
||||
# Load data
|
||||
item = await repo.entities.get(item_id)
|
||||
if item is None:
|
||||
await abort_not_found(context, "Entity", request.id)
|
||||
|
||||
# Process
|
||||
result = await repo.entities.update(item)
|
||||
await repo.commit()
|
||||
return converted_to_proto(result)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. SERVICEHOST PROTOCOL (src/noteflow/grpc/_mixins/protocols.py)
|
||||
|
||||
The `ServicerHost` protocol defines the contract mixins expect from NoteFlowServicer:
|
||||
|
||||
### Configuration Attributes
|
||||
- `_session_factory: async_sessionmaker[AsyncSession] | None`
|
||||
- `_memory_store: MeetingStore | None`
|
||||
- `_meetings_dir: Path`
|
||||
- `_crypto: AesGcmCryptoBox`
|
||||
|
||||
### Service Engines
|
||||
- `_asr_engine: FasterWhisperEngine | None`
|
||||
- `_diarization_engine: DiarizationEngine | None`
|
||||
- `_summarization_service: object | None` ← **Generic type for flexibility**
|
||||
|
||||
**For EntitiesMixin**: Would need to add:
|
||||
```python
|
||||
_ner_service: NerService | None # Type will be added to protocol
|
||||
```
|
||||
|
||||
### Key Methods Mixins Use
|
||||
```python
|
||||
def _use_database(self) -> bool
|
||||
def _get_memory_store(self) -> MeetingStore
|
||||
def _create_uow(self) -> SqlAlchemyUnitOfWork
|
||||
def _create_repository_provider(self) -> UnitOfWork # Handles both DB and memory
|
||||
def _next_segment_id(self, meeting_id: str, fallback: int = 0) -> int
|
||||
def _init_streaming_state(self, meeting_id: str, next_segment_id: int) -> None
|
||||
def _cleanup_streaming_state(self, meeting_id: str) -> None
|
||||
def _ensure_meeting_dek(self, meeting: Meeting) -> tuple[bytes, bytes, bool]
|
||||
def _start_meeting_if_needed(self, meeting: Meeting) -> tuple[bool, str | None]
|
||||
def _open_meeting_audio_writer(...) -> None
|
||||
def _close_audio_writer(self, meeting_id: str) -> None
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. PROTO DEFINITIONS (noteflow.proto, lines 45-46, 593-630)
|
||||
|
||||
**Already defined for Sprint 4**:
|
||||
|
||||
### Service RPC
|
||||
```protobuf
|
||||
rpc ExtractEntities(ExtractEntitiesRequest) returns (ExtractEntitiesResponse);
|
||||
```
|
||||
|
||||
### Request Message
|
||||
```protobuf
|
||||
message ExtractEntitiesRequest {
|
||||
string meeting_id = 1;
|
||||
bool force_refresh = 2;
|
||||
}
|
||||
```
|
||||
|
||||
### Response Messages
|
||||
```protobuf
|
||||
message ExtractedEntity {
|
||||
string id = 1;
|
||||
string text = 2;
|
||||
string category = 3; // person, company, product, technical, acronym, location, date, other
|
||||
repeated int32 segment_ids = 4;
|
||||
float confidence = 5;
|
||||
bool is_pinned = 6;
|
||||
}
|
||||
|
||||
message ExtractEntitiesResponse {
|
||||
repeated ExtractedEntity entities = 1;
|
||||
int32 total_count = 2;
|
||||
bool cached = 3;
|
||||
}
|
||||
```
|
||||
|
||||
**Status**: Proto defs complete, no regeneration needed unless modifying messages.
|
||||
|
||||
---
|
||||
|
||||
## 5. CONVERTER PATTERN (src/noteflow/grpc/_mixins/converters.py)
|
||||
|
||||
### Pattern Overview
|
||||
All converters are **standalone functions** in a single module, not class methods:
|
||||
|
||||
```python
|
||||
def parse_meeting_id_or_abort(meeting_id_str: str, context: object) -> MeetingId
|
||||
def parse_annotation_id(annotation_id_str: str) -> AnnotationId
|
||||
def word_to_proto(word: WordTiming) -> noteflow_pb2.WordTiming
|
||||
def segment_to_final_segment_proto(segment: Segment) -> noteflow_pb2.FinalSegment
|
||||
def meeting_to_proto(meeting: Meeting, ...) -> noteflow_pb2.Meeting
|
||||
def summary_to_proto(summary: Summary) -> noteflow_pb2.Summary
|
||||
def annotation_to_proto(annotation: Annotation) -> noteflow_pb2.Annotation
|
||||
def create_segment_from_asr(meeting_id: MeetingId, ...) -> Segment
|
||||
def proto_to_export_format(proto_format: int) -> ExportFormat
|
||||
```
|
||||
|
||||
### For EntitiesMixin, Need to Add:
|
||||
```python
|
||||
def entity_to_proto(entity: Entity) -> noteflow_pb2.ExtractedEntity
|
||||
def proto_to_entity(request: noteflow_pb2.ExtractEntitiesRequest, ...) -> Entity | None
|
||||
def proto_entity_list_to_domain(protos: list[...]) -> list[Entity]
|
||||
```
|
||||
|
||||
### Key Principles
|
||||
1. **Stateless functions** - no dependencies on mixins
|
||||
2. **Handle both directions** - proto→domain and domain→proto
|
||||
3. **Consolidate patterns** - e.g., `parse_meeting_id_or_abort` replaces 11+ duplicated patterns
|
||||
4. **Use domain converters** - import from infrastructure when needed (e.g., `AsrConverter`)
|
||||
|
||||
---
|
||||
|
||||
## 6. DEPENDENCY INJECTION PATTERN
|
||||
|
||||
### How Services Get Injected
|
||||
|
||||
**Example: SummarizationMixin**
|
||||
```python
|
||||
class SummarizationMixin:
|
||||
_summarization_service: SummarizationService | None
|
||||
|
||||
async def GenerateSummary(self: ServicerHost, ...):
|
||||
# Access through self
|
||||
summary = await self._summarize_or_placeholder(meeting_id, segments, style_prompt)
|
||||
```
|
||||
|
||||
**Pattern for NerService**:
|
||||
1. Add field to NoteFlowServicer `__init__` parameter:
|
||||
```python
|
||||
def __init__(
|
||||
self,
|
||||
...,
|
||||
ner_service: NerService | None = None,
|
||||
) -> None:
|
||||
self._ner_service = ner_service
|
||||
```
|
||||
|
||||
2. Add to ServicerHost protocol:
|
||||
```python
|
||||
class ServicerHost(Protocol):
|
||||
_ner_service: NerService | None
|
||||
```
|
||||
|
||||
3. Use in EntitiesMixin:
|
||||
```python
|
||||
async def ExtractEntities(self: ServicerHost, request, context):
|
||||
# Service may be None; handle gracefully or abort
|
||||
if self._ner_service is None:
|
||||
await abort_service_unavailable(context, "NER")
|
||||
```
|
||||
|
||||
### Repository Provider Pattern
|
||||
All mixins use `self._create_repository_provider()` to abstract DB vs. in-memory:
|
||||
```python
|
||||
async with self._create_repository_provider() as repo:
|
||||
entities = await repo.entities.get_by_meeting(meeting_id) # Need to implement
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. HOW TO ADD ENTITESMIXIN
|
||||
|
||||
### Step-by-Step Integration Points
|
||||
|
||||
**A. Create `src/noteflow/grpc/_mixins/entities.py`**
|
||||
```python
|
||||
class EntitiesMixin:
|
||||
"""Mixin providing named entity extraction functionality."""
|
||||
|
||||
async def ExtractEntities(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.ExtractEntitiesRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.ExtractEntitiesResponse:
|
||||
# Implementation
|
||||
```
|
||||
|
||||
**B. Update `src/noteflow/grpc/_mixins/__init__.py`**
|
||||
```python
|
||||
from .entities import EntitiesMixin
|
||||
__all__ = [..., "EntitiesMixin"]
|
||||
```
|
||||
|
||||
**C. Update `src/noteflow/grpc/service.py`**
|
||||
```python
|
||||
from ._mixins import (
|
||||
...,
|
||||
EntitiesMixin, # Add
|
||||
)
|
||||
|
||||
class NoteFlowServicer(
|
||||
...,
|
||||
EntitiesMixin, # Add before noteflow_pb2_grpc.NoteFlowServiceServicer
|
||||
noteflow_pb2_grpc.NoteFlowServiceServicer,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
...,
|
||||
ner_service: NerService | None = None,
|
||||
):
|
||||
...,
|
||||
self._ner_service = ner_service
|
||||
```
|
||||
|
||||
**D. Update `src/noteflow/grpc/_mixins/protocols.py`**
|
||||
```python
|
||||
class ServicerHost(Protocol):
|
||||
...
|
||||
_ner_service: object | None # or import NerService if in TYPE_CHECKING
|
||||
```
|
||||
|
||||
**E. Update `src/noteflow/grpc/_mixins/converters.py`**
|
||||
Add entity conversion functions:
|
||||
```python
|
||||
def entity_to_proto(entity: Entity) -> noteflow_pb2.ExtractedEntity:
|
||||
"""Convert domain Entity to protobuf."""
|
||||
return noteflow_pb2.ExtractedEntity(
|
||||
id=str(entity.id),
|
||||
text=entity.text,
|
||||
category=entity.category,
|
||||
segment_ids=entity.segment_ids,
|
||||
confidence=entity.confidence,
|
||||
is_pinned=entity.is_pinned,
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8. REPOSITORY INTEGRATION
|
||||
|
||||
### Current Repositories Available
|
||||
Located in `src/noteflow/infrastructure/persistence/repositories/`:
|
||||
- `annotation_repo.py` → `SqlAlchemyAnnotationRepository`
|
||||
- `diarization_job_repo.py` → `SqlAlchemyDiarizationJobRepository`
|
||||
- `meeting_repo.py` → `SqlAlchemyMeetingRepository`
|
||||
- `preferences_repo.py` → `SqlAlchemyPreferencesRepository`
|
||||
- `segment_repo.py` → `SqlAlchemySegmentRepository`
|
||||
- `summary_repo.py` → `SqlAlchemySummaryRepository`
|
||||
|
||||
**Missing**: `entity_repo.py` (Sprint 4 task)
|
||||
|
||||
### UnitOfWork Integration
|
||||
Current `SqlAlchemyUnitOfWork` (line 87-127):
|
||||
```python
|
||||
@property
|
||||
def annotations(self) -> SqlAlchemyAnnotationRepository: ...
|
||||
@property
|
||||
def meetings(self) -> SqlAlchemyMeetingRepository: ...
|
||||
```
|
||||
|
||||
**Needed for Sprint 4**:
|
||||
```python
|
||||
@property
|
||||
def entities(self) -> SqlAlchemyEntityRepository: ...
|
||||
|
||||
# And in __aenter__/__aexit__:
|
||||
self._entities_repo = SqlAlchemyEntityRepository(self._session)
|
||||
```
|
||||
|
||||
### Domain vs. In-Memory
|
||||
NoteFlowServicer uses `_create_repository_provider()` which returns:
|
||||
- `SqlAlchemyUnitOfWork` if database configured
|
||||
- `MemoryUnitOfWork` if in-memory fallback
|
||||
|
||||
EntitiesMixin must handle both seamlessly (like AnnotationMixin does).
|
||||
|
||||
---
|
||||
|
||||
## 9. ERROR HANDLING PATTERN
|
||||
|
||||
All mixins use helper functions from `src/noteflow/grpc/_mixins/errors.py`:
|
||||
|
||||
```python
|
||||
async def abort_not_found(context, entity_type, entity_id) -> None
|
||||
async def abort_invalid_argument(context, message) -> None
|
||||
async def abort_database_required(context, feature_name) -> None
|
||||
```
|
||||
|
||||
**Pattern in EntitiesMixin**:
|
||||
```python
|
||||
async def ExtractEntities(self: ServicerHost, request, context):
|
||||
meeting_id = await parse_meeting_id_or_abort(request.meeting_id, context)
|
||||
async with self._create_repository_provider() as repo:
|
||||
meeting = await repo.meetings.get(meeting_id)
|
||||
if meeting is None:
|
||||
await abort_not_found(context, "Meeting", request.meeting_id)
|
||||
|
||||
entities = await repo.entities.get_by_meeting(meeting_id)
|
||||
# Handle not found, validation errors, etc.
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 10. CACHING & REFRESH PATTERN
|
||||
|
||||
From proto: `force_refresh` flag in `ExtractEntitiesRequest`.
|
||||
|
||||
**SummarizationMixin pattern** (similar):
|
||||
```python
|
||||
existing = await repo.summaries.get_by_meeting(meeting_id)
|
||||
if existing and not request.force_regenerate:
|
||||
return summary_to_proto(existing)
|
||||
|
||||
# Otherwise, generate fresh
|
||||
```
|
||||
|
||||
**EntitiesMixin should follow**:
|
||||
```python
|
||||
async def ExtractEntities(self: ServicerHost, request, context):
|
||||
meeting_id = await parse_meeting_id_or_abort(request.meeting_id, context)
|
||||
|
||||
async with self._create_repository_provider() as repo:
|
||||
meeting = await repo.meetings.get(meeting_id)
|
||||
if meeting is None:
|
||||
await abort_not_found(context, "Meeting", request.meeting_id)
|
||||
|
||||
# Check cache first
|
||||
cached_entities = await repo.entities.get_by_meeting(meeting_id)
|
||||
if cached_entities and not request.force_refresh:
|
||||
return noteflow_pb2.ExtractEntitiesResponse(
|
||||
entities=[entity_to_proto(e) for e in cached_entities],
|
||||
total_count=len(cached_entities),
|
||||
cached=True,
|
||||
)
|
||||
|
||||
# Load segments, run NER service (outside repo context if slow)
|
||||
segments = await repo.segments.get_by_meeting(meeting_id)
|
||||
entities = await self._extract_entities(meeting_id, segments)
|
||||
|
||||
# Persist in fresh transaction
|
||||
async with self._create_repository_provider() as repo:
|
||||
saved = await repo.entities.save_batch(entities)
|
||||
await repo.commit()
|
||||
|
||||
return noteflow_pb2.ExtractEntitiesResponse(
|
||||
entities=[entity_to_proto(e) for e in saved],
|
||||
total_count=len(saved),
|
||||
cached=False,
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 11. KEY INSIGHTS & RISKS
|
||||
|
||||
### Strengths of Current Pattern
|
||||
✅ **Modular**: Each mixin is self-contained, easy to test
|
||||
✅ **Composable**: NoteFlowServicer inherits from all mixins without complexity
|
||||
✅ **Database-agnostic**: `_create_repository_provider()` abstracts DB vs. memory
|
||||
✅ **Proven**: 7 existing mixins provide battle-tested templates
|
||||
✅ **Asyncio-native**: All methods are async, no blocking calls
|
||||
✅ **Type-safe**: ServicerHost protocol ensures mixin contracts
|
||||
|
||||
### Potential Issues Sprint 4 Should Watch For
|
||||
|
||||
1. **Missing Domain Entity**: No `Entity` class in `src/noteflow/domain/entities/` yet
|
||||
- Proto has `ExtractedEntity` (protobuf message)
|
||||
- ORM has `NamedEntityModel` (persistence)
|
||||
- Need to bridge: `Entity` domain class
|
||||
|
||||
2. **Missing Repository**: `entity_repo.py` doesn't exist
|
||||
- Must implement CRUD methods used in mixin
|
||||
- Must handle both database and in-memory backends
|
||||
|
||||
3. **NER Service Not Started**: Service layer (`NerService`) needs implementation
|
||||
- How to get segments for extraction?
|
||||
- What model/provider does NER use?
|
||||
- Is it async? Does it block?
|
||||
|
||||
4. **Proto Regeneration**: After any proto changes:
|
||||
```bash
|
||||
python -m grpc_tools.protoc -I src/noteflow/grpc/proto \
|
||||
--python_out=src/noteflow/grpc/proto \
|
||||
--grpc_python_out=src/noteflow/grpc/proto \
|
||||
src/noteflow/grpc/proto/noteflow.proto
|
||||
```
|
||||
Proto files are already generated, so check if they match lines 45-46 and 593-630.
|
||||
|
||||
5. **Memory Store Fallback**: Annotation repository supports in-memory; entity repo must too
|
||||
- See `src/noteflow/infrastructure/persistence/memory.py` for fallback pattern
|
||||
|
||||
---
|
||||
|
||||
## 12. PROTO SYNC CHECKLIST FOR SPRINT 4
|
||||
|
||||
- [x] `ExtractEntities` RPC defined (line 46)
|
||||
- [x] `ExtractEntitiesRequest` message (lines 593-599)
|
||||
- [x] `ExtractedEntity` message (lines 601-619)
|
||||
- [x] `ExtractEntitiesResponse` message (lines 621-630)
|
||||
- [ ] Regenerate Python stubs if proto modified
|
||||
- [ ] Verify `noteflow_pb2.ExtractedEntity` exists in generated code
|
||||
- [ ] Verify `noteflow_pb2_grpc.NoteFlowService` includes `ExtractEntities` method stub
|
||||
|
||||
---
|
||||
|
||||
## Summary Table
|
||||
|
||||
| Component | Status | File | Notes |
|
||||
|-----------|--------|------|-------|
|
||||
| **Mixin Class** | READY | `_mixins/entities.py` (new) | Follow AnnotationMixin pattern |
|
||||
| **Service RPC** | ✅ EXISTS | `proto/noteflow.proto:46` | Already defined |
|
||||
| **Request/Response** | ✅ EXISTS | `proto/noteflow.proto:593-630` | Complete messages |
|
||||
| **ORM Model** | ✅ EXISTS | `persistence/models/entities/named_entity.py` | `NamedEntityModel` ready |
|
||||
| **Migration** | ✅ EXISTS | `persistence/migrations/.../add_named_entities_table.py` | Table created |
|
||||
| **Domain Entity** | ❌ MISSING | `domain/entities/` | Need `Entity` class |
|
||||
| **Repository** | ❌ MISSING | `persistence/repositories/entity_repo.py` | Need CRUD implementation |
|
||||
| **NER Service** | ❌ MISSING | `application/services/` | Need extraction logic |
|
||||
| **Converter Functions** | PARTIAL | `_mixins/converters.py` | Need `entity_to_proto()` |
|
||||
| **UnitOfWork** | READY | `persistence/unit_of_work.py` | Just add `entities` property |
|
||||
| **ServicerHost** | READY | `_mixins/protocols.py` | Just add `_ner_service` field |
|
||||
|
||||
---
|
||||
|
||||
## VALIDATION CONCLUSION
|
||||
|
||||
**Status**: ✅ **SPRINT 4 APPROACH IS SOUND**
|
||||
|
||||
The proposed mixin-based architecture for EntitiesMixin:
|
||||
1. Follows established patterns from 7 existing mixins
|
||||
2. Leverages existing proto definitions
|
||||
3. Has ORM and migration infrastructure ready
|
||||
4. Fits cleanly into NoteFlowServicer composition
|
||||
|
||||
**Next Steps**:
|
||||
1. Create domain `Entity` class (mirrors proto `ExtractedEntity`)
|
||||
2. Implement `SqlAlchemyEntityRepository` with CRUD operations
|
||||
3. Create `NerService` application service
|
||||
4. Implement `EntitiesMixin.ExtractEntities()` RPC handler
|
||||
5. Add converter functions `entity_to_proto()` etc.
|
||||
6. Add `entities` property to `SqlAlchemyUnitOfWork`
|
||||
7. Add `_ner_service` to `ServicerHost` protocol
|
||||
|
||||
No architectural surprises or deviations from established patterns.
|
||||
@@ -36,44 +36,56 @@ Automatically extract named entities (people, companies, products, locations) fr
|
||||
|
||||
## Current State Analysis
|
||||
|
||||
### What Exists
|
||||
### What Already Exists
|
||||
|
||||
#### Frontend Components (Complete)
|
||||
|
||||
| Component | Location | Status |
|
||||
|-----------|----------|--------|
|
||||
| Entity UI | `client/src/components/EntityHighlightText.tsx` | Renders highlighted entities |
|
||||
| Entity Panel | `client/src/components/EntityManagementPanel.tsx` | Manual CRUD for entities |
|
||||
| Annotation RPC | `noteflow.proto` | `AddAnnotation` for manual entries |
|
||||
| Entity Highlight | `client/src/components/entity-highlight.tsx` | Renders inline entity highlights with tooltips (190 lines) |
|
||||
| Entity Panel | `client/src/components/entity-management-panel.tsx` | Manual CRUD with Sheet UI (388 lines) |
|
||||
| Entity Store | `client/src/lib/entity-store.ts` | Client-side state with observer pattern (169 lines) |
|
||||
| Entity Types | `client/src/types/entity.ts` | TypeScript types + color mappings (49 lines) |
|
||||
|
||||
### Existing Persistence Infrastructure
|
||||
#### Backend Infrastructure (Complete)
|
||||
|
||||
**Location**: `src/noteflow/infrastructure/persistence/`
|
||||
| Component | Location | Status |
|
||||
|-----------|----------|--------|
|
||||
| **ORM Model** | `infrastructure/persistence/models/entities/named_entity.py` | Complete - `NamedEntityModel` with all fields |
|
||||
| **Migration** | `migrations/versions/h2c3d4e5f6g7_add_named_entities_table.py` | Complete - table created with indices |
|
||||
| **Meeting Relationship** | `models/core/meeting.py:35-37` | Configured - `named_entities` with cascade delete |
|
||||
| **Proto RPC** | `noteflow.proto:46` | Defined - `rpc ExtractEntities(...)` |
|
||||
| **Proto Messages** | `noteflow.proto:593-630` | Defined - Request, Response, ExtractedEntity |
|
||||
| **Feature Flag** | `config/settings.py:223-226` | Defined - `NOTEFLOW_FEATURE_NER_ENABLED` |
|
||||
| **Dependency** | `pyproject.toml:63` | Declared - `spacy>=3.7` |
|
||||
|
||||
**Decision Point**: Named entities can be stored in:
|
||||
> **Note**: The `named_entities` table and proto definitions are already implemented.
|
||||
> Sprint 0 completed this foundation work. No schema changes needed.
|
||||
|
||||
1. **Existing `annotations` table** (simpler, less work):
|
||||
- `annotation_type` = 'ENTITY_PERSON', 'ENTITY_COMPANY', etc.
|
||||
- Reuses existing `AnnotationModel` and `SqlAlchemyAnnotationRepository`
|
||||
- Schema: `docker/db/schema.sql:182-196`
|
||||
### Gap: What Needs Implementation
|
||||
|
||||
2. **New `named_entities` table** (recommended for richer features):
|
||||
- Dedicated table with `category`, `confidence`, `is_pinned` columns
|
||||
- Better for entity deduplication and cross-meeting entity linking
|
||||
- Requires new Alembic migration
|
||||
Backend NER processing pipeline:
|
||||
- No spaCy engine wrapper (`infrastructure/ner/`)
|
||||
- No `NamedEntity` domain entity class
|
||||
- No `NerPort` domain protocol
|
||||
- No `NerService` application layer
|
||||
- No `SqlAlchemyEntityRepository`
|
||||
- No `EntitiesMixin` gRPC handler
|
||||
- No Rust/Tauri command for extraction
|
||||
- Frontend not wired to backend extraction
|
||||
|
||||
**Recommendation**: Use option 2 (new table) for:
|
||||
- Entity confidence scores
|
||||
- Pinned/verified entity status
|
||||
- Future: cross-meeting entity linking and knowledge graph
|
||||
### Field Naming Alignment Warning
|
||||
|
||||
### Gap
|
||||
|
||||
No backend NER capability:
|
||||
- No spaCy or other NER engine
|
||||
- No `ExtractEntities` RPC
|
||||
- No `NamedEntity` domain entity
|
||||
- **No `NerService` application layer** (violates hexagonal architecture)
|
||||
- Entities are purely client-side manual entries
|
||||
- No persistence for extracted entities
|
||||
> **⚠️ CRITICAL**: Frontend and backend use different field names for entity text:
|
||||
>
|
||||
> | Layer | Field | Location |
|
||||
> |-------|-------|----------|
|
||||
> | Frontend TS | `term` | `client/src/types/entity.ts:Entity.term` |
|
||||
> | Backend Proto | `text` | `noteflow.proto:ExtractedEntity.text` |
|
||||
> | ORM Model | `text` | `models/entities/named_entity.py:NamedEntityModel.text` |
|
||||
>
|
||||
> **Resolution required**: Either align frontend to use `text` or add mapping in Tauri commands.
|
||||
> The existing `entity-store.ts` uses `term` throughout—a rename may be needed for consistency.
|
||||
|
||||
---
|
||||
|
||||
@@ -120,29 +132,33 @@ All components use **Protocol-based dependency injection** for testability:
|
||||
| `src/noteflow/application/services/ner_service.py` | **Application service layer** | ~150 |
|
||||
| `src/noteflow/infrastructure/ner/__init__.py` | Module init | ~10 |
|
||||
| `src/noteflow/infrastructure/ner/engine.py` | spaCy NER engine | ~120 |
|
||||
| `src/noteflow/infrastructure/persistence/models/core/named_entity.py` | ORM model | ~60 |
|
||||
| `src/noteflow/infrastructure/persistence/repositories/entity_repo.py` | Entity repository | ~100 |
|
||||
| `src/noteflow/infrastructure/converters/ner_converters.py` | ORM ↔ domain converters | ~50 |
|
||||
| `src/noteflow/grpc/_mixins/entities.py` | gRPC mixin (calls NerService) | ~100 |
|
||||
| `tests/application/test_ner_service.py` | Application layer tests | ~200 |
|
||||
| `tests/infrastructure/ner/test_engine.py` | Engine unit tests | ~150 |
|
||||
| `client/src/hooks/useEntityExtraction.ts` | React hook with polling | ~80 |
|
||||
| `client/src/components/EntityExtractionPanel.tsx` | UI component | ~120 |
|
||||
| `client/src/hooks/use-entity-extraction.ts` | React hook with polling | ~80 |
|
||||
| `client/src-tauri/src/commands/entities.rs` | Rust command | ~60 |
|
||||
|
||||
> **Note**: ORM model (`NamedEntityModel`) already exists at `models/entities/named_entity.py`.
|
||||
> Frontend extraction panel can extend existing `entity-management-panel.tsx`.
|
||||
|
||||
### Files to Modify
|
||||
|
||||
| File | Change Type | Lines Est. |
|
||||
|------|-------------|------------|
|
||||
| `src/noteflow/grpc/service.py` | Add mixin + NerService initialization | +15 |
|
||||
| `src/noteflow/grpc/server.py` | Initialize NerService | +10 |
|
||||
| `src/noteflow/infrastructure/persistence/models/__init__.py` | Export entity model | +2 |
|
||||
| `src/noteflow/infrastructure/persistence/models/core/__init__.py` | Export entity model | +2 |
|
||||
| `src/noteflow/grpc/_mixins/protocols.py` | Add `_ner_service` to ServicerHost | +5 |
|
||||
| `src/noteflow/infrastructure/persistence/repositories/__init__.py` | Export entity repo | +2 |
|
||||
| `src/noteflow/infrastructure/persistence/unit_of_work.py` | Add entity repository | +10 |
|
||||
| `pyproject.toml` | Add spacy dependency | +2 |
|
||||
| `src/noteflow/infrastructure/persistence/unit_of_work.py` | Add entity repository property | +15 |
|
||||
| `client/src/lib/tauri.ts` | Add extractEntities wrapper | +15 |
|
||||
| `client/src/pages/MeetingDetail.tsx` | Integrate EntityExtractionPanel | +20 |
|
||||
| `client/src/components/entity-management-panel.tsx` | Add extraction trigger button | +30 |
|
||||
| `client/src-tauri/src/commands/mod.rs` | Export entities module | +2 |
|
||||
| `client/src-tauri/src/lib.rs` | Register entity commands | +3 |
|
||||
| `client/src-tauri/src/grpc/client.rs` | Add extract_entities method | +25 |
|
||||
|
||||
> **Note**: `pyproject.toml` already has `spacy>=3.7`. Proto regeneration not needed.
|
||||
|
||||
---
|
||||
|
||||
@@ -493,59 +509,16 @@ class NerEngine:
|
||||
|
||||
### Task 5: Create Entity Persistence
|
||||
|
||||
**File**: `src/noteflow/infrastructure/persistence/models/core/named_entity.py`
|
||||
> **✅ ORM Model Already Exists**: The `NamedEntityModel` is already implemented at
|
||||
> `infrastructure/persistence/models/entities/named_entity.py`. Do **not** recreate it.
|
||||
> Only the repository and converters need to be implemented.
|
||||
|
||||
```python
|
||||
"""Named entity ORM model."""
|
||||
**Existing ORM** (reference only—do not modify):
|
||||
- Location: `src/noteflow/infrastructure/persistence/models/entities/named_entity.py`
|
||||
- Fields: `id`, `meeting_id`, `text`, `category`, `segment_ids`, `confidence`, `is_pinned`, `created_at`
|
||||
- Relationship: `meeting` → `MeetingModel.named_entities` (cascade delete configured)
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import ARRAY, Boolean, DateTime, Float, ForeignKey, Integer, String
|
||||
from sqlalchemy.dialects.postgresql import UUID as PG_UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from noteflow.infrastructure.persistence.models.base import Base
|
||||
|
||||
|
||||
class NamedEntityModel(Base):
|
||||
"""SQLAlchemy model for named entities."""
|
||||
|
||||
__tablename__ = "named_entities"
|
||||
__table_args__ = {"schema": "noteflow"}
|
||||
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
PG_UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
)
|
||||
meeting_id: Mapped[UUID] = mapped_column(
|
||||
PG_UUID(as_uuid=True),
|
||||
ForeignKey("noteflow.meetings.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
text: Mapped[str] = mapped_column(String, nullable=False)
|
||||
category: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
segment_ids: Mapped[list[int]] = mapped_column(
|
||||
ARRAY(Integer),
|
||||
nullable=False,
|
||||
default=list,
|
||||
)
|
||||
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||
is_pinned: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=datetime.utcnow,
|
||||
)
|
||||
|
||||
# Relationship
|
||||
meeting = relationship("MeetingModel", back_populates="named_entities")
|
||||
```
|
||||
|
||||
**File**: `src/noteflow/infrastructure/persistence/repositories/entity_repo.py`
|
||||
**File to Create**: `src/noteflow/infrastructure/persistence/repositories/entity_repo.py`
|
||||
|
||||
```python
|
||||
"""Named entity repository."""
|
||||
@@ -1510,30 +1483,53 @@ pub async fn clear_entities(
|
||||
|
||||
| Pattern | Reference Location | Usage |
|
||||
|---------|-------------------|-------|
|
||||
| ORM Model | `models/core/annotation.py` | Structure for `NamedEntityModel` |
|
||||
| Repository | `repositories/annotation_repo.py` | CRUD operations |
|
||||
| Unit of Work | `unit_of_work.py` | Add entity repository |
|
||||
| Converters | `converters/orm_converters.py` | ORM ↔ domain entity |
|
||||
| ORM Model | `models/entities/named_entity.py` | **Already exists** - no creation needed |
|
||||
| Repository | `repositories/annotation_repo.py` | Template for CRUD operations |
|
||||
| Base Repository | `repositories/_base.py` | Extend for helper methods |
|
||||
| Unit of Work | `unit_of_work.py` | Add entity repository property |
|
||||
| Converters | `converters/orm_converters.py` | Add `entity_to_domain()` method |
|
||||
|
||||
### Existing Entity UI Components
|
||||
|
||||
**Location**: `client/src/components/EntityHighlightText.tsx`
|
||||
**Location**: `client/src/components/entity-highlight.tsx`
|
||||
|
||||
Already renders entity highlights - connect to extracted entities.
|
||||
Already renders entity highlights with tooltips - connect to extracted entities.
|
||||
|
||||
**Location**: `client/src/components/EntityManagementPanel.tsx`
|
||||
**Location**: `client/src/components/entity-management-panel.tsx`
|
||||
|
||||
CRUD panel - extend to display auto-extracted entities.
|
||||
CRUD panel with Sheet UI - extend to display auto-extracted entities.
|
||||
|
||||
**Location**: `client/src/lib/entity-store.ts`
|
||||
|
||||
Client-side observer pattern store - wire to backend extraction results.
|
||||
|
||||
> **Warning**: Color definitions are duplicated between `entity-highlight.tsx` (inline `categoryColors`)
|
||||
> and `types/entity.ts` (`ENTITY_CATEGORY_COLORS`). Use the shared constant from `types/entity.ts`.
|
||||
|
||||
### Application Service Pattern
|
||||
|
||||
**Location**: `src/noteflow/application/services/diarization_service.py`
|
||||
**Location**: `src/noteflow/application/services/summarization_service.py`
|
||||
|
||||
Pattern for application service with:
|
||||
- Async locks for concurrency control
|
||||
- UoW factory for database operations
|
||||
- `run_in_executor` for CPU-bound work
|
||||
- Structured logging
|
||||
- Dataclass-based service with settings
|
||||
- Provider registration pattern (multi-backend support)
|
||||
- Lazy model loading via property getters
|
||||
- Callback-based persistence (not UoW-injected)
|
||||
|
||||
**Location**: `src/noteflow/infrastructure/summarization/ollama_provider.py`
|
||||
|
||||
Pattern for lazy model loading:
|
||||
- `self._client: T | None = None` initial state
|
||||
- `_get_client()` method for lazy initialization
|
||||
- `is_available` property for runtime checks
|
||||
- `asyncio.to_thread()` for CPU-bound inference
|
||||
|
||||
**Location**: `src/noteflow/grpc/_mixins/diarization.py`
|
||||
|
||||
Pattern for CPU-bound gRPC handlers:
|
||||
- `asyncio.Lock` for concurrency control
|
||||
- `loop.run_in_executor()` for blocking operations
|
||||
- Structured logging with meeting context
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -180,7 +180,7 @@ The persistence layer is complete, but there's no application/domain layer to:
|
||||
2. **Feature flag** verified:
|
||||
```python
|
||||
from noteflow.config.settings import get_settings
|
||||
assert get_settings().feature_flags.calendar_sync_enabled
|
||||
assert get_settings().feature_flags.calendar_enabled
|
||||
```
|
||||
|
||||
**If any verification fails**: Complete Sprint 0 first.
|
||||
@@ -1057,7 +1057,7 @@ class CalendarService:
|
||||
Authorization URL to redirect user to.
|
||||
"""
|
||||
settings = get_settings()
|
||||
if not settings.feature_flags.calendar_sync_enabled:
|
||||
if not settings.feature_flags.calendar_enabled:
|
||||
msg = "Calendar sync is disabled by feature flag"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -31,6 +31,7 @@ dependencies = [
|
||||
# HTTP client for webhooks and integrations
|
||||
"httpx>=0.27",
|
||||
"weasyprint>=67.0",
|
||||
"authlib>=1.6.6",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
421
src/noteflow/application/services/calendar_service.py
Normal file
421
src/noteflow/application/services/calendar_service.py
Normal file
@@ -0,0 +1,421 @@
|
||||
"""Calendar integration service.
|
||||
|
||||
Orchestrates OAuth flow, token management, and calendar event fetching.
|
||||
Uses existing Integration entity and IntegrationRepository for persistence.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from noteflow.domain.entities.integration import Integration, IntegrationStatus, IntegrationType
|
||||
from noteflow.domain.ports.calendar import CalendarEventInfo, OAuthConnectionInfo
|
||||
from noteflow.domain.value_objects import OAuthProvider, OAuthTokens
|
||||
from noteflow.infrastructure.calendar import (
|
||||
GoogleCalendarAdapter,
|
||||
OAuthManager,
|
||||
OutlookCalendarAdapter,
|
||||
)
|
||||
from noteflow.infrastructure.calendar.google_adapter import GoogleCalendarError
|
||||
from noteflow.infrastructure.calendar.oauth_manager import OAuthError
|
||||
from noteflow.infrastructure.calendar.outlook_adapter import OutlookCalendarError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from noteflow.config.settings import CalendarSettings
|
||||
from noteflow.domain.ports.unit_of_work import UnitOfWork
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CalendarServiceError(Exception):
|
||||
"""Calendar service operation failed."""
|
||||
|
||||
|
||||
class CalendarService:
|
||||
"""Calendar integration service.
|
||||
|
||||
Orchestrates OAuth flow and calendar event fetching. Uses:
|
||||
- IntegrationRepository for Integration entity CRUD
|
||||
- IntegrationRepository.get_secrets/set_secrets for encrypted token storage
|
||||
- OAuthManager for PKCE OAuth flow
|
||||
- GoogleCalendarAdapter/OutlookCalendarAdapter for provider APIs
|
||||
"""
|
||||
|
||||
# Default workspace ID for single-user mode
|
||||
DEFAULT_WORKSPACE_ID = UUID("00000000-0000-0000-0000-000000000000")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uow_factory: Callable[[], UnitOfWork],
|
||||
settings: CalendarSettings,
|
||||
oauth_manager: OAuthManager | None = None,
|
||||
google_adapter: GoogleCalendarAdapter | None = None,
|
||||
outlook_adapter: OutlookCalendarAdapter | None = None,
|
||||
) -> None:
|
||||
"""Initialize calendar service.
|
||||
|
||||
Args:
|
||||
uow_factory: Factory function returning UnitOfWork instances.
|
||||
settings: Calendar settings with OAuth credentials.
|
||||
oauth_manager: Optional OAuth manager (created from settings if not provided).
|
||||
google_adapter: Optional Google adapter (created if not provided).
|
||||
outlook_adapter: Optional Outlook adapter (created if not provided).
|
||||
"""
|
||||
self._uow_factory = uow_factory
|
||||
self._settings = settings
|
||||
self._oauth_manager = oauth_manager or OAuthManager(settings)
|
||||
self._google_adapter = google_adapter or GoogleCalendarAdapter()
|
||||
self._outlook_adapter = outlook_adapter or OutlookCalendarAdapter()
|
||||
|
||||
async def initiate_oauth(
|
||||
self,
|
||||
provider: str,
|
||||
redirect_uri: str | None = None,
|
||||
) -> tuple[str, str]:
|
||||
"""Start OAuth flow for a calendar provider.
|
||||
|
||||
Args:
|
||||
provider: Provider name ('google' or 'outlook').
|
||||
redirect_uri: Optional override for OAuth callback URI.
|
||||
|
||||
Returns:
|
||||
Tuple of (authorization_url, state_token).
|
||||
|
||||
Raises:
|
||||
CalendarServiceError: If provider is invalid or credentials not configured.
|
||||
"""
|
||||
oauth_provider = self._parse_provider(provider)
|
||||
effective_redirect = redirect_uri or self._settings.redirect_uri
|
||||
|
||||
try:
|
||||
auth_url, state = self._oauth_manager.initiate_auth(
|
||||
provider=oauth_provider,
|
||||
redirect_uri=effective_redirect,
|
||||
)
|
||||
logger.info("Initiated OAuth flow for provider=%s", provider)
|
||||
return auth_url, state
|
||||
except OAuthError as e:
|
||||
raise CalendarServiceError(str(e)) from e
|
||||
|
||||
async def complete_oauth(
|
||||
self,
|
||||
provider: str,
|
||||
code: str,
|
||||
state: str,
|
||||
) -> bool:
|
||||
"""Complete OAuth flow and store tokens.
|
||||
|
||||
Creates or updates Integration entity with CALENDAR type and
|
||||
stores encrypted tokens via IntegrationRepository.set_secrets.
|
||||
|
||||
Args:
|
||||
provider: Provider name ('google' or 'outlook').
|
||||
code: Authorization code from OAuth callback.
|
||||
state: State parameter from OAuth callback.
|
||||
|
||||
Returns:
|
||||
True if OAuth completed successfully.
|
||||
|
||||
Raises:
|
||||
CalendarServiceError: If OAuth exchange fails.
|
||||
"""
|
||||
oauth_provider = self._parse_provider(provider)
|
||||
|
||||
try:
|
||||
tokens = await self._oauth_manager.complete_auth(
|
||||
provider=oauth_provider,
|
||||
code=code,
|
||||
state=state,
|
||||
)
|
||||
except OAuthError as e:
|
||||
raise CalendarServiceError(f"OAuth failed: {e}") from e
|
||||
|
||||
# Get user email from provider
|
||||
try:
|
||||
email = await self._get_user_email(oauth_provider, tokens.access_token)
|
||||
except (GoogleCalendarError, OutlookCalendarError) as e:
|
||||
raise CalendarServiceError(f"Failed to get user email: {e}") from e
|
||||
|
||||
# Persist integration and tokens
|
||||
async with self._uow_factory() as uow:
|
||||
integration = await uow.integrations.get_by_provider(
|
||||
provider=provider,
|
||||
integration_type=IntegrationType.CALENDAR.value,
|
||||
)
|
||||
|
||||
if integration is None:
|
||||
integration = Integration.create(
|
||||
workspace_id=self.DEFAULT_WORKSPACE_ID,
|
||||
name=f"{provider.title()} Calendar",
|
||||
integration_type=IntegrationType.CALENDAR,
|
||||
config={"provider": provider},
|
||||
)
|
||||
await uow.integrations.create(integration)
|
||||
else:
|
||||
integration.config["provider"] = provider
|
||||
|
||||
integration.connect(provider_email=email)
|
||||
await uow.integrations.update(integration)
|
||||
|
||||
# Store encrypted tokens
|
||||
await uow.integrations.set_secrets(
|
||||
integration_id=integration.id,
|
||||
secrets=tokens.to_secrets_dict(),
|
||||
)
|
||||
await uow.commit()
|
||||
|
||||
logger.info("Completed OAuth for provider=%s, email=%s", provider, email)
|
||||
return True
|
||||
|
||||
async def get_connection_status(self, provider: str) -> OAuthConnectionInfo:
|
||||
"""Get OAuth connection status for a provider.
|
||||
|
||||
Args:
|
||||
provider: Provider name ('google' or 'outlook').
|
||||
|
||||
Returns:
|
||||
OAuthConnectionInfo with status and details.
|
||||
"""
|
||||
async with self._uow_factory() as uow:
|
||||
integration = await uow.integrations.get_by_provider(
|
||||
provider=provider,
|
||||
integration_type=IntegrationType.CALENDAR.value,
|
||||
)
|
||||
|
||||
if integration is None:
|
||||
return OAuthConnectionInfo(
|
||||
provider=provider,
|
||||
status="disconnected",
|
||||
)
|
||||
|
||||
# Check token expiry
|
||||
secrets = await uow.integrations.get_secrets(integration.id)
|
||||
expires_at = None
|
||||
status = self._map_integration_status(integration.status)
|
||||
|
||||
if secrets and integration.is_connected:
|
||||
try:
|
||||
tokens = OAuthTokens.from_secrets_dict(secrets)
|
||||
expires_at = tokens.expires_at
|
||||
if tokens.is_expired():
|
||||
status = "expired"
|
||||
except (KeyError, ValueError):
|
||||
status = "error"
|
||||
|
||||
return OAuthConnectionInfo(
|
||||
provider=provider,
|
||||
status=status,
|
||||
email=integration.provider_email,
|
||||
expires_at=expires_at,
|
||||
error_message=integration.error_message,
|
||||
)
|
||||
|
||||
async def disconnect(self, provider: str) -> bool:
|
||||
"""Disconnect OAuth integration and revoke tokens.
|
||||
|
||||
Args:
|
||||
provider: Provider name ('google' or 'outlook').
|
||||
|
||||
Returns:
|
||||
True if disconnected successfully.
|
||||
"""
|
||||
oauth_provider = self._parse_provider(provider)
|
||||
|
||||
async with self._uow_factory() as uow:
|
||||
integration = await uow.integrations.get_by_provider(
|
||||
provider=provider,
|
||||
integration_type=IntegrationType.CALENDAR.value,
|
||||
)
|
||||
|
||||
if integration is None:
|
||||
return False
|
||||
|
||||
# Get tokens before deletion for revocation
|
||||
secrets = await uow.integrations.get_secrets(integration.id)
|
||||
access_token = secrets.get("access_token") if secrets else None
|
||||
|
||||
# Delete integration (cascades to secrets)
|
||||
await uow.integrations.delete(integration.id)
|
||||
await uow.commit()
|
||||
|
||||
# Revoke tokens with provider (best-effort)
|
||||
if access_token:
|
||||
try:
|
||||
await self._oauth_manager.revoke_tokens(oauth_provider, access_token)
|
||||
except OAuthError as e:
|
||||
logger.warning(
|
||||
"Failed to revoke tokens for provider=%s: %s",
|
||||
provider,
|
||||
e,
|
||||
)
|
||||
|
||||
logger.info("Disconnected provider=%s", provider)
|
||||
return True
|
||||
|
||||
async def list_calendar_events(
|
||||
self,
|
||||
provider: str | None = None,
|
||||
hours_ahead: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[CalendarEventInfo]:
|
||||
"""Fetch calendar events from connected providers.
|
||||
|
||||
Args:
|
||||
provider: Optional provider to fetch from (fetches all if None).
|
||||
hours_ahead: Hours to look ahead (defaults to settings).
|
||||
limit: Maximum events per provider (defaults to settings).
|
||||
|
||||
Returns:
|
||||
List of calendar events sorted by start time.
|
||||
|
||||
Raises:
|
||||
CalendarServiceError: If no providers connected or fetch fails.
|
||||
"""
|
||||
effective_hours = hours_ahead or self._settings.sync_hours_ahead
|
||||
effective_limit = limit or self._settings.max_events
|
||||
|
||||
events: list[CalendarEventInfo] = []
|
||||
|
||||
if provider:
|
||||
provider_events = await self._fetch_provider_events(
|
||||
provider=provider,
|
||||
hours_ahead=effective_hours,
|
||||
limit=effective_limit,
|
||||
)
|
||||
events.extend(provider_events)
|
||||
else:
|
||||
# Fetch from all connected providers
|
||||
for p in ["google", "outlook"]:
|
||||
try:
|
||||
provider_events = await self._fetch_provider_events(
|
||||
provider=p,
|
||||
hours_ahead=effective_hours,
|
||||
limit=effective_limit,
|
||||
)
|
||||
events.extend(provider_events)
|
||||
except CalendarServiceError:
|
||||
continue # Skip disconnected providers
|
||||
|
||||
# Sort by start time
|
||||
events.sort(key=lambda e: e.start_time)
|
||||
return events
|
||||
|
||||
async def _fetch_provider_events(
|
||||
self,
|
||||
provider: str,
|
||||
hours_ahead: int,
|
||||
limit: int,
|
||||
) -> list[CalendarEventInfo]:
|
||||
"""Fetch events from a specific provider with token refresh."""
|
||||
oauth_provider = self._parse_provider(provider)
|
||||
|
||||
async with self._uow_factory() as uow:
|
||||
integration = await uow.integrations.get_by_provider(
|
||||
provider=provider,
|
||||
integration_type=IntegrationType.CALENDAR.value,
|
||||
)
|
||||
|
||||
if integration is None or not integration.is_connected:
|
||||
raise CalendarServiceError(f"Provider {provider} not connected")
|
||||
|
||||
secrets = await uow.integrations.get_secrets(integration.id)
|
||||
if not secrets:
|
||||
raise CalendarServiceError(f"No tokens for provider {provider}")
|
||||
|
||||
try:
|
||||
tokens = OAuthTokens.from_secrets_dict(secrets)
|
||||
except (KeyError, ValueError) as e:
|
||||
raise CalendarServiceError(f"Invalid tokens: {e}") from e
|
||||
|
||||
# Refresh if expired
|
||||
if tokens.is_expired() and tokens.refresh_token:
|
||||
try:
|
||||
tokens = await self._oauth_manager.refresh_tokens(
|
||||
provider=oauth_provider,
|
||||
refresh_token=tokens.refresh_token,
|
||||
)
|
||||
await uow.integrations.set_secrets(
|
||||
integration_id=integration.id,
|
||||
secrets=tokens.to_secrets_dict(),
|
||||
)
|
||||
await uow.commit()
|
||||
except OAuthError as e:
|
||||
integration.mark_error(f"Token refresh failed: {e}")
|
||||
await uow.integrations.update(integration)
|
||||
await uow.commit()
|
||||
raise CalendarServiceError(f"Token refresh failed: {e}") from e
|
||||
|
||||
# Fetch events
|
||||
try:
|
||||
events = await self._fetch_events(
|
||||
oauth_provider,
|
||||
tokens.access_token,
|
||||
hours_ahead,
|
||||
limit,
|
||||
)
|
||||
integration.record_sync()
|
||||
await uow.integrations.update(integration)
|
||||
await uow.commit()
|
||||
return events
|
||||
except (GoogleCalendarError, OutlookCalendarError) as e:
|
||||
integration.mark_error(str(e))
|
||||
await uow.integrations.update(integration)
|
||||
await uow.commit()
|
||||
raise CalendarServiceError(str(e)) from e
|
||||
|
||||
async def _fetch_events(
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
access_token: str,
|
||||
hours_ahead: int,
|
||||
limit: int,
|
||||
) -> list[CalendarEventInfo]:
|
||||
"""Fetch events from provider API."""
|
||||
adapter = self._get_adapter(provider)
|
||||
return await adapter.list_events(
|
||||
access_token=access_token,
|
||||
hours_ahead=hours_ahead,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
async def _get_user_email(
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
access_token: str,
|
||||
) -> str:
|
||||
"""Get user email from provider API."""
|
||||
adapter = self._get_adapter(provider)
|
||||
return await adapter.get_user_email(access_token)
|
||||
|
||||
def _get_adapter(
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
) -> GoogleCalendarAdapter | OutlookCalendarAdapter:
|
||||
"""Get calendar adapter for provider."""
|
||||
if provider == OAuthProvider.GOOGLE:
|
||||
return self._google_adapter
|
||||
return self._outlook_adapter
|
||||
|
||||
@staticmethod
|
||||
def _parse_provider(provider: str) -> OAuthProvider:
|
||||
"""Parse and validate provider string."""
|
||||
try:
|
||||
return OAuthProvider(provider.lower())
|
||||
except ValueError as e:
|
||||
raise CalendarServiceError(
|
||||
f"Invalid provider: {provider}. Must be 'google' or 'outlook'."
|
||||
) from e
|
||||
|
||||
@staticmethod
|
||||
def _map_integration_status(status: IntegrationStatus) -> str:
|
||||
"""Map IntegrationStatus to connection status string."""
|
||||
mapping = {
|
||||
IntegrationStatus.CONNECTED: "connected",
|
||||
IntegrationStatus.DISCONNECTED: "disconnected",
|
||||
IntegrationStatus.ERROR: "error",
|
||||
}
|
||||
return mapping.get(status, "disconnected")
|
||||
@@ -84,13 +84,14 @@ class ExportService:
|
||||
ValueError: If meeting not found.
|
||||
"""
|
||||
async with self._uow:
|
||||
meeting = await self._uow.meetings.get(meeting_id)
|
||||
if meeting is None:
|
||||
raise ValueError(f"Meeting {meeting_id} not found")
|
||||
found_meeting = await self._uow.meetings.get(meeting_id)
|
||||
if not found_meeting:
|
||||
msg = f"Meeting {meeting_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
segments = await self._uow.segments.get_by_meeting(meeting_id)
|
||||
exporter = self._get_exporter(fmt)
|
||||
return exporter.export(meeting, segments)
|
||||
return exporter.export(found_meeting, segments)
|
||||
|
||||
async def export_to_file(
|
||||
self,
|
||||
@@ -123,7 +124,10 @@ class ExportService:
|
||||
output_path = output_path.with_suffix(exporter.file_extension)
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path.write_text(content, encoding="utf-8")
|
||||
if isinstance(content, bytes):
|
||||
output_path.write_bytes(content)
|
||||
else:
|
||||
output_path.write_text(content, encoding="utf-8")
|
||||
return output_path
|
||||
|
||||
def _infer_format_from_extension(self, extension: str) -> ExportFormat:
|
||||
|
||||
@@ -90,13 +90,13 @@ class MeetingService:
|
||||
"""List meetings with optional filtering.
|
||||
|
||||
Args:
|
||||
states: Optional list of states to filter by.
|
||||
limit: Maximum number of meetings to return.
|
||||
offset: Number of meetings to skip.
|
||||
sort_desc: Sort by created_at descending if True.
|
||||
states: Filter to specific meeting states (None = all).
|
||||
limit: Maximum results to return.
|
||||
offset: Number of results to skip for pagination.
|
||||
sort_desc: If True, newest meetings first.
|
||||
|
||||
Returns:
|
||||
Tuple of (meetings list, total count).
|
||||
Tuple of (meeting sequence, total matching count).
|
||||
"""
|
||||
async with self._uow:
|
||||
return await self._uow.meetings.list_all(
|
||||
|
||||
263
src/noteflow/application/services/ner_service.py
Normal file
263
src/noteflow/application/services/ner_service.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""Named Entity Recognition application service.
|
||||
|
||||
Orchestrates NER extraction, caching, and persistence following hexagonal architecture:
|
||||
gRPC mixin → NerService (application) → NerEngine (infrastructure)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.config.settings import get_settings
|
||||
from noteflow.domain.entities.named_entity import NamedEntity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from noteflow.domain.ports.ner import NerPort
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWork
|
||||
|
||||
UoWFactory = Callable[[], SqlAlchemyUnitOfWork]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionResult:
|
||||
"""Result of entity extraction.
|
||||
|
||||
Attributes:
|
||||
entities: List of extracted named entities.
|
||||
cached: Whether the result came from cache.
|
||||
total_count: Total number of entities.
|
||||
"""
|
||||
|
||||
entities: Sequence[NamedEntity]
|
||||
cached: bool
|
||||
total_count: int
|
||||
|
||||
|
||||
class NerService:
|
||||
"""Application service for Named Entity Recognition.
|
||||
|
||||
Provides a clean interface for NER operations, abstracting away
|
||||
the infrastructure details (spaCy engine, database persistence).
|
||||
|
||||
Orchestrates:
|
||||
- Feature flag checking
|
||||
- Cache lookup (return existing entities if available)
|
||||
- Extraction via NerPort adapter
|
||||
- Persistence via UnitOfWork
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ner_engine: NerPort,
|
||||
uow_factory: UoWFactory,
|
||||
) -> None:
|
||||
"""Initialize NER service.
|
||||
|
||||
Args:
|
||||
ner_engine: NER engine implementation (infrastructure adapter).
|
||||
uow_factory: Factory for creating Unit of Work instances.
|
||||
"""
|
||||
self._ner_engine = ner_engine
|
||||
self._uow_factory = uow_factory
|
||||
self._extraction_lock = asyncio.Lock()
|
||||
self._model_load_lock = asyncio.Lock()
|
||||
|
||||
async def extract_entities(
|
||||
self,
|
||||
meeting_id: MeetingId,
|
||||
force_refresh: bool = False,
|
||||
) -> ExtractionResult:
|
||||
"""Extract named entities from a meeting's transcript.
|
||||
|
||||
Checks for cached results first, unless force_refresh is True.
|
||||
Persists new extractions to the database.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting to extract entities from.
|
||||
force_refresh: If True, re-extract even if cached results exist.
|
||||
|
||||
Returns:
|
||||
ExtractionResult with entities and metadata.
|
||||
|
||||
Raises:
|
||||
ValueError: If meeting not found or has no segments.
|
||||
RuntimeError: If NER feature is disabled.
|
||||
"""
|
||||
self._check_feature_enabled()
|
||||
|
||||
# Check cache and get segments in one transaction
|
||||
cached_or_segments = await self._get_cached_or_segments(meeting_id, force_refresh)
|
||||
if isinstance(cached_or_segments, ExtractionResult):
|
||||
return cached_or_segments
|
||||
|
||||
segments = cached_or_segments
|
||||
|
||||
# Extract and persist
|
||||
entities = await self._extract_with_lock(segments)
|
||||
for entity in entities:
|
||||
entity.meeting_id = meeting_id
|
||||
|
||||
await self._persist_entities(meeting_id, entities, force_refresh)
|
||||
|
||||
logger.info(
|
||||
"Extracted %d entities from meeting %s (%d segments)",
|
||||
len(entities),
|
||||
meeting_id,
|
||||
len(segments),
|
||||
)
|
||||
return ExtractionResult(entities=entities, cached=False, total_count=len(entities))
|
||||
|
||||
def _check_feature_enabled(self) -> None:
|
||||
"""Raise if NER feature is disabled."""
|
||||
settings = get_settings()
|
||||
if not settings.feature_flags.ner_enabled:
|
||||
raise RuntimeError("NER extraction is disabled by feature flag")
|
||||
|
||||
async def _get_cached_or_segments(
|
||||
self,
|
||||
meeting_id: MeetingId,
|
||||
force_refresh: bool,
|
||||
) -> ExtractionResult | list[tuple[int, str]]:
|
||||
"""Check cache and return cached result or segments for extraction."""
|
||||
async with self._uow_factory() as uow:
|
||||
if not force_refresh:
|
||||
cached = await uow.entities.get_by_meeting(meeting_id)
|
||||
if cached:
|
||||
logger.debug("Returning %d cached entities for meeting %s", len(cached), meeting_id)
|
||||
return ExtractionResult(entities=cached, cached=True, total_count=len(cached))
|
||||
|
||||
meeting = await uow.meetings.get(meeting_id)
|
||||
if not meeting:
|
||||
raise ValueError(f"Meeting {meeting_id} not found")
|
||||
|
||||
# Load segments separately (not eagerly loaded on meeting)
|
||||
segments = await uow.segments.get_by_meeting(meeting_id)
|
||||
if not segments:
|
||||
logger.debug("Meeting %s has no segments", meeting_id)
|
||||
return ExtractionResult(entities=[], cached=False, total_count=0)
|
||||
|
||||
return [(s.segment_id, s.text) for s in segments]
|
||||
|
||||
async def _persist_entities(
|
||||
self,
|
||||
meeting_id: MeetingId,
|
||||
entities: list[NamedEntity],
|
||||
force_refresh: bool,
|
||||
) -> None:
|
||||
"""Persist extracted entities to database."""
|
||||
async with self._uow_factory() as uow:
|
||||
if force_refresh:
|
||||
await uow.entities.delete_by_meeting(meeting_id)
|
||||
await uow.entities.save_batch(entities)
|
||||
await uow.commit()
|
||||
|
||||
async def _extract_with_lock(
|
||||
self,
|
||||
segments: list[tuple[int, str]],
|
||||
) -> list[NamedEntity]:
|
||||
"""Extract entities with concurrency control.
|
||||
|
||||
Ensures only one extraction runs at a time and handles
|
||||
lazy model loading safely.
|
||||
|
||||
Args:
|
||||
segments: List of (segment_id, text) tuples.
|
||||
|
||||
Returns:
|
||||
List of extracted entities.
|
||||
"""
|
||||
async with self._extraction_lock:
|
||||
# Ensure model is loaded (thread-safe)
|
||||
if not self._ner_engine.is_ready():
|
||||
async with self._model_load_lock:
|
||||
if not self._ner_engine.is_ready():
|
||||
# Warm up model with a simple extraction
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._ner_engine.extract("warm up"),
|
||||
)
|
||||
|
||||
# Extract entities in executor (CPU-bound)
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
self._ner_engine.extract_from_segments,
|
||||
segments,
|
||||
)
|
||||
|
||||
async def get_entities(self, meeting_id: MeetingId) -> Sequence[NamedEntity]:
|
||||
"""Get cached entities for a meeting (no extraction).
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting UUID.
|
||||
|
||||
Returns:
|
||||
List of entities (empty if not extracted yet).
|
||||
"""
|
||||
async with self._uow_factory() as uow:
|
||||
return await uow.entities.get_by_meeting(meeting_id)
|
||||
|
||||
async def pin_entity(self, entity_id: UUID, is_pinned: bool = True) -> bool:
|
||||
"""Mark an entity as user-verified (pinned).
|
||||
|
||||
Args:
|
||||
entity_id: Entity UUID.
|
||||
is_pinned: New pinned status.
|
||||
|
||||
Returns:
|
||||
True if entity was found and updated.
|
||||
"""
|
||||
async with self._uow_factory() as uow:
|
||||
result = await uow.entities.update_pinned(entity_id, is_pinned)
|
||||
if result:
|
||||
await uow.commit()
|
||||
return result
|
||||
|
||||
async def clear_entities(self, meeting_id: MeetingId) -> int:
|
||||
"""Delete all entities for a meeting.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting UUID.
|
||||
|
||||
Returns:
|
||||
Number of deleted entities.
|
||||
"""
|
||||
async with self._uow_factory() as uow:
|
||||
count = await uow.entities.delete_by_meeting(meeting_id)
|
||||
await uow.commit()
|
||||
logger.info("Cleared %d entities for meeting %s", count, meeting_id)
|
||||
return count
|
||||
|
||||
async def has_entities(self, meeting_id: MeetingId) -> bool:
|
||||
"""Check if a meeting has extracted entities.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting UUID.
|
||||
|
||||
Returns:
|
||||
True if at least one entity exists.
|
||||
"""
|
||||
async with self._uow_factory() as uow:
|
||||
return await uow.entities.exists_for_meeting(meeting_id)
|
||||
|
||||
def is_engine_ready(self) -> bool:
|
||||
"""Check if NER engine model is loaded and ready for extraction.
|
||||
|
||||
Verifies both that the spaCy model is loaded and that the NER
|
||||
feature is enabled via configuration.
|
||||
|
||||
Returns:
|
||||
True if the engine is ready and NER is enabled.
|
||||
"""
|
||||
settings = get_settings()
|
||||
return settings.feature_flags.ner_enabled and self._ner_engine.is_ready()
|
||||
@@ -79,12 +79,24 @@ class DownloadReport:
|
||||
|
||||
@property
|
||||
def success_count(self) -> int:
|
||||
"""Count of successful downloads."""
|
||||
"""Count of successful downloads.
|
||||
|
||||
Returns:
|
||||
Number of results with success=True, 0 if no results.
|
||||
"""
|
||||
if not self.results:
|
||||
return 0
|
||||
return sum(1 for r in self.results if r.success)
|
||||
|
||||
@property
|
||||
def failure_count(self) -> int:
|
||||
"""Count of failed downloads."""
|
||||
"""Count of failed downloads.
|
||||
|
||||
Returns:
|
||||
Number of results with success=False, 0 if no results.
|
||||
"""
|
||||
if not self.results:
|
||||
return 0
|
||||
return sum(1 for r in self.results if not r.success)
|
||||
|
||||
|
||||
|
||||
@@ -234,6 +234,68 @@ class FeatureFlags(BaseSettings):
|
||||
]
|
||||
|
||||
|
||||
class CalendarSettings(BaseSettings):
|
||||
"""Calendar integration OAuth and sync settings.
|
||||
|
||||
Environment variables use NOTEFLOW_CALENDAR_ prefix:
|
||||
NOTEFLOW_CALENDAR_GOOGLE_CLIENT_ID: Google OAuth client ID
|
||||
NOTEFLOW_CALENDAR_GOOGLE_CLIENT_SECRET: Google OAuth client secret
|
||||
NOTEFLOW_CALENDAR_OUTLOOK_CLIENT_ID: Microsoft OAuth client ID
|
||||
NOTEFLOW_CALENDAR_OUTLOOK_CLIENT_SECRET: Microsoft OAuth client secret
|
||||
NOTEFLOW_CALENDAR_REDIRECT_URI: OAuth callback URI
|
||||
NOTEFLOW_CALENDAR_SYNC_HOURS_AHEAD: Hours to look ahead for events
|
||||
NOTEFLOW_CALENDAR_MAX_EVENTS: Maximum events to fetch
|
||||
NOTEFLOW_CALENDAR_SYNC_INTERVAL_MINUTES: Sync interval in minutes
|
||||
"""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="NOTEFLOW_CALENDAR_",
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# Google OAuth
|
||||
google_client_id: Annotated[
|
||||
str,
|
||||
Field(default="", description="Google OAuth client ID"),
|
||||
]
|
||||
google_client_secret: Annotated[
|
||||
str,
|
||||
Field(default="", description="Google OAuth client secret"),
|
||||
]
|
||||
|
||||
# Microsoft OAuth
|
||||
outlook_client_id: Annotated[
|
||||
str,
|
||||
Field(default="", description="Microsoft OAuth client ID"),
|
||||
]
|
||||
outlook_client_secret: Annotated[
|
||||
str,
|
||||
Field(default="", description="Microsoft OAuth client secret"),
|
||||
]
|
||||
|
||||
# OAuth redirect
|
||||
redirect_uri: Annotated[
|
||||
str,
|
||||
Field(default="noteflow://oauth/callback", description="OAuth callback URI"),
|
||||
]
|
||||
|
||||
# Sync settings
|
||||
sync_hours_ahead: Annotated[
|
||||
int,
|
||||
Field(default=24, ge=1, le=168, description="Hours to look ahead for events"),
|
||||
]
|
||||
max_events: Annotated[
|
||||
int,
|
||||
Field(default=20, ge=1, le=100, description="Maximum events to fetch"),
|
||||
]
|
||||
sync_interval_minutes: Annotated[
|
||||
int,
|
||||
Field(default=15, ge=1, le=1440, description="Sync interval in minutes"),
|
||||
]
|
||||
|
||||
|
||||
class Settings(TriggerSettings):
|
||||
"""Application settings loaded from environment variables.
|
||||
|
||||
@@ -347,6 +409,14 @@ class Settings(TriggerSettings):
|
||||
"""Return database URL as string."""
|
||||
return str(self.database_url)
|
||||
|
||||
@property
|
||||
def feature_flags(self) -> FeatureFlags:
|
||||
"""Return cached feature flags instance.
|
||||
|
||||
Provides convenient access to feature flags via settings object.
|
||||
"""
|
||||
return get_feature_flags()
|
||||
|
||||
|
||||
def _load_settings() -> Settings:
|
||||
"""Load settings from environment.
|
||||
@@ -398,3 +468,18 @@ def get_feature_flags() -> FeatureFlags:
|
||||
Cached FeatureFlags instance loaded from environment.
|
||||
"""
|
||||
return _load_feature_flags()
|
||||
|
||||
|
||||
def _load_calendar_settings() -> CalendarSettings:
|
||||
"""Load calendar settings from environment."""
|
||||
return CalendarSettings.model_validate({})
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_calendar_settings() -> CalendarSettings:
|
||||
"""Get cached calendar settings instance.
|
||||
|
||||
Returns:
|
||||
Cached CalendarSettings instance loaded from environment.
|
||||
"""
|
||||
return _load_calendar_settings()
|
||||
|
||||
@@ -1,15 +1,22 @@
|
||||
"""Domain entities for NoteFlow."""
|
||||
|
||||
from .annotation import Annotation
|
||||
from .integration import Integration, IntegrationStatus, IntegrationType
|
||||
from .meeting import Meeting
|
||||
from .named_entity import EntityCategory, NamedEntity
|
||||
from .segment import Segment, WordTiming
|
||||
from .summary import ActionItem, KeyPoint, Summary
|
||||
|
||||
__all__ = [
|
||||
"ActionItem",
|
||||
"Annotation",
|
||||
"EntityCategory",
|
||||
"Integration",
|
||||
"IntegrationStatus",
|
||||
"IntegrationType",
|
||||
"KeyPoint",
|
||||
"Meeting",
|
||||
"NamedEntity",
|
||||
"Segment",
|
||||
"Summary",
|
||||
"WordTiming",
|
||||
|
||||
123
src/noteflow/domain/entities/integration.py
Normal file
123
src/noteflow/domain/entities/integration.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Integration entity for external service connections."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from noteflow.domain.utils.time import utc_now
|
||||
|
||||
|
||||
class IntegrationType(StrEnum):
|
||||
"""Types of integrations supported."""
|
||||
|
||||
AUTH = "auth"
|
||||
EMAIL = "email"
|
||||
CALENDAR = "calendar"
|
||||
PKM = "pkm"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class IntegrationStatus(StrEnum):
|
||||
"""Status of an integration connection."""
|
||||
|
||||
DISCONNECTED = "disconnected"
|
||||
CONNECTED = "connected"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Integration:
|
||||
"""External service integration entity.
|
||||
|
||||
Represents a connection to an external service like Google Calendar,
|
||||
Microsoft Outlook, Notion, etc.
|
||||
"""
|
||||
|
||||
id: UUID
|
||||
workspace_id: UUID
|
||||
name: str
|
||||
type: IntegrationType
|
||||
status: IntegrationStatus = IntegrationStatus.DISCONNECTED
|
||||
config: dict[str, object] = field(default_factory=dict)
|
||||
last_sync: datetime | None = None
|
||||
error_message: str | None = None
|
||||
created_at: datetime = field(default_factory=utc_now)
|
||||
updated_at: datetime = field(default_factory=utc_now)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
workspace_id: UUID,
|
||||
name: str,
|
||||
integration_type: IntegrationType,
|
||||
config: dict[str, object] | None = None,
|
||||
) -> Integration:
|
||||
"""Create a new integration.
|
||||
|
||||
Args:
|
||||
workspace_id: ID of the workspace this integration belongs to.
|
||||
name: Display name for the integration.
|
||||
integration_type: Type of integration.
|
||||
config: Optional configuration dictionary.
|
||||
|
||||
Returns:
|
||||
New Integration instance.
|
||||
"""
|
||||
now = utc_now()
|
||||
return cls(
|
||||
id=uuid4(),
|
||||
workspace_id=workspace_id,
|
||||
name=name,
|
||||
type=integration_type,
|
||||
status=IntegrationStatus.DISCONNECTED,
|
||||
config=config or {},
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
def connect(self, provider_email: str | None = None) -> None:
|
||||
"""Mark integration as connected.
|
||||
|
||||
Args:
|
||||
provider_email: Optional email of the authenticated account.
|
||||
"""
|
||||
self.status = IntegrationStatus.CONNECTED
|
||||
self.error_message = None
|
||||
self.updated_at = utc_now()
|
||||
if provider_email:
|
||||
self.config["provider_email"] = provider_email
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Mark integration as disconnected."""
|
||||
self.status = IntegrationStatus.DISCONNECTED
|
||||
self.error_message = None
|
||||
self.updated_at = utc_now()
|
||||
|
||||
def mark_error(self, message: str) -> None:
|
||||
"""Mark integration as having an error.
|
||||
|
||||
Args:
|
||||
message: Error message describing the issue.
|
||||
"""
|
||||
self.status = IntegrationStatus.ERROR
|
||||
self.error_message = message
|
||||
self.updated_at = utc_now()
|
||||
|
||||
def record_sync(self) -> None:
|
||||
"""Record a successful sync timestamp."""
|
||||
self.last_sync = utc_now()
|
||||
self.updated_at = self.last_sync
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if integration is connected."""
|
||||
return self.status == IntegrationStatus.CONNECTED
|
||||
|
||||
@property
|
||||
def provider_email(self) -> str | None:
|
||||
"""Get the provider email if available."""
|
||||
email = self.config.get("provider_email")
|
||||
return str(email) if email else None
|
||||
@@ -197,7 +197,7 @@ class Meeting:
|
||||
"""Concatenate all segment text."""
|
||||
return " ".join(s.text for s in self.segments)
|
||||
|
||||
def is_active(self) -> bool:
|
||||
def is_in_active_state(self) -> bool:
|
||||
"""Check if meeting is in an active state (created or recording).
|
||||
|
||||
Note: STOPPING is not considered active as it's transitioning to stopped.
|
||||
|
||||
146
src/noteflow/domain/entities/named_entity.py
Normal file
146
src/noteflow/domain/entities/named_entity.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Named entity domain entity for NER extraction."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
|
||||
|
||||
class EntityCategory(Enum):
|
||||
"""Categories for named entities.
|
||||
|
||||
Maps to spaCy entity types:
|
||||
PERSON -> PERSON
|
||||
ORG -> COMPANY
|
||||
GPE/LOC/FAC -> LOCATION
|
||||
PRODUCT/WORK_OF_ART -> PRODUCT
|
||||
DATE/TIME -> DATE
|
||||
|
||||
Note: TECHNICAL and ACRONYM are placeholders for future custom pattern
|
||||
matching (not currently mapped from spaCy's default NER model).
|
||||
"""
|
||||
|
||||
PERSON = "person"
|
||||
COMPANY = "company"
|
||||
PRODUCT = "product"
|
||||
TECHNICAL = "technical" # Future: custom pattern matching
|
||||
ACRONYM = "acronym" # Future: custom pattern matching
|
||||
LOCATION = "location"
|
||||
DATE = "date"
|
||||
OTHER = "other"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str) -> EntityCategory:
|
||||
"""Convert string to EntityCategory.
|
||||
|
||||
Args:
|
||||
value: Category string value.
|
||||
|
||||
Returns:
|
||||
Corresponding EntityCategory.
|
||||
|
||||
Raises:
|
||||
ValueError: If value is not a valid category.
|
||||
"""
|
||||
try:
|
||||
return cls(value.lower())
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid entity category: {value}") from e
|
||||
|
||||
|
||||
@dataclass
|
||||
class NamedEntity:
|
||||
"""A named entity extracted from a meeting transcript.
|
||||
|
||||
Represents a person, company, product, location, or other notable term
|
||||
identified in the meeting transcript via NER processing.
|
||||
"""
|
||||
|
||||
id: UUID = field(default_factory=uuid4)
|
||||
meeting_id: MeetingId | None = None
|
||||
text: str = ""
|
||||
normalized_text: str = ""
|
||||
category: EntityCategory = EntityCategory.OTHER
|
||||
segment_ids: list[int] = field(default_factory=list)
|
||||
confidence: float = 0.0
|
||||
is_pinned: bool = False
|
||||
|
||||
# Database primary key (set after persistence)
|
||||
db_id: UUID | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate entity data and compute normalized_text."""
|
||||
if not 0.0 <= self.confidence <= 1.0:
|
||||
raise ValueError(f"confidence must be between 0 and 1, got {self.confidence}")
|
||||
|
||||
# Auto-compute normalized_text if not provided
|
||||
if self.text and not self.normalized_text:
|
||||
self.normalized_text = self.text.lower().strip()
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
text: str,
|
||||
category: EntityCategory,
|
||||
segment_ids: list[int],
|
||||
confidence: float,
|
||||
meeting_id: MeetingId | None = None,
|
||||
) -> NamedEntity:
|
||||
"""Create a new named entity with validation and normalization.
|
||||
|
||||
Factory method that handles text normalization, segment deduplication,
|
||||
and confidence validation before entity construction.
|
||||
|
||||
Args:
|
||||
text: The entity text as it appears in transcript.
|
||||
category: Classification category.
|
||||
segment_ids: Segments where entity appears (will be deduplicated and sorted).
|
||||
confidence: Extraction confidence (0.0-1.0).
|
||||
meeting_id: Optional meeting association.
|
||||
|
||||
Returns:
|
||||
New NamedEntity instance with normalized fields.
|
||||
|
||||
Raises:
|
||||
ValueError: If text is empty or confidence is out of range.
|
||||
"""
|
||||
# Validate required text
|
||||
stripped_text = text.strip()
|
||||
if not stripped_text:
|
||||
raise ValueError("Entity text cannot be empty")
|
||||
|
||||
# Normalize and deduplicate segment_ids
|
||||
unique_segments = sorted(set(segment_ids))
|
||||
|
||||
return cls(
|
||||
text=stripped_text,
|
||||
normalized_text=stripped_text.lower(),
|
||||
category=category,
|
||||
segment_ids=unique_segments,
|
||||
confidence=confidence,
|
||||
meeting_id=meeting_id,
|
||||
)
|
||||
|
||||
@property
|
||||
def occurrence_count(self) -> int:
|
||||
"""Number of unique segments where this entity appears.
|
||||
|
||||
Returns:
|
||||
Count of distinct segment IDs. Returns 0 if segment_ids is empty.
|
||||
"""
|
||||
if not self.segment_ids:
|
||||
return 0
|
||||
return len(self.segment_ids)
|
||||
|
||||
def merge_segments(self, other_segment_ids: list[int]) -> None:
|
||||
"""Merge segment IDs from another occurrence of this entity.
|
||||
|
||||
Args:
|
||||
other_segment_ids: Segment IDs from another occurrence.
|
||||
"""
|
||||
self.segment_ids = sorted(set(self.segment_ids) | set(other_segment_ids))
|
||||
@@ -1,13 +1,21 @@
|
||||
"""Domain ports (interfaces) for NoteFlow."""
|
||||
|
||||
from .calendar import (
|
||||
CalendarEventInfo,
|
||||
CalendarPort,
|
||||
OAuthConnectionInfo,
|
||||
OAuthPort,
|
||||
)
|
||||
from .diarization import (
|
||||
CancellationError,
|
||||
DiarizationError,
|
||||
JobAlreadyActiveError,
|
||||
JobNotFoundError,
|
||||
)
|
||||
from .ner import NerPort
|
||||
from .repositories import (
|
||||
AnnotationRepository,
|
||||
EntityRepository,
|
||||
MeetingRepository,
|
||||
SegmentRepository,
|
||||
SummaryRepository,
|
||||
@@ -16,11 +24,17 @@ from .unit_of_work import UnitOfWork
|
||||
|
||||
__all__ = [
|
||||
"AnnotationRepository",
|
||||
"CalendarEventInfo",
|
||||
"CalendarPort",
|
||||
"CancellationError",
|
||||
"DiarizationError",
|
||||
"EntityRepository",
|
||||
"JobAlreadyActiveError",
|
||||
"JobNotFoundError",
|
||||
"MeetingRepository",
|
||||
"NerPort",
|
||||
"OAuthConnectionInfo",
|
||||
"OAuthPort",
|
||||
"SegmentRepository",
|
||||
"SummaryRepository",
|
||||
"UnitOfWork",
|
||||
|
||||
168
src/noteflow/domain/ports/calendar.py
Normal file
168
src/noteflow/domain/ports/calendar.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Calendar integration port interfaces.
|
||||
|
||||
Defines protocols for OAuth operations and calendar event fetching.
|
||||
Implementations live in infrastructure/calendar/.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.domain.value_objects import OAuthProvider, OAuthTokens
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class CalendarEventInfo:
|
||||
"""Calendar event information from provider API.
|
||||
|
||||
Richer than trigger-layer CalendarEvent, used for API responses
|
||||
and caching in CalendarEventModel.
|
||||
"""
|
||||
|
||||
id: str
|
||||
title: str
|
||||
start_time: datetime
|
||||
end_time: datetime
|
||||
attendees: tuple[str, ...]
|
||||
location: str | None = None
|
||||
description: str | None = None
|
||||
meeting_url: str | None = None
|
||||
is_recurring: bool = False
|
||||
is_all_day: bool = False
|
||||
provider: str = ""
|
||||
raw: dict[str, object] | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class OAuthConnectionInfo:
|
||||
"""OAuth connection status for a provider."""
|
||||
|
||||
provider: str
|
||||
status: str # "connected", "disconnected", "error", "expired"
|
||||
email: str | None = None
|
||||
expires_at: datetime | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class OAuthPort(Protocol):
|
||||
"""Port for OAuth operations.
|
||||
|
||||
Handles PKCE flow for Google and Outlook calendar providers.
|
||||
"""
|
||||
|
||||
def initiate_auth(
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
redirect_uri: str,
|
||||
) -> tuple[str, str]:
|
||||
"""Generate OAuth authorization URL with PKCE.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider (google or outlook).
|
||||
redirect_uri: Callback URL after authorization.
|
||||
|
||||
Returns:
|
||||
Tuple of (authorization_url, state_token).
|
||||
"""
|
||||
...
|
||||
|
||||
async def complete_auth(
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
code: str,
|
||||
state: str,
|
||||
) -> OAuthTokens:
|
||||
"""Exchange authorization code for tokens.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider.
|
||||
code: Authorization code from callback.
|
||||
state: State parameter from callback.
|
||||
|
||||
Returns:
|
||||
OAuth tokens.
|
||||
|
||||
Raises:
|
||||
ValueError: If state doesn't match or code exchange fails.
|
||||
"""
|
||||
...
|
||||
|
||||
async def refresh_tokens(
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
refresh_token: str,
|
||||
) -> OAuthTokens:
|
||||
"""Refresh expired access token.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider.
|
||||
refresh_token: Refresh token from previous exchange.
|
||||
|
||||
Returns:
|
||||
New OAuth tokens.
|
||||
|
||||
Raises:
|
||||
ValueError: If refresh fails.
|
||||
"""
|
||||
...
|
||||
|
||||
async def revoke_tokens(
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
access_token: str,
|
||||
) -> bool:
|
||||
"""Revoke OAuth tokens with provider.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider.
|
||||
access_token: Access token to revoke.
|
||||
|
||||
Returns:
|
||||
True if revoked successfully.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class CalendarPort(Protocol):
|
||||
"""Port for calendar event operations.
|
||||
|
||||
Each provider (Google, Outlook) implements this protocol.
|
||||
"""
|
||||
|
||||
async def list_events(
|
||||
self,
|
||||
access_token: str,
|
||||
hours_ahead: int = 24,
|
||||
limit: int = 20,
|
||||
) -> list[CalendarEventInfo]:
|
||||
"""Fetch upcoming calendar events.
|
||||
|
||||
Args:
|
||||
access_token: Provider OAuth access token.
|
||||
hours_ahead: Hours to look ahead.
|
||||
limit: Maximum number of events.
|
||||
|
||||
Returns:
|
||||
Calendar events sorted chronologically.
|
||||
|
||||
Raises:
|
||||
ValueError: If token is invalid or API call fails.
|
||||
"""
|
||||
...
|
||||
|
||||
async def get_user_email(self, access_token: str) -> str:
|
||||
"""Get authenticated user's email address.
|
||||
|
||||
Args:
|
||||
access_token: Valid OAuth access token.
|
||||
|
||||
Returns:
|
||||
User's email address.
|
||||
|
||||
Raises:
|
||||
ValueError: If token is invalid or API call fails.
|
||||
"""
|
||||
...
|
||||
53
src/noteflow/domain/ports/ner.py
Normal file
53
src/noteflow/domain/ports/ner.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""NER port interface (hexagonal architecture)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.domain.entities.named_entity import NamedEntity
|
||||
|
||||
|
||||
class NerPort(Protocol):
|
||||
"""Port for named entity recognition.
|
||||
|
||||
This is the domain port that the application layer uses.
|
||||
Infrastructure adapters (like NerEngine) implement this protocol.
|
||||
"""
|
||||
|
||||
def extract(self, text: str) -> list[NamedEntity]:
|
||||
"""Extract named entities from text.
|
||||
|
||||
Args:
|
||||
text: Input text to analyze.
|
||||
|
||||
Returns:
|
||||
List of extracted entities (deduplicated by normalized text).
|
||||
"""
|
||||
...
|
||||
|
||||
def extract_from_segments(
|
||||
self,
|
||||
segments: list[tuple[int, str]],
|
||||
) -> list[NamedEntity]:
|
||||
"""Extract entities from multiple segments with tracking.
|
||||
|
||||
Processes each segment individually and tracks which segments
|
||||
contain each entity via segment_ids. Entities appearing in
|
||||
multiple segments are deduplicated with merged segment lists.
|
||||
|
||||
Args:
|
||||
segments: List of (segment_id, text) tuples.
|
||||
|
||||
Returns:
|
||||
Entities with segment_ids populated (deduplicated across segments).
|
||||
"""
|
||||
...
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Check if the NER engine is loaded and ready.
|
||||
|
||||
Returns:
|
||||
True if model is loaded and ready for inference.
|
||||
"""
|
||||
...
|
||||
@@ -7,12 +7,12 @@ from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.domain.entities import Annotation, Meeting, Segment, Summary
|
||||
from uuid import UUID
|
||||
|
||||
from noteflow.domain.entities import Annotation, Integration, Meeting, Segment, Summary
|
||||
from noteflow.domain.entities.named_entity import NamedEntity
|
||||
from noteflow.domain.value_objects import AnnotationId, MeetingId, MeetingState
|
||||
from noteflow.infrastructure.persistence.repositories import (
|
||||
DiarizationJob,
|
||||
StreamingTurn,
|
||||
)
|
||||
from noteflow.infrastructure.persistence.repositories import DiarizationJob, StreamingTurn
|
||||
|
||||
|
||||
class MeetingRepository(Protocol):
|
||||
@@ -493,3 +493,186 @@ class PreferencesRepository(Protocol):
|
||||
True if deleted, False if not found.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class EntityRepository(Protocol):
|
||||
"""Repository protocol for NamedEntity operations (NER results)."""
|
||||
|
||||
async def save(self, entity: NamedEntity) -> NamedEntity:
|
||||
"""Save or update a named entity.
|
||||
|
||||
Args:
|
||||
entity: Entity to save.
|
||||
|
||||
Returns:
|
||||
Saved entity with db_id populated.
|
||||
"""
|
||||
...
|
||||
|
||||
async def save_batch(self, entities: Sequence[NamedEntity]) -> Sequence[NamedEntity]:
|
||||
"""Save multiple entities efficiently.
|
||||
|
||||
Args:
|
||||
entities: Entities to save.
|
||||
|
||||
Returns:
|
||||
Saved entities with db_ids populated.
|
||||
"""
|
||||
...
|
||||
|
||||
async def get(self, entity_id: UUID) -> NamedEntity | None:
|
||||
"""Get entity by ID.
|
||||
|
||||
Args:
|
||||
entity_id: Entity UUID.
|
||||
|
||||
Returns:
|
||||
Entity if found, None otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
async def get_by_meeting(self, meeting_id: MeetingId) -> Sequence[NamedEntity]:
|
||||
"""Get all entities for a meeting.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting UUID.
|
||||
|
||||
Returns:
|
||||
List of entities.
|
||||
"""
|
||||
...
|
||||
|
||||
async def delete_by_meeting(self, meeting_id: MeetingId) -> int:
|
||||
"""Delete all entities for a meeting.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting UUID.
|
||||
|
||||
Returns:
|
||||
Number of deleted entities.
|
||||
"""
|
||||
...
|
||||
|
||||
async def update_pinned(self, entity_id: UUID, is_pinned: bool) -> bool:
|
||||
"""Update the pinned status of an entity.
|
||||
|
||||
Args:
|
||||
entity_id: Entity UUID.
|
||||
is_pinned: New pinned status.
|
||||
|
||||
Returns:
|
||||
True if entity was found and updated.
|
||||
"""
|
||||
...
|
||||
|
||||
async def exists_for_meeting(self, meeting_id: MeetingId) -> bool:
|
||||
"""Check if any entities exist for a meeting.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting UUID.
|
||||
|
||||
Returns:
|
||||
True if at least one entity exists.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class IntegrationRepository(Protocol):
|
||||
"""Repository protocol for external service integrations.
|
||||
|
||||
Manages OAuth-connected services like calendars, email providers, and PKM tools.
|
||||
"""
|
||||
|
||||
async def get(self, integration_id: UUID) -> Integration | None:
|
||||
"""Retrieve an integration by ID.
|
||||
|
||||
Args:
|
||||
integration_id: Integration UUID.
|
||||
|
||||
Returns:
|
||||
Integration if found, None otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
async def get_by_provider(
|
||||
self,
|
||||
provider: str,
|
||||
integration_type: str | None = None,
|
||||
) -> Integration | None:
|
||||
"""Retrieve an integration by provider name.
|
||||
|
||||
Args:
|
||||
provider: Provider name (e.g., 'google', 'outlook').
|
||||
integration_type: Optional type filter.
|
||||
|
||||
Returns:
|
||||
Integration if found, None otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
async def create(self, integration: Integration) -> Integration:
|
||||
"""Persist a new integration.
|
||||
|
||||
Args:
|
||||
integration: Integration to create.
|
||||
|
||||
Returns:
|
||||
Created integration.
|
||||
"""
|
||||
...
|
||||
|
||||
async def update(self, integration: Integration) -> Integration:
|
||||
"""Update an existing integration.
|
||||
|
||||
Args:
|
||||
integration: Integration with updated fields.
|
||||
|
||||
Returns:
|
||||
Updated integration.
|
||||
|
||||
Raises:
|
||||
ValueError: If integration does not exist.
|
||||
"""
|
||||
...
|
||||
|
||||
async def delete(self, integration_id: UUID) -> bool:
|
||||
"""Delete an integration and its secrets.
|
||||
|
||||
Args:
|
||||
integration_id: Integration UUID.
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found.
|
||||
"""
|
||||
...
|
||||
|
||||
async def get_secrets(self, integration_id: UUID) -> dict[str, str] | None:
|
||||
"""Get encrypted secrets for an integration.
|
||||
|
||||
Args:
|
||||
integration_id: Integration UUID.
|
||||
|
||||
Returns:
|
||||
Dictionary of secret key-value pairs, or None if not found.
|
||||
"""
|
||||
...
|
||||
|
||||
async def set_secrets(self, integration_id: UUID, secrets: dict[str, str]) -> None:
|
||||
"""Store encrypted secrets for an integration.
|
||||
|
||||
Args:
|
||||
integration_id: Integration UUID.
|
||||
secrets: Dictionary of secret key-value pairs.
|
||||
"""
|
||||
...
|
||||
|
||||
async def list_by_type(self, integration_type: str) -> Sequence[Integration]:
|
||||
"""List integrations by type.
|
||||
|
||||
Args:
|
||||
integration_type: Integration type (e.g., 'calendar', 'email').
|
||||
|
||||
Returns:
|
||||
List of integrations of the specified type.
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -8,6 +8,8 @@ if TYPE_CHECKING:
|
||||
from .repositories import (
|
||||
AnnotationRepository,
|
||||
DiarizationJobRepository,
|
||||
EntityRepository,
|
||||
IntegrationRepository,
|
||||
MeetingRepository,
|
||||
PreferencesRepository,
|
||||
SegmentRepository,
|
||||
@@ -65,6 +67,16 @@ class UnitOfWork(Protocol):
|
||||
"""Access the preferences repository."""
|
||||
...
|
||||
|
||||
@property
|
||||
def entities(self) -> EntityRepository:
|
||||
"""Access the entities repository for NER results."""
|
||||
...
|
||||
|
||||
@property
|
||||
def integrations(self) -> IntegrationRepository:
|
||||
"""Access the integrations repository for OAuth connections."""
|
||||
...
|
||||
|
||||
# Feature flags for DB-only capabilities
|
||||
@property
|
||||
def supports_annotations(self) -> bool:
|
||||
@@ -90,6 +102,23 @@ class UnitOfWork(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def supports_entities(self) -> bool:
|
||||
"""Check if NER entity persistence is supported.
|
||||
|
||||
Returns False for memory-only implementations.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def supports_integrations(self) -> bool:
|
||||
"""Check if OAuth integration persistence is supported.
|
||||
|
||||
Returns False for memory-only implementations that don't support
|
||||
encrypted secret storage.
|
||||
"""
|
||||
...
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""Enter the unit of work context.
|
||||
|
||||
|
||||
@@ -58,8 +58,13 @@ class SummarizationResult:
|
||||
|
||||
@property
|
||||
def is_success(self) -> bool:
|
||||
"""Check if summarization succeeded with content."""
|
||||
return bool(self.summary.executive_summary)
|
||||
"""Check if summarization succeeded with meaningful content.
|
||||
|
||||
Returns:
|
||||
True if executive_summary has non-whitespace content.
|
||||
"""
|
||||
summary_text = self.summary.executive_summary
|
||||
return bool(summary_text and summary_text.strip())
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@@ -16,6 +16,6 @@ def utc_now() -> datetime:
|
||||
consistent timezone-aware datetime handling.
|
||||
|
||||
Returns:
|
||||
Current datetime in UTC timezone.
|
||||
Current datetime in UTC timezone with microsecond precision.
|
||||
"""
|
||||
return datetime.now(UTC)
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum, IntEnum
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum, IntEnum, StrEnum
|
||||
from typing import NewType
|
||||
from uuid import UUID
|
||||
|
||||
@@ -79,3 +81,88 @@ class MeetingState(IntEnum):
|
||||
MeetingState.ERROR: set(), # Terminal state
|
||||
}
|
||||
return target in valid_transitions.get(self, set())
|
||||
|
||||
|
||||
# OAuth value objects for calendar integration
|
||||
|
||||
|
||||
class OAuthProvider(StrEnum):
|
||||
"""Supported OAuth providers for calendar integration."""
|
||||
|
||||
GOOGLE = "google"
|
||||
OUTLOOK = "outlook"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class OAuthState:
|
||||
"""CSRF state for OAuth PKCE flow.
|
||||
|
||||
Stored in-memory with TTL for security. Contains the code_verifier
|
||||
needed to complete the PKCE exchange.
|
||||
"""
|
||||
|
||||
state: str
|
||||
provider: OAuthProvider
|
||||
redirect_uri: str
|
||||
code_verifier: str
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the state has expired."""
|
||||
return datetime.now(self.created_at.tzinfo) > self.expires_at
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class OAuthTokens:
|
||||
"""OAuth tokens returned from provider.
|
||||
|
||||
Stored encrypted in IntegrationSecretModel via the repository.
|
||||
"""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str | None
|
||||
token_type: str
|
||||
expires_at: datetime
|
||||
scope: str
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the access token has expired."""
|
||||
return datetime.now(self.expires_at.tzinfo) > self.expires_at
|
||||
|
||||
def to_secrets_dict(self) -> dict[str, str]:
|
||||
"""Convert to dictionary for encrypted storage."""
|
||||
result: dict[str, str] = {
|
||||
"access_token": self.access_token,
|
||||
"token_type": self.token_type,
|
||||
"expires_at": self.expires_at.isoformat(),
|
||||
"scope": self.scope,
|
||||
}
|
||||
if self.refresh_token:
|
||||
result["refresh_token"] = self.refresh_token
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_secrets_dict(cls, secrets: dict[str, str]) -> OAuthTokens:
|
||||
"""Create from dictionary retrieved from encrypted storage.
|
||||
|
||||
Args:
|
||||
secrets: Token dictionary from keystore.
|
||||
|
||||
Returns:
|
||||
OAuthTokens with parsed expiration datetime.
|
||||
|
||||
Raises:
|
||||
KeyError: If access_token or expires_at missing.
|
||||
ValueError: If expires_at is not valid ISO format.
|
||||
"""
|
||||
# Parse expiration datetime from ISO string
|
||||
expires_at = datetime.fromisoformat(secrets["expires_at"])
|
||||
|
||||
return cls(
|
||||
access_token=secrets["access_token"],
|
||||
refresh_token=secrets.get("refresh_token"),
|
||||
token_type=secrets.get("token_type", "Bearer"),
|
||||
expires_at=expires_at,
|
||||
scope=secrets.get("scope", ""),
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from noteflow.grpc._types import ConnectionCallback, TranscriptCallback
|
||||
from noteflow.grpc._types import ConnectionCallback, TranscriptCallback, TranscriptSegment
|
||||
from noteflow.grpc.proto import noteflow_pb2_grpc
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ class ClientHost(Protocol):
|
||||
"""Background thread for audio streaming."""
|
||||
...
|
||||
|
||||
def _notify_transcript(self, segment: object) -> None:
|
||||
def _notify_transcript(self, segment: TranscriptSegment) -> None:
|
||||
"""Notify transcript callback."""
|
||||
...
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
"""gRPC service mixins for NoteFlowServicer."""
|
||||
|
||||
from .annotation import AnnotationMixin
|
||||
from .calendar import CalendarMixin
|
||||
from .diarization import DiarizationMixin
|
||||
from .diarization_job import DiarizationJobMixin
|
||||
from .entities import EntitiesMixin
|
||||
from .export import ExportMixin
|
||||
from .meeting import MeetingMixin
|
||||
from .streaming import StreamingMixin
|
||||
@@ -10,8 +12,10 @@ from .summarization import SummarizationMixin
|
||||
|
||||
__all__ = [
|
||||
"AnnotationMixin",
|
||||
"CalendarMixin",
|
||||
"DiarizationJobMixin",
|
||||
"DiarizationMixin",
|
||||
"EntitiesMixin",
|
||||
"ExportMixin",
|
||||
"MeetingMixin",
|
||||
"StreamingMixin",
|
||||
|
||||
176
src/noteflow/grpc/_mixins/calendar.py
Normal file
176
src/noteflow/grpc/_mixins/calendar.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""Calendar integration mixin for gRPC service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import grpc.aio
|
||||
|
||||
from noteflow.application.services.calendar_service import CalendarServiceError
|
||||
|
||||
from ..proto import noteflow_pb2
|
||||
from .errors import abort_internal, abort_invalid_argument, abort_unavailable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .protocols import ServicerHost
|
||||
|
||||
|
||||
class CalendarMixin:
|
||||
"""Mixin providing calendar integration functionality.
|
||||
|
||||
Requires host to implement ServicerHost protocol with _calendar_service.
|
||||
Provides OAuth flow and calendar event fetching for Google and Outlook.
|
||||
"""
|
||||
|
||||
async def ListCalendarEvents(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.ListCalendarEventsRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.ListCalendarEventsResponse:
|
||||
"""List upcoming calendar events from connected providers."""
|
||||
if self._calendar_service is None:
|
||||
await abort_unavailable(context, "Calendar integration not enabled")
|
||||
|
||||
provider = request.provider if request.provider else None
|
||||
hours_ahead = request.hours_ahead if request.hours_ahead > 0 else None
|
||||
limit = request.limit if request.limit > 0 else None
|
||||
|
||||
try:
|
||||
events = await self._calendar_service.list_calendar_events(
|
||||
provider=provider,
|
||||
hours_ahead=hours_ahead,
|
||||
limit=limit,
|
||||
)
|
||||
except CalendarServiceError as e:
|
||||
await abort_internal(context, str(e))
|
||||
|
||||
proto_events = [
|
||||
noteflow_pb2.CalendarEvent(
|
||||
id=event.id,
|
||||
title=event.title,
|
||||
start_time=int(event.start_time.timestamp()),
|
||||
end_time=int(event.end_time.timestamp()),
|
||||
location=event.location or "",
|
||||
attendees=list(event.attendees),
|
||||
meeting_url=event.meeting_url or "",
|
||||
is_recurring=event.is_recurring,
|
||||
provider=event.provider,
|
||||
)
|
||||
for event in events
|
||||
]
|
||||
|
||||
return noteflow_pb2.ListCalendarEventsResponse(
|
||||
events=proto_events,
|
||||
total_count=len(proto_events),
|
||||
)
|
||||
|
||||
async def GetCalendarProviders(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.GetCalendarProvidersRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.GetCalendarProvidersResponse:
|
||||
"""Get available calendar providers with authentication status."""
|
||||
if self._calendar_service is None:
|
||||
await abort_unavailable(context, "Calendar integration not enabled")
|
||||
|
||||
providers = []
|
||||
for provider_name, display_name in [
|
||||
("google", "Google Calendar"),
|
||||
("outlook", "Microsoft Outlook"),
|
||||
]:
|
||||
status = await self._calendar_service.get_connection_status(provider_name)
|
||||
providers.append(
|
||||
noteflow_pb2.CalendarProvider(
|
||||
name=provider_name,
|
||||
is_authenticated=status.status == "connected",
|
||||
display_name=display_name,
|
||||
)
|
||||
)
|
||||
|
||||
return noteflow_pb2.GetCalendarProvidersResponse(providers=providers)
|
||||
|
||||
async def InitiateOAuth(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.InitiateOAuthRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.InitiateOAuthResponse:
|
||||
"""Start OAuth flow for a calendar provider."""
|
||||
if self._calendar_service is None:
|
||||
await abort_unavailable(context, "Calendar integration not enabled")
|
||||
|
||||
try:
|
||||
auth_url, state = await self._calendar_service.initiate_oauth(
|
||||
provider=request.provider,
|
||||
redirect_uri=request.redirect_uri if request.redirect_uri else None,
|
||||
)
|
||||
except CalendarServiceError as e:
|
||||
await abort_invalid_argument(context, str(e))
|
||||
|
||||
return noteflow_pb2.InitiateOAuthResponse(
|
||||
auth_url=auth_url,
|
||||
state=state,
|
||||
)
|
||||
|
||||
async def CompleteOAuth(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.CompleteOAuthRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.CompleteOAuthResponse:
|
||||
"""Complete OAuth flow with authorization code."""
|
||||
if self._calendar_service is None:
|
||||
await abort_unavailable(context, "Calendar integration not enabled")
|
||||
|
||||
try:
|
||||
success = await self._calendar_service.complete_oauth(
|
||||
provider=request.provider,
|
||||
code=request.code,
|
||||
state=request.state,
|
||||
)
|
||||
except CalendarServiceError as e:
|
||||
return noteflow_pb2.CompleteOAuthResponse(
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
# Get the provider email after successful connection
|
||||
status = await self._calendar_service.get_connection_status(request.provider)
|
||||
|
||||
return noteflow_pb2.CompleteOAuthResponse(
|
||||
success=success,
|
||||
provider_email=status.email or "",
|
||||
)
|
||||
|
||||
async def GetOAuthConnectionStatus(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.GetOAuthConnectionStatusRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.GetOAuthConnectionStatusResponse:
|
||||
"""Get OAuth connection status for a provider."""
|
||||
if self._calendar_service is None:
|
||||
await abort_unavailable(context, "Calendar integration not enabled")
|
||||
|
||||
status = await self._calendar_service.get_connection_status(request.provider)
|
||||
|
||||
connection = noteflow_pb2.OAuthConnection(
|
||||
provider=status.provider,
|
||||
status=status.status,
|
||||
email=status.email or "",
|
||||
expires_at=int(status.expires_at.timestamp()) if status.expires_at else 0,
|
||||
error_message=status.error_message or "",
|
||||
integration_type=request.integration_type or "calendar",
|
||||
)
|
||||
|
||||
return noteflow_pb2.GetOAuthConnectionStatusResponse(connection=connection)
|
||||
|
||||
async def DisconnectOAuth(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.DisconnectOAuthRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.DisconnectOAuthResponse:
|
||||
"""Disconnect OAuth integration and revoke tokens."""
|
||||
if self._calendar_service is None:
|
||||
await abort_unavailable(context, "Calendar integration not enabled")
|
||||
|
||||
success = await self._calendar_service.disconnect(request.provider)
|
||||
|
||||
return noteflow_pb2.DisconnectOAuthResponse(success=success)
|
||||
98
src/noteflow/grpc/_mixins/entities.py
Normal file
98
src/noteflow/grpc/_mixins/entities.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Entity extraction gRPC mixin."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import grpc.aio
|
||||
|
||||
from ..proto import noteflow_pb2
|
||||
from .converters import parse_meeting_id_or_abort
|
||||
from .errors import abort_failed_precondition, abort_not_found
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.application.services.ner_service import NerService
|
||||
from noteflow.domain.entities.named_entity import NamedEntity
|
||||
|
||||
from .protocols import ServicerHost
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EntitiesMixin:
|
||||
"""Mixin for entity extraction RPC methods.
|
||||
|
||||
Implements the ExtractEntities RPC using the NerService application layer.
|
||||
Architecture: gRPC → NerService (application) → NerEngine (infrastructure)
|
||||
"""
|
||||
|
||||
_ner_service: NerService | None
|
||||
|
||||
async def ExtractEntities(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.ExtractEntitiesRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> noteflow_pb2.ExtractEntitiesResponse:
|
||||
"""Extract named entities from meeting transcript.
|
||||
|
||||
Delegates to NerService for extraction, caching, and persistence.
|
||||
Returns cached results if available, unless force_refresh is True.
|
||||
"""
|
||||
meeting_id = await parse_meeting_id_or_abort(request.meeting_id, context)
|
||||
|
||||
if self._ner_service is None:
|
||||
await abort_failed_precondition(
|
||||
context,
|
||||
"NER service not configured. Set NOTEFLOW_FEATURE_NER_ENABLED=true",
|
||||
)
|
||||
|
||||
try:
|
||||
result = await self._ner_service.extract_entities(
|
||||
meeting_id=meeting_id,
|
||||
force_refresh=request.force_refresh,
|
||||
)
|
||||
except ValueError:
|
||||
# Meeting not found
|
||||
await abort_not_found(context, "Meeting", request.meeting_id)
|
||||
except RuntimeError as e:
|
||||
# Feature disabled
|
||||
await abort_failed_precondition(context, str(e))
|
||||
|
||||
# Convert to proto
|
||||
proto_entities = [
|
||||
noteflow_pb2.ExtractedEntity(
|
||||
id=str(entity.id),
|
||||
text=entity.text,
|
||||
category=entity.category.value,
|
||||
segment_ids=list(entity.segment_ids),
|
||||
confidence=entity.confidence,
|
||||
is_pinned=entity.is_pinned,
|
||||
)
|
||||
for entity in result.entities
|
||||
]
|
||||
|
||||
return noteflow_pb2.ExtractEntitiesResponse(
|
||||
entities=proto_entities,
|
||||
total_count=result.total_count,
|
||||
cached=result.cached,
|
||||
)
|
||||
|
||||
|
||||
def entity_to_proto(entity: NamedEntity) -> noteflow_pb2.ExtractedEntity:
|
||||
"""Convert domain NamedEntity to proto ExtractedEntity.
|
||||
|
||||
Args:
|
||||
entity: NamedEntity domain object.
|
||||
|
||||
Returns:
|
||||
Proto ExtractedEntity message.
|
||||
"""
|
||||
return noteflow_pb2.ExtractedEntity(
|
||||
id=str(entity.id),
|
||||
text=entity.text,
|
||||
category=entity.category.value,
|
||||
segment_ids=list(entity.segment_ids),
|
||||
confidence=entity.confidence,
|
||||
is_pinned=entity.is_pinned,
|
||||
)
|
||||
@@ -12,6 +12,8 @@ from numpy.typing import NDArray
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from noteflow.application.services.calendar_service import CalendarService
|
||||
from noteflow.application.services.ner_service import NerService
|
||||
from noteflow.domain.entities import Meeting
|
||||
from noteflow.domain.ports.unit_of_work import UnitOfWork
|
||||
from noteflow.infrastructure.asr import FasterWhisperEngine, Segmenter, StreamingVad
|
||||
@@ -41,10 +43,12 @@ class ServicerHost(Protocol):
|
||||
_meetings_dir: Path
|
||||
_crypto: AesGcmCryptoBox
|
||||
|
||||
# Engines
|
||||
# Engines and services
|
||||
_asr_engine: FasterWhisperEngine | None
|
||||
_diarization_engine: DiarizationEngine | None
|
||||
_summarization_service: object | None
|
||||
_ner_service: NerService | None
|
||||
_calendar_service: CalendarService | None
|
||||
_diarization_refinement_enabled: bool
|
||||
|
||||
# Audio writers
|
||||
|
||||
@@ -185,7 +185,7 @@ class StreamingMixin:
|
||||
await repo.meetings.update(meeting)
|
||||
await repo.commit()
|
||||
|
||||
next_segment_id = await repo.segments.get_next_segment_id(meeting.id)
|
||||
next_segment_id = await repo.segments.compute_next_segment_id(meeting.id)
|
||||
self._open_meeting_audio_writer(
|
||||
meeting_id, dek, wrapped_dek, asset_path=meeting.asset_path
|
||||
)
|
||||
|
||||
@@ -80,7 +80,7 @@ class StreamingSession:
|
||||
return self._meeting_id
|
||||
|
||||
@property
|
||||
def is_active(self) -> bool:
|
||||
def is_streaming(self) -> bool:
|
||||
"""Check if the session is currently streaming."""
|
||||
return self._started and self._thread is not None and self._thread.is_alive()
|
||||
|
||||
|
||||
@@ -93,13 +93,13 @@ class MeetingStore:
|
||||
"""List meetings with optional filtering.
|
||||
|
||||
Args:
|
||||
states: Optional list of states to filter by.
|
||||
limit: Maximum number of meetings to return.
|
||||
offset: Number of meetings to skip.
|
||||
sort_desc: Sort by created_at descending if True.
|
||||
states: Filter by these states (all if None).
|
||||
limit: Max meetings per page.
|
||||
offset: Pagination offset.
|
||||
sort_desc: Sort by created_at descending.
|
||||
|
||||
Returns:
|
||||
Tuple of (meetings list, total count).
|
||||
Tuple of (paginated meeting list, total matching count).
|
||||
"""
|
||||
with self._lock:
|
||||
meetings = list(self._meetings.values())
|
||||
@@ -166,8 +166,8 @@ class MeetingStore:
|
||||
meeting.add_segment(segment)
|
||||
return meeting
|
||||
|
||||
def get_segments(self, meeting_id: str) -> list[Segment]:
|
||||
"""Get a copy of segments for a meeting.
|
||||
def fetch_segments(self, meeting_id: str) -> list[Segment]:
|
||||
"""Fetch a copy of segments for in-memory meeting.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting ID.
|
||||
@@ -277,8 +277,8 @@ class MeetingStore:
|
||||
meeting.ended_at = end_time
|
||||
return True
|
||||
|
||||
def get_next_segment_id(self, meeting_id: str) -> int:
|
||||
"""Get next segment ID for a meeting.
|
||||
def compute_next_segment_id(self, meeting_id: str) -> int:
|
||||
"""Compute next segment ID for in-memory meeting.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting ID.
|
||||
|
||||
@@ -48,8 +48,12 @@ service NoteFlowService {
|
||||
// Calendar integration (Sprint 5)
|
||||
rpc ListCalendarEvents(ListCalendarEventsRequest) returns (ListCalendarEventsResponse);
|
||||
rpc GetCalendarProviders(GetCalendarProvidersRequest) returns (GetCalendarProvidersResponse);
|
||||
rpc InitiateCalendarAuth(InitiateCalendarAuthRequest) returns (InitiateCalendarAuthResponse);
|
||||
rpc CompleteCalendarAuth(CompleteCalendarAuthRequest) returns (CompleteCalendarAuthResponse);
|
||||
|
||||
// OAuth integration (generic for calendar, email, PKM, etc.)
|
||||
rpc InitiateOAuth(InitiateOAuthRequest) returns (InitiateOAuthResponse);
|
||||
rpc CompleteOAuth(CompleteOAuthRequest) returns (CompleteOAuthResponse);
|
||||
rpc GetOAuthConnectionStatus(GetOAuthConnectionStatusRequest) returns (GetOAuthConnectionStatusResponse);
|
||||
rpc DisconnectOAuth(DisconnectOAuthRequest) returns (DisconnectOAuthResponse);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
@@ -702,15 +706,22 @@ message GetCalendarProvidersResponse {
|
||||
repeated CalendarProvider providers = 1;
|
||||
}
|
||||
|
||||
message InitiateCalendarAuthRequest {
|
||||
// Provider to authenticate: google, outlook
|
||||
// =============================================================================
|
||||
// OAuth Integration Messages (generic for calendar, email, PKM, etc.)
|
||||
// =============================================================================
|
||||
|
||||
message InitiateOAuthRequest {
|
||||
// Provider to authenticate: google, outlook, notion, etc.
|
||||
string provider = 1;
|
||||
|
||||
// Redirect URI for OAuth callback
|
||||
string redirect_uri = 2;
|
||||
|
||||
// Integration type: calendar, email, pkm, custom
|
||||
string integration_type = 3;
|
||||
}
|
||||
|
||||
message InitiateCalendarAuthResponse {
|
||||
message InitiateOAuthResponse {
|
||||
// Authorization URL to redirect user to
|
||||
string auth_url = 1;
|
||||
|
||||
@@ -718,7 +729,7 @@ message InitiateCalendarAuthResponse {
|
||||
string state = 2;
|
||||
}
|
||||
|
||||
message CompleteCalendarAuthRequest {
|
||||
message CompleteOAuthRequest {
|
||||
// Provider being authenticated
|
||||
string provider = 1;
|
||||
|
||||
@@ -729,13 +740,62 @@ message CompleteCalendarAuthRequest {
|
||||
string state = 3;
|
||||
}
|
||||
|
||||
message CompleteCalendarAuthResponse {
|
||||
message CompleteOAuthResponse {
|
||||
// Whether authentication succeeded
|
||||
bool success = 1;
|
||||
|
||||
// Error message if failed
|
||||
string error_message = 2;
|
||||
|
||||
// Email of authenticated calendar account
|
||||
// Email of authenticated account
|
||||
string provider_email = 3;
|
||||
}
|
||||
|
||||
message OAuthConnection {
|
||||
// Provider name: google, outlook, notion
|
||||
string provider = 1;
|
||||
|
||||
// Connection status: disconnected, connected, error
|
||||
string status = 2;
|
||||
|
||||
// Email of authenticated account
|
||||
string email = 3;
|
||||
|
||||
// Token expiration timestamp (Unix epoch seconds)
|
||||
int64 expires_at = 4;
|
||||
|
||||
// Error message if status is error
|
||||
string error_message = 5;
|
||||
|
||||
// Integration type: calendar, email, pkm, custom
|
||||
string integration_type = 6;
|
||||
}
|
||||
|
||||
message GetOAuthConnectionStatusRequest {
|
||||
// Provider to check: google, outlook, notion
|
||||
string provider = 1;
|
||||
|
||||
// Optional integration type filter
|
||||
string integration_type = 2;
|
||||
}
|
||||
|
||||
message GetOAuthConnectionStatusResponse {
|
||||
// Connection details
|
||||
OAuthConnection connection = 1;
|
||||
}
|
||||
|
||||
message DisconnectOAuthRequest {
|
||||
// Provider to disconnect
|
||||
string provider = 1;
|
||||
|
||||
// Optional integration type
|
||||
string integration_type = 2;
|
||||
}
|
||||
|
||||
message DisconnectOAuthResponse {
|
||||
// Whether disconnection succeeded
|
||||
bool success = 1;
|
||||
|
||||
// Error message if failed
|
||||
string error_message = 2;
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -620,15 +620,17 @@ class GetCalendarProvidersResponse(_message.Message):
|
||||
providers: _containers.RepeatedCompositeFieldContainer[CalendarProvider]
|
||||
def __init__(self, providers: _Optional[_Iterable[_Union[CalendarProvider, _Mapping]]] = ...) -> None: ...
|
||||
|
||||
class InitiateCalendarAuthRequest(_message.Message):
|
||||
__slots__ = ("provider", "redirect_uri")
|
||||
class InitiateOAuthRequest(_message.Message):
|
||||
__slots__ = ("provider", "redirect_uri", "integration_type")
|
||||
PROVIDER_FIELD_NUMBER: _ClassVar[int]
|
||||
REDIRECT_URI_FIELD_NUMBER: _ClassVar[int]
|
||||
INTEGRATION_TYPE_FIELD_NUMBER: _ClassVar[int]
|
||||
provider: str
|
||||
redirect_uri: str
|
||||
def __init__(self, provider: _Optional[str] = ..., redirect_uri: _Optional[str] = ...) -> None: ...
|
||||
integration_type: str
|
||||
def __init__(self, provider: _Optional[str] = ..., redirect_uri: _Optional[str] = ..., integration_type: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class InitiateCalendarAuthResponse(_message.Message):
|
||||
class InitiateOAuthResponse(_message.Message):
|
||||
__slots__ = ("auth_url", "state")
|
||||
AUTH_URL_FIELD_NUMBER: _ClassVar[int]
|
||||
STATE_FIELD_NUMBER: _ClassVar[int]
|
||||
@@ -636,7 +638,7 @@ class InitiateCalendarAuthResponse(_message.Message):
|
||||
state: str
|
||||
def __init__(self, auth_url: _Optional[str] = ..., state: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class CompleteCalendarAuthRequest(_message.Message):
|
||||
class CompleteOAuthRequest(_message.Message):
|
||||
__slots__ = ("provider", "code", "state")
|
||||
PROVIDER_FIELD_NUMBER: _ClassVar[int]
|
||||
CODE_FIELD_NUMBER: _ClassVar[int]
|
||||
@@ -646,7 +648,7 @@ class CompleteCalendarAuthRequest(_message.Message):
|
||||
state: str
|
||||
def __init__(self, provider: _Optional[str] = ..., code: _Optional[str] = ..., state: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class CompleteCalendarAuthResponse(_message.Message):
|
||||
class CompleteOAuthResponse(_message.Message):
|
||||
__slots__ = ("success", "error_message", "provider_email")
|
||||
SUCCESS_FIELD_NUMBER: _ClassVar[int]
|
||||
ERROR_MESSAGE_FIELD_NUMBER: _ClassVar[int]
|
||||
@@ -655,3 +657,49 @@ class CompleteCalendarAuthResponse(_message.Message):
|
||||
error_message: str
|
||||
provider_email: str
|
||||
def __init__(self, success: bool = ..., error_message: _Optional[str] = ..., provider_email: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class OAuthConnection(_message.Message):
|
||||
__slots__ = ("provider", "status", "email", "expires_at", "error_message", "integration_type")
|
||||
PROVIDER_FIELD_NUMBER: _ClassVar[int]
|
||||
STATUS_FIELD_NUMBER: _ClassVar[int]
|
||||
EMAIL_FIELD_NUMBER: _ClassVar[int]
|
||||
EXPIRES_AT_FIELD_NUMBER: _ClassVar[int]
|
||||
ERROR_MESSAGE_FIELD_NUMBER: _ClassVar[int]
|
||||
INTEGRATION_TYPE_FIELD_NUMBER: _ClassVar[int]
|
||||
provider: str
|
||||
status: str
|
||||
email: str
|
||||
expires_at: int
|
||||
error_message: str
|
||||
integration_type: str
|
||||
def __init__(self, provider: _Optional[str] = ..., status: _Optional[str] = ..., email: _Optional[str] = ..., expires_at: _Optional[int] = ..., error_message: _Optional[str] = ..., integration_type: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class GetOAuthConnectionStatusRequest(_message.Message):
|
||||
__slots__ = ("provider", "integration_type")
|
||||
PROVIDER_FIELD_NUMBER: _ClassVar[int]
|
||||
INTEGRATION_TYPE_FIELD_NUMBER: _ClassVar[int]
|
||||
provider: str
|
||||
integration_type: str
|
||||
def __init__(self, provider: _Optional[str] = ..., integration_type: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class GetOAuthConnectionStatusResponse(_message.Message):
|
||||
__slots__ = ("connection",)
|
||||
CONNECTION_FIELD_NUMBER: _ClassVar[int]
|
||||
connection: OAuthConnection
|
||||
def __init__(self, connection: _Optional[_Union[OAuthConnection, _Mapping]] = ...) -> None: ...
|
||||
|
||||
class DisconnectOAuthRequest(_message.Message):
|
||||
__slots__ = ("provider", "integration_type")
|
||||
PROVIDER_FIELD_NUMBER: _ClassVar[int]
|
||||
INTEGRATION_TYPE_FIELD_NUMBER: _ClassVar[int]
|
||||
provider: str
|
||||
integration_type: str
|
||||
def __init__(self, provider: _Optional[str] = ..., integration_type: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class DisconnectOAuthResponse(_message.Message):
|
||||
__slots__ = ("success", "error_message")
|
||||
SUCCESS_FIELD_NUMBER: _ClassVar[int]
|
||||
ERROR_MESSAGE_FIELD_NUMBER: _ClassVar[int]
|
||||
success: bool
|
||||
error_message: str
|
||||
def __init__(self, success: bool = ..., error_message: _Optional[str] = ...) -> None: ...
|
||||
|
||||
@@ -143,15 +143,25 @@ class NoteFlowServiceStub(object):
|
||||
request_serializer=noteflow__pb2.GetCalendarProvidersRequest.SerializeToString,
|
||||
response_deserializer=noteflow__pb2.GetCalendarProvidersResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.InitiateCalendarAuth = channel.unary_unary(
|
||||
'/noteflow.NoteFlowService/InitiateCalendarAuth',
|
||||
request_serializer=noteflow__pb2.InitiateCalendarAuthRequest.SerializeToString,
|
||||
response_deserializer=noteflow__pb2.InitiateCalendarAuthResponse.FromString,
|
||||
self.InitiateOAuth = channel.unary_unary(
|
||||
'/noteflow.NoteFlowService/InitiateOAuth',
|
||||
request_serializer=noteflow__pb2.InitiateOAuthRequest.SerializeToString,
|
||||
response_deserializer=noteflow__pb2.InitiateOAuthResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.CompleteCalendarAuth = channel.unary_unary(
|
||||
'/noteflow.NoteFlowService/CompleteCalendarAuth',
|
||||
request_serializer=noteflow__pb2.CompleteCalendarAuthRequest.SerializeToString,
|
||||
response_deserializer=noteflow__pb2.CompleteCalendarAuthResponse.FromString,
|
||||
self.CompleteOAuth = channel.unary_unary(
|
||||
'/noteflow.NoteFlowService/CompleteOAuth',
|
||||
request_serializer=noteflow__pb2.CompleteOAuthRequest.SerializeToString,
|
||||
response_deserializer=noteflow__pb2.CompleteOAuthResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.GetOAuthConnectionStatus = channel.unary_unary(
|
||||
'/noteflow.NoteFlowService/GetOAuthConnectionStatus',
|
||||
request_serializer=noteflow__pb2.GetOAuthConnectionStatusRequest.SerializeToString,
|
||||
response_deserializer=noteflow__pb2.GetOAuthConnectionStatusResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.DisconnectOAuth = channel.unary_unary(
|
||||
'/noteflow.NoteFlowService/DisconnectOAuth',
|
||||
request_serializer=noteflow__pb2.DisconnectOAuthRequest.SerializeToString,
|
||||
response_deserializer=noteflow__pb2.DisconnectOAuthResponse.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
@@ -297,13 +307,26 @@ class NoteFlowServiceServicer(object):
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def InitiateCalendarAuth(self, request, context):
|
||||
def InitiateOAuth(self, request, context):
|
||||
"""OAuth integration (generic for calendar, email, PKM, etc.)
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def CompleteOAuth(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def CompleteCalendarAuth(self, request, context):
|
||||
def GetOAuthConnectionStatus(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def DisconnectOAuth(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
@@ -417,15 +440,25 @@ def add_NoteFlowServiceServicer_to_server(servicer, server):
|
||||
request_deserializer=noteflow__pb2.GetCalendarProvidersRequest.FromString,
|
||||
response_serializer=noteflow__pb2.GetCalendarProvidersResponse.SerializeToString,
|
||||
),
|
||||
'InitiateCalendarAuth': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.InitiateCalendarAuth,
|
||||
request_deserializer=noteflow__pb2.InitiateCalendarAuthRequest.FromString,
|
||||
response_serializer=noteflow__pb2.InitiateCalendarAuthResponse.SerializeToString,
|
||||
'InitiateOAuth': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.InitiateOAuth,
|
||||
request_deserializer=noteflow__pb2.InitiateOAuthRequest.FromString,
|
||||
response_serializer=noteflow__pb2.InitiateOAuthResponse.SerializeToString,
|
||||
),
|
||||
'CompleteCalendarAuth': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.CompleteCalendarAuth,
|
||||
request_deserializer=noteflow__pb2.CompleteCalendarAuthRequest.FromString,
|
||||
response_serializer=noteflow__pb2.CompleteCalendarAuthResponse.SerializeToString,
|
||||
'CompleteOAuth': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.CompleteOAuth,
|
||||
request_deserializer=noteflow__pb2.CompleteOAuthRequest.FromString,
|
||||
response_serializer=noteflow__pb2.CompleteOAuthResponse.SerializeToString,
|
||||
),
|
||||
'GetOAuthConnectionStatus': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetOAuthConnectionStatus,
|
||||
request_deserializer=noteflow__pb2.GetOAuthConnectionStatusRequest.FromString,
|
||||
response_serializer=noteflow__pb2.GetOAuthConnectionStatusResponse.SerializeToString,
|
||||
),
|
||||
'DisconnectOAuth': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.DisconnectOAuth,
|
||||
request_deserializer=noteflow__pb2.DisconnectOAuthRequest.FromString,
|
||||
response_serializer=noteflow__pb2.DisconnectOAuthResponse.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
@@ -1010,7 +1043,7 @@ class NoteFlowService(object):
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def InitiateCalendarAuth(request,
|
||||
def InitiateOAuth(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
@@ -1023,9 +1056,9 @@ class NoteFlowService(object):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/noteflow.NoteFlowService/InitiateCalendarAuth',
|
||||
noteflow__pb2.InitiateCalendarAuthRequest.SerializeToString,
|
||||
noteflow__pb2.InitiateCalendarAuthResponse.FromString,
|
||||
'/noteflow.NoteFlowService/InitiateOAuth',
|
||||
noteflow__pb2.InitiateOAuthRequest.SerializeToString,
|
||||
noteflow__pb2.InitiateOAuthResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
@@ -1037,7 +1070,7 @@ class NoteFlowService(object):
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def CompleteCalendarAuth(request,
|
||||
def CompleteOAuth(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
@@ -1050,9 +1083,63 @@ class NoteFlowService(object):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/noteflow.NoteFlowService/CompleteCalendarAuth',
|
||||
noteflow__pb2.CompleteCalendarAuthRequest.SerializeToString,
|
||||
noteflow__pb2.CompleteCalendarAuthResponse.FromString,
|
||||
'/noteflow.NoteFlowService/CompleteOAuth',
|
||||
noteflow__pb2.CompleteOAuthRequest.SerializeToString,
|
||||
noteflow__pb2.CompleteOAuthResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def GetOAuthConnectionStatus(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/noteflow.NoteFlowService/GetOAuthConnectionStatus',
|
||||
noteflow__pb2.GetOAuthConnectionStatusRequest.SerializeToString,
|
||||
noteflow__pb2.GetOAuthConnectionStatusResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def DisconnectOAuth(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/noteflow.NoteFlowService/DisconnectOAuth',
|
||||
noteflow__pb2.DisconnectOAuthRequest.SerializeToString,
|
||||
noteflow__pb2.DisconnectOAuthResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
|
||||
@@ -7,17 +7,19 @@ import asyncio
|
||||
import logging
|
||||
import signal
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
|
||||
import grpc.aio
|
||||
from pydantic import ValidationError
|
||||
|
||||
from noteflow.application.services import RecoveryService
|
||||
from noteflow.application.services.ner_service import NerService
|
||||
from noteflow.application.services.summarization_service import SummarizationService
|
||||
from noteflow.config.settings import Settings, get_settings
|
||||
from noteflow.infrastructure.asr import FasterWhisperEngine
|
||||
from noteflow.infrastructure.asr.engine import VALID_MODEL_SIZES
|
||||
from noteflow.infrastructure.diarization import DiarizationEngine
|
||||
from noteflow.infrastructure.ner import NerEngine
|
||||
from noteflow.infrastructure.persistence.database import (
|
||||
create_async_session_factory,
|
||||
ensure_schema_ready,
|
||||
@@ -41,6 +43,16 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _DiarizationEngineKwargs(TypedDict, total=False):
|
||||
"""Type-safe kwargs for DiarizationEngine initialization."""
|
||||
|
||||
device: str
|
||||
hf_token: str | None
|
||||
streaming_latency: float
|
||||
min_speakers: int
|
||||
max_speakers: int
|
||||
|
||||
|
||||
class NoteFlowServer:
|
||||
"""Async gRPC server for NoteFlow."""
|
||||
|
||||
@@ -54,6 +66,7 @@ class NoteFlowServer:
|
||||
summarization_service: SummarizationService | None = None,
|
||||
diarization_engine: DiarizationEngine | None = None,
|
||||
diarization_refinement_enabled: bool = True,
|
||||
ner_service: NerService | None = None,
|
||||
) -> None:
|
||||
"""Initialize the server.
|
||||
|
||||
@@ -66,6 +79,7 @@ class NoteFlowServer:
|
||||
summarization_service: Optional summarization service for generating summaries.
|
||||
diarization_engine: Optional diarization engine for speaker identification.
|
||||
diarization_refinement_enabled: Whether to allow diarization refinement RPCs.
|
||||
ner_service: Optional NER service for entity extraction.
|
||||
"""
|
||||
self._port = port
|
||||
self._asr_model = asr_model
|
||||
@@ -75,6 +89,7 @@ class NoteFlowServer:
|
||||
self._summarization_service = summarization_service
|
||||
self._diarization_engine = diarization_engine
|
||||
self._diarization_refinement_enabled = diarization_refinement_enabled
|
||||
self._ner_service = ner_service
|
||||
self._server: grpc.aio.Server | None = None
|
||||
self._servicer: NoteFlowServicer | None = None
|
||||
|
||||
@@ -105,13 +120,14 @@ class NoteFlowServer:
|
||||
self._summarization_service = create_summarization_service()
|
||||
logger.info("Summarization service initialized (default factory)")
|
||||
|
||||
# Create servicer with session factory, summarization, and diarization
|
||||
# Create servicer with session factory, summarization, diarization, and NER
|
||||
self._servicer = NoteFlowServicer(
|
||||
asr_engine=asr_engine,
|
||||
session_factory=self._session_factory,
|
||||
summarization_service=self._summarization_service,
|
||||
diarization_engine=self._diarization_engine,
|
||||
diarization_refinement_enabled=self._diarization_refinement_enabled,
|
||||
ner_service=self._ner_service,
|
||||
)
|
||||
|
||||
# Create async gRPC server
|
||||
@@ -217,6 +233,24 @@ async def run_server_with_config(config: GrpcServerConfig) -> None:
|
||||
|
||||
summarization_service.on_consent_change = persist_consent
|
||||
|
||||
# Create NER service if enabled
|
||||
ner_service: NerService | None = None
|
||||
settings = get_settings()
|
||||
if settings.feature_flags.ner_enabled:
|
||||
if session_factory:
|
||||
logger.info("Initializing NER service (spaCy)...")
|
||||
ner_engine = NerEngine()
|
||||
ner_service = NerService(
|
||||
ner_engine=ner_engine,
|
||||
uow_factory=lambda: SqlAlchemyUnitOfWork(session_factory),
|
||||
)
|
||||
logger.info("NER service initialized (model loaded on demand)")
|
||||
else:
|
||||
logger.warning(
|
||||
"NER feature enabled but no database configured. "
|
||||
"NER requires database for entity persistence."
|
||||
)
|
||||
|
||||
# Create diarization engine if enabled
|
||||
diarization_engine: DiarizationEngine | None = None
|
||||
diarization = config.diarization
|
||||
@@ -228,7 +262,7 @@ async def run_server_with_config(config: GrpcServerConfig) -> None:
|
||||
)
|
||||
else:
|
||||
logger.info("Initializing diarization engine on %s...", diarization.device)
|
||||
diarization_kwargs: dict[str, Any] = {
|
||||
diarization_kwargs: _DiarizationEngineKwargs = {
|
||||
"device": diarization.device,
|
||||
"hf_token": diarization.hf_token,
|
||||
}
|
||||
@@ -250,6 +284,7 @@ async def run_server_with_config(config: GrpcServerConfig) -> None:
|
||||
summarization_service=summarization_service,
|
||||
diarization_engine=diarization_engine,
|
||||
diarization_refinement_enabled=diarization.refinement_enabled,
|
||||
ner_service=ner_service,
|
||||
)
|
||||
|
||||
# Set up graceful shutdown
|
||||
|
||||
@@ -27,8 +27,10 @@ from noteflow.infrastructure.security.keystore import KeyringKeyStore
|
||||
|
||||
from ._mixins import (
|
||||
AnnotationMixin,
|
||||
CalendarMixin,
|
||||
DiarizationJobMixin,
|
||||
DiarizationMixin,
|
||||
EntitiesMixin,
|
||||
ExportMixin,
|
||||
MeetingMixin,
|
||||
StreamingMixin,
|
||||
@@ -41,6 +43,8 @@ if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from noteflow.application.services.calendar_service import CalendarService
|
||||
from noteflow.application.services.ner_service import NerService
|
||||
from noteflow.application.services.summarization_service import SummarizationService
|
||||
from noteflow.infrastructure.asr import FasterWhisperEngine
|
||||
from noteflow.infrastructure.diarization import DiarizationEngine, SpeakerTurn
|
||||
@@ -56,6 +60,8 @@ class NoteFlowServicer(
|
||||
SummarizationMixin,
|
||||
AnnotationMixin,
|
||||
ExportMixin,
|
||||
EntitiesMixin,
|
||||
CalendarMixin,
|
||||
noteflow_pb2_grpc.NoteFlowServiceServicer,
|
||||
):
|
||||
"""Async gRPC service implementation for NoteFlow with PostgreSQL persistence."""
|
||||
@@ -75,6 +81,8 @@ class NoteFlowServicer(
|
||||
summarization_service: SummarizationService | None = None,
|
||||
diarization_engine: DiarizationEngine | None = None,
|
||||
diarization_refinement_enabled: bool = True,
|
||||
ner_service: NerService | None = None,
|
||||
calendar_service: CalendarService | None = None,
|
||||
) -> None:
|
||||
"""Initialize the service.
|
||||
|
||||
@@ -87,12 +95,16 @@ class NoteFlowServicer(
|
||||
summarization_service: Optional summarization service for generating summaries.
|
||||
diarization_engine: Optional diarization engine for speaker identification.
|
||||
diarization_refinement_enabled: Whether to allow post-meeting diarization refinement.
|
||||
ner_service: Optional NER service for entity extraction.
|
||||
calendar_service: Optional calendar service for OAuth and event fetching.
|
||||
"""
|
||||
self._asr_engine = asr_engine
|
||||
self._session_factory = session_factory
|
||||
self._summarization_service = summarization_service
|
||||
self._diarization_engine = diarization_engine
|
||||
self._diarization_refinement_enabled = diarization_refinement_enabled
|
||||
self._ner_service = ner_service
|
||||
self._calendar_service = calendar_service
|
||||
self._start_time = time.time()
|
||||
# Fallback to in-memory store if no database configured
|
||||
self._memory_store: MeetingStore | None = (
|
||||
|
||||
14
src/noteflow/infrastructure/calendar/__init__.py
Normal file
14
src/noteflow/infrastructure/calendar/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Calendar integration infrastructure.
|
||||
|
||||
OAuth management and provider adapters for Google Calendar and Outlook.
|
||||
"""
|
||||
|
||||
from .google_adapter import GoogleCalendarAdapter
|
||||
from .oauth_manager import OAuthManager
|
||||
from .outlook_adapter import OutlookCalendarAdapter
|
||||
|
||||
__all__ = [
|
||||
"GoogleCalendarAdapter",
|
||||
"OAuthManager",
|
||||
"OutlookCalendarAdapter",
|
||||
]
|
||||
198
src/noteflow/infrastructure/calendar/google_adapter.py
Normal file
198
src/noteflow/infrastructure/calendar/google_adapter.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""Google Calendar API adapter.
|
||||
|
||||
Implements CalendarPort for Google Calendar using the Google Calendar API v3.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import httpx
|
||||
|
||||
from noteflow.domain.ports.calendar import CalendarEventInfo, CalendarPort
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GoogleCalendarError(Exception):
|
||||
"""Google Calendar API error."""
|
||||
|
||||
|
||||
class GoogleCalendarAdapter(CalendarPort):
|
||||
"""Google Calendar API adapter.
|
||||
|
||||
Fetches calendar events and user info using Google Calendar API v3.
|
||||
Requires a valid OAuth access token with calendar.readonly scope.
|
||||
"""
|
||||
|
||||
CALENDAR_API_BASE = "https://www.googleapis.com/calendar/v3"
|
||||
USERINFO_API_URL = "https://www.googleapis.com/oauth2/v2/userinfo"
|
||||
|
||||
async def list_events(
|
||||
self,
|
||||
access_token: str,
|
||||
hours_ahead: int = 24,
|
||||
limit: int = 20,
|
||||
) -> list[CalendarEventInfo]:
|
||||
"""Fetch upcoming calendar events from Google Calendar.
|
||||
|
||||
Args:
|
||||
access_token: Google OAuth token with calendar.readonly scope.
|
||||
hours_ahead: Number of hours to look ahead.
|
||||
limit: Maximum events to fetch (capped by API).
|
||||
|
||||
Returns:
|
||||
List of Google Calendar events ordered by start time.
|
||||
|
||||
Raises:
|
||||
GoogleCalendarError: If API call fails.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
time_min = now.isoformat()
|
||||
time_max = (now + timedelta(hours=hours_ahead)).isoformat()
|
||||
|
||||
url = f"{self.CALENDAR_API_BASE}/calendars/primary/events"
|
||||
params: dict[str, str | int] = {
|
||||
"timeMin": time_min,
|
||||
"timeMax": time_max,
|
||||
"maxResults": limit,
|
||||
"singleEvents": "true",
|
||||
"orderBy": "startTime",
|
||||
}
|
||||
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url, params=params, headers=headers)
|
||||
|
||||
if response.status_code == 401:
|
||||
raise GoogleCalendarError("Access token expired or invalid")
|
||||
|
||||
if response.status_code != 200:
|
||||
error_msg = response.text
|
||||
logger.error("Google Calendar API error: %s", error_msg)
|
||||
raise GoogleCalendarError(f"API error: {error_msg}")
|
||||
|
||||
data = response.json()
|
||||
items = data.get("items", [])
|
||||
|
||||
return [self._parse_event(item) for item in items]
|
||||
|
||||
async def get_user_email(self, access_token: str) -> str:
|
||||
"""Get authenticated user's email address.
|
||||
|
||||
Args:
|
||||
access_token: Valid OAuth access token.
|
||||
|
||||
Returns:
|
||||
User's email address.
|
||||
|
||||
Raises:
|
||||
GoogleCalendarError: If API call fails.
|
||||
"""
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(self.USERINFO_API_URL, headers=headers)
|
||||
|
||||
if response.status_code == 401:
|
||||
raise GoogleCalendarError("Access token expired or invalid")
|
||||
|
||||
if response.status_code != 200:
|
||||
error_msg = response.text
|
||||
logger.error("Google userinfo API error: %s", error_msg)
|
||||
raise GoogleCalendarError(f"API error: {error_msg}")
|
||||
|
||||
data = response.json()
|
||||
email = data.get("email")
|
||||
if not email:
|
||||
raise GoogleCalendarError("No email in userinfo response")
|
||||
|
||||
return str(email)
|
||||
|
||||
def _parse_event(self, item: dict[str, object]) -> CalendarEventInfo:
|
||||
"""Parse Google Calendar event into CalendarEventInfo."""
|
||||
event_id = str(item.get("id", ""))
|
||||
title = str(item.get("summary", "Untitled"))
|
||||
|
||||
# Parse start/end times
|
||||
start_data = item.get("start", {})
|
||||
end_data = item.get("end", {})
|
||||
|
||||
is_all_day = "date" in start_data if isinstance(start_data, dict) else False
|
||||
start_time = self._parse_datetime(start_data)
|
||||
end_time = self._parse_datetime(end_data)
|
||||
|
||||
# Parse attendees
|
||||
attendees_data = item.get("attendees", [])
|
||||
attendees = tuple(
|
||||
str(a.get("email", ""))
|
||||
for a in attendees_data
|
||||
if isinstance(a, dict) and a.get("email")
|
||||
) if isinstance(attendees_data, list) else ()
|
||||
|
||||
# Extract meeting URL from conferenceData or hangoutLink
|
||||
meeting_url = self._extract_meeting_url(item)
|
||||
|
||||
# Check if recurring
|
||||
is_recurring = bool(item.get("recurringEventId"))
|
||||
|
||||
location = item.get("location")
|
||||
description = item.get("description")
|
||||
|
||||
return CalendarEventInfo(
|
||||
id=event_id,
|
||||
title=title,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
attendees=attendees,
|
||||
location=str(location) if location else None,
|
||||
description=str(description) if description else None,
|
||||
meeting_url=meeting_url,
|
||||
is_recurring=is_recurring,
|
||||
is_all_day=is_all_day,
|
||||
provider="google",
|
||||
raw=dict(item) if isinstance(item, dict) else None,
|
||||
)
|
||||
|
||||
def _parse_datetime(self, dt_data: object) -> datetime:
|
||||
"""Parse datetime from Google Calendar format."""
|
||||
if not isinstance(dt_data, dict):
|
||||
return datetime.now(UTC)
|
||||
|
||||
# All-day events use "date", timed events use "dateTime"
|
||||
dt_str = dt_data.get("dateTime") or dt_data.get("date")
|
||||
|
||||
if not dt_str or not isinstance(dt_str, str):
|
||||
return datetime.now(UTC)
|
||||
|
||||
# Handle Z suffix for UTC
|
||||
if dt_str.endswith("Z"):
|
||||
dt_str = f"{dt_str[:-1]}+00:00"
|
||||
|
||||
try:
|
||||
return datetime.fromisoformat(dt_str)
|
||||
except ValueError:
|
||||
logger.warning("Failed to parse datetime: %s", dt_str)
|
||||
return datetime.now(UTC)
|
||||
|
||||
def _extract_meeting_url(self, item: dict[str, object]) -> str | None:
|
||||
"""Extract video meeting URL from event data."""
|
||||
# Try hangoutLink first (Google Meet)
|
||||
hangout_link = item.get("hangoutLink")
|
||||
if isinstance(hangout_link, str) and hangout_link:
|
||||
return hangout_link
|
||||
|
||||
# Try conferenceData for other providers
|
||||
conference_data = item.get("conferenceData")
|
||||
if isinstance(conference_data, dict):
|
||||
entry_points = conference_data.get("entryPoints", [])
|
||||
if isinstance(entry_points, list):
|
||||
for entry in entry_points:
|
||||
if isinstance(entry, dict) and entry.get("entryPointType") == "video":
|
||||
uri = entry.get("uri")
|
||||
if isinstance(uri, str) and uri:
|
||||
return uri
|
||||
|
||||
return None
|
||||
424
src/noteflow/infrastructure/calendar/oauth_manager.py
Normal file
424
src/noteflow/infrastructure/calendar/oauth_manager.py
Normal file
@@ -0,0 +1,424 @@
|
||||
"""OAuth manager with PKCE support for calendar providers.
|
||||
|
||||
Implements the OAuthPort protocol using httpx for async HTTP requests.
|
||||
Uses PKCE (Proof Key for Code Exchange) for secure OAuth 2.0 flow.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
|
||||
from noteflow.domain.ports.calendar import OAuthPort
|
||||
from noteflow.domain.value_objects import OAuthProvider, OAuthState, OAuthTokens
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.config.settings import CalendarSettings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthError(Exception):
|
||||
"""OAuth operation failed."""
|
||||
|
||||
|
||||
class OAuthManager(OAuthPort):
|
||||
"""OAuth manager implementing PKCE flow for Google and Outlook.
|
||||
|
||||
Manages OAuth authorization URLs, token exchange, refresh, and revocation.
|
||||
State tokens are stored in-memory with TTL for security.
|
||||
|
||||
Deployment Note:
|
||||
State tokens are stored in-memory (self._pending_states dict). This works
|
||||
correctly for single-worker deployments (current Tauri desktop model) but
|
||||
would require database-backed state storage for multi-worker/load-balanced
|
||||
deployments where OAuth initiate and complete requests may hit different
|
||||
workers.
|
||||
"""
|
||||
|
||||
# OAuth endpoints
|
||||
GOOGLE_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
GOOGLE_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
GOOGLE_REVOKE_URL = "https://oauth2.googleapis.com/revoke"
|
||||
|
||||
OUTLOOK_AUTH_URL = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
||||
OUTLOOK_TOKEN_URL = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
||||
OUTLOOK_REVOKE_URL = "https://login.microsoftonline.com/common/oauth2/v2.0/logout"
|
||||
|
||||
# OAuth scopes
|
||||
GOOGLE_SCOPES: ClassVar[list[str]] = [
|
||||
"https://www.googleapis.com/auth/calendar.readonly",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"openid",
|
||||
]
|
||||
OUTLOOK_SCOPES: ClassVar[list[str]] = [
|
||||
"Calendars.Read",
|
||||
"User.Read",
|
||||
"offline_access",
|
||||
"openid",
|
||||
]
|
||||
|
||||
# State TTL (10 minutes)
|
||||
STATE_TTL_SECONDS = 600
|
||||
|
||||
def __init__(self, settings: CalendarSettings) -> None:
|
||||
"""Initialize OAuth manager with calendar settings.
|
||||
|
||||
Args:
|
||||
settings: Calendar settings with OAuth credentials.
|
||||
"""
|
||||
self._settings = settings
|
||||
self._pending_states: dict[str, OAuthState] = {}
|
||||
|
||||
def initiate_auth(
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
redirect_uri: str,
|
||||
) -> tuple[str, str]:
|
||||
"""Generate OAuth authorization URL with PKCE.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider (google or outlook).
|
||||
redirect_uri: Callback URL after authorization.
|
||||
|
||||
Returns:
|
||||
Tuple of (authorization_url, state_token).
|
||||
|
||||
Raises:
|
||||
OAuthError: If provider credentials are not configured.
|
||||
"""
|
||||
self._cleanup_expired_states()
|
||||
self._validate_provider_config(provider)
|
||||
|
||||
# Generate PKCE code verifier and challenge
|
||||
code_verifier = self._generate_code_verifier()
|
||||
code_challenge = self._generate_code_challenge(code_verifier)
|
||||
|
||||
# Generate state token for CSRF protection
|
||||
state_token = secrets.token_urlsafe(32)
|
||||
|
||||
# Store state for validation during callback
|
||||
now = datetime.now(UTC)
|
||||
oauth_state = OAuthState(
|
||||
state=state_token,
|
||||
provider=provider,
|
||||
redirect_uri=redirect_uri,
|
||||
code_verifier=code_verifier,
|
||||
created_at=now,
|
||||
expires_at=now + timedelta(seconds=self.STATE_TTL_SECONDS),
|
||||
)
|
||||
self._pending_states[state_token] = oauth_state
|
||||
|
||||
# Build authorization URL
|
||||
auth_url = self._build_auth_url(
|
||||
provider=provider,
|
||||
redirect_uri=redirect_uri,
|
||||
state=state_token,
|
||||
code_challenge=code_challenge,
|
||||
)
|
||||
|
||||
logger.info("Initiated OAuth flow for provider=%s", provider.value)
|
||||
return auth_url, state_token
|
||||
|
||||
async def complete_auth(
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
code: str,
|
||||
state: str,
|
||||
) -> OAuthTokens:
|
||||
"""Exchange authorization code for tokens.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider.
|
||||
code: Authorization code from callback.
|
||||
state: State parameter from callback.
|
||||
|
||||
Returns:
|
||||
OAuth tokens.
|
||||
|
||||
Raises:
|
||||
OAuthError: If state is invalid, expired, or token exchange fails.
|
||||
"""
|
||||
# Validate and retrieve state
|
||||
oauth_state = self._pending_states.pop(state, None)
|
||||
if oauth_state is None:
|
||||
raise OAuthError("Invalid or expired state token")
|
||||
|
||||
if oauth_state.is_expired():
|
||||
raise OAuthError("State token has expired")
|
||||
|
||||
if oauth_state.provider != provider:
|
||||
raise OAuthError(
|
||||
f"Provider mismatch: expected {oauth_state.provider}, got {provider}"
|
||||
)
|
||||
|
||||
# Exchange code for tokens
|
||||
tokens = await self._exchange_code(
|
||||
provider=provider,
|
||||
code=code,
|
||||
redirect_uri=oauth_state.redirect_uri,
|
||||
code_verifier=oauth_state.code_verifier,
|
||||
)
|
||||
|
||||
logger.info("Completed OAuth flow for provider=%s", provider.value)
|
||||
return tokens
|
||||
|
||||
async def refresh_tokens(
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
refresh_token: str,
|
||||
) -> OAuthTokens:
|
||||
"""Refresh expired access token.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider.
|
||||
refresh_token: Refresh token from previous exchange.
|
||||
|
||||
Returns:
|
||||
New OAuth tokens.
|
||||
|
||||
Raises:
|
||||
OAuthError: If refresh fails.
|
||||
"""
|
||||
token_url = self._get_token_url(provider)
|
||||
client_id, client_secret = self._get_credentials(provider)
|
||||
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": client_id,
|
||||
}
|
||||
|
||||
# Google requires client_secret for refresh
|
||||
if provider == OAuthProvider.GOOGLE:
|
||||
data["client_secret"] = client_secret
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(token_url, data=data)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_detail = response.text
|
||||
logger.error(
|
||||
"Token refresh failed for provider=%s: %s",
|
||||
provider.value,
|
||||
error_detail,
|
||||
)
|
||||
raise OAuthError(f"Token refresh failed: {error_detail}")
|
||||
|
||||
token_data = response.json()
|
||||
tokens = self._parse_token_response(token_data, refresh_token)
|
||||
|
||||
logger.info("Refreshed tokens for provider=%s", provider.value)
|
||||
return tokens
|
||||
|
||||
async def revoke_tokens(
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
access_token: str,
|
||||
) -> bool:
|
||||
"""Revoke OAuth tokens with provider.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider.
|
||||
access_token: Access token to revoke.
|
||||
|
||||
Returns:
|
||||
True if revoked successfully.
|
||||
"""
|
||||
revoke_url = self._get_revoke_url(provider)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
if provider == OAuthProvider.GOOGLE:
|
||||
response = await client.post(
|
||||
revoke_url,
|
||||
params={"token": access_token},
|
||||
)
|
||||
else:
|
||||
# Outlook uses logout endpoint with token hint
|
||||
response = await client.get(
|
||||
revoke_url,
|
||||
params={"post_logout_redirect_uri": self._settings.redirect_uri},
|
||||
)
|
||||
|
||||
if response.status_code in (200, 204):
|
||||
logger.info("Revoked tokens for provider=%s", provider.value)
|
||||
return True
|
||||
|
||||
logger.warning(
|
||||
"Token revocation returned status=%d for provider=%s",
|
||||
response.status_code,
|
||||
provider.value,
|
||||
)
|
||||
return False
|
||||
|
||||
def _validate_provider_config(self, provider: OAuthProvider) -> None:
|
||||
"""Validate that provider credentials are configured."""
|
||||
client_id, client_secret = self._get_credentials(provider)
|
||||
if not client_id or not client_secret:
|
||||
raise OAuthError(
|
||||
f"OAuth credentials not configured for {provider.value}"
|
||||
)
|
||||
|
||||
def _get_credentials(self, provider: OAuthProvider) -> tuple[str, str]:
|
||||
"""Get client credentials for provider."""
|
||||
if provider == OAuthProvider.GOOGLE:
|
||||
return (
|
||||
self._settings.google_client_id,
|
||||
self._settings.google_client_secret,
|
||||
)
|
||||
return (
|
||||
self._settings.outlook_client_id,
|
||||
self._settings.outlook_client_secret,
|
||||
)
|
||||
|
||||
def _get_auth_url(self, provider: OAuthProvider) -> str:
|
||||
"""Get authorization URL for provider."""
|
||||
if provider == OAuthProvider.GOOGLE:
|
||||
return self.GOOGLE_AUTH_URL
|
||||
return self.OUTLOOK_AUTH_URL
|
||||
|
||||
def _get_token_url(self, provider: OAuthProvider) -> str:
|
||||
"""Get token URL for provider."""
|
||||
if provider == OAuthProvider.GOOGLE:
|
||||
return self.GOOGLE_TOKEN_URL
|
||||
return self.OUTLOOK_TOKEN_URL
|
||||
|
||||
def _get_revoke_url(self, provider: OAuthProvider) -> str:
|
||||
"""Get revoke URL for provider."""
|
||||
if provider == OAuthProvider.GOOGLE:
|
||||
return self.GOOGLE_REVOKE_URL
|
||||
return self.OUTLOOK_REVOKE_URL
|
||||
|
||||
def _get_scopes(self, provider: OAuthProvider) -> list[str]:
|
||||
"""Get OAuth scopes for provider."""
|
||||
if provider == OAuthProvider.GOOGLE:
|
||||
return self.GOOGLE_SCOPES
|
||||
return self.OUTLOOK_SCOPES
|
||||
|
||||
def _build_auth_url(
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
redirect_uri: str,
|
||||
state: str,
|
||||
code_challenge: str,
|
||||
) -> str:
|
||||
"""Build OAuth authorization URL with PKCE parameters."""
|
||||
client_id, _ = self._get_credentials(provider)
|
||||
scopes = self._get_scopes(provider)
|
||||
base_url = self._get_auth_url(provider)
|
||||
|
||||
params = {
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"scope": " ".join(scopes),
|
||||
"state": state,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
# Provider-specific parameters
|
||||
if provider == OAuthProvider.GOOGLE:
|
||||
params["access_type"] = "offline"
|
||||
params["prompt"] = "consent"
|
||||
elif provider == OAuthProvider.OUTLOOK:
|
||||
params["response_mode"] = "query"
|
||||
|
||||
return f"{base_url}?{urlencode(params)}"
|
||||
|
||||
async def _exchange_code(
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
code: str,
|
||||
redirect_uri: str,
|
||||
code_verifier: str,
|
||||
) -> OAuthTokens:
|
||||
"""Exchange authorization code for tokens."""
|
||||
token_url = self._get_token_url(provider)
|
||||
client_id, client_secret = self._get_credentials(provider)
|
||||
|
||||
data = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"client_id": client_id,
|
||||
"code_verifier": code_verifier,
|
||||
}
|
||||
|
||||
# Google requires client_secret even with PKCE
|
||||
if provider == OAuthProvider.GOOGLE:
|
||||
data["client_secret"] = client_secret
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(token_url, data=data)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_detail = response.text
|
||||
logger.error(
|
||||
"Token exchange failed for provider=%s: %s",
|
||||
provider.value,
|
||||
error_detail,
|
||||
)
|
||||
raise OAuthError(f"Token exchange failed: {error_detail}")
|
||||
|
||||
token_data = response.json()
|
||||
return self._parse_token_response(token_data)
|
||||
|
||||
def _parse_token_response(
|
||||
self,
|
||||
data: dict[str, object],
|
||||
existing_refresh_token: str | None = None,
|
||||
) -> OAuthTokens:
|
||||
"""Parse token response into OAuthTokens."""
|
||||
access_token = str(data.get("access_token", ""))
|
||||
if not access_token:
|
||||
raise OAuthError("No access_token in response")
|
||||
|
||||
# Calculate expiry time
|
||||
expires_in_raw = data.get("expires_in", 3600)
|
||||
expires_in = int(expires_in_raw) if isinstance(expires_in_raw, (int, float, str)) else 3600
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=expires_in)
|
||||
|
||||
# Refresh token may not be returned on refresh
|
||||
refresh_token = data.get("refresh_token")
|
||||
if isinstance(refresh_token, str):
|
||||
final_refresh_token: str | None = refresh_token
|
||||
else:
|
||||
final_refresh_token = existing_refresh_token
|
||||
|
||||
return OAuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=final_refresh_token,
|
||||
token_type=str(data.get("token_type", "Bearer")),
|
||||
expires_at=expires_at,
|
||||
scope=str(data.get("scope", "")),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_code_verifier() -> str:
|
||||
"""Generate a cryptographically random code verifier for PKCE."""
|
||||
return secrets.token_urlsafe(64)
|
||||
|
||||
@staticmethod
|
||||
def _generate_code_challenge(verifier: str) -> str:
|
||||
"""Generate code challenge from verifier using S256 method."""
|
||||
digest = hashlib.sha256(verifier.encode("ascii")).digest()
|
||||
return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
|
||||
|
||||
def _cleanup_expired_states(self) -> None:
|
||||
"""Remove expired state tokens."""
|
||||
now = datetime.now(UTC)
|
||||
expired_keys = [
|
||||
key
|
||||
for key, state in self._pending_states.items()
|
||||
if state.is_expired() or now > state.expires_at
|
||||
]
|
||||
for key in expired_keys:
|
||||
del self._pending_states[key]
|
||||
226
src/noteflow/infrastructure/calendar/outlook_adapter.py
Normal file
226
src/noteflow/infrastructure/calendar/outlook_adapter.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""Microsoft Outlook Calendar API adapter.
|
||||
|
||||
Implements CalendarPort for Outlook using Microsoft Graph API.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import httpx
|
||||
|
||||
from noteflow.domain.ports.calendar import CalendarEventInfo, CalendarPort
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutlookCalendarError(Exception):
|
||||
"""Outlook Calendar API error."""
|
||||
|
||||
|
||||
class OutlookCalendarAdapter(CalendarPort):
|
||||
"""Microsoft Graph Calendar API adapter.
|
||||
|
||||
Fetches calendar events and user info using Microsoft Graph API.
|
||||
Requires a valid OAuth access token with Calendars.Read scope.
|
||||
"""
|
||||
|
||||
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
|
||||
|
||||
async def list_events(
|
||||
self,
|
||||
access_token: str,
|
||||
hours_ahead: int = 24,
|
||||
limit: int = 20,
|
||||
) -> list[CalendarEventInfo]:
|
||||
"""Fetch upcoming calendar events from Outlook Calendar.
|
||||
|
||||
Args:
|
||||
access_token: Microsoft Graph OAuth token with Calendars.Read scope.
|
||||
hours_ahead: Hours to look ahead from current time.
|
||||
limit: Maximum events to return (capped by Graph API).
|
||||
|
||||
Returns:
|
||||
List of Outlook calendar events ordered by start datetime.
|
||||
|
||||
Raises:
|
||||
OutlookCalendarError: If API call fails.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
start_time = now.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
end_time = (now + timedelta(hours=hours_ahead)).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
url = f"{self.GRAPH_API_BASE}/me/calendarView"
|
||||
params: dict[str, str | int] = {
|
||||
"startDateTime": start_time,
|
||||
"endDateTime": end_time,
|
||||
"$top": limit,
|
||||
"$orderby": "start/dateTime",
|
||||
"$select": (
|
||||
"id,subject,start,end,location,bodyPreview,"
|
||||
"attendees,isAllDay,seriesMasterId,onlineMeeting,onlineMeetingUrl"
|
||||
),
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Prefer": 'outlook.timezone="UTC"',
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url, params=params, headers=headers)
|
||||
|
||||
if response.status_code == 401:
|
||||
raise OutlookCalendarError("Access token expired or invalid")
|
||||
|
||||
if response.status_code != 200:
|
||||
error_msg = response.text
|
||||
logger.error("Microsoft Graph API error: %s", error_msg)
|
||||
raise OutlookCalendarError(f"API error: {error_msg}")
|
||||
|
||||
data = response.json()
|
||||
items = data.get("value", [])
|
||||
|
||||
return [self._parse_event(item) for item in items]
|
||||
|
||||
async def get_user_email(self, access_token: str) -> str:
|
||||
"""Get authenticated user's email address.
|
||||
|
||||
Args:
|
||||
access_token: Valid OAuth access token.
|
||||
|
||||
Returns:
|
||||
User's email address.
|
||||
|
||||
Raises:
|
||||
OutlookCalendarError: If API call fails.
|
||||
"""
|
||||
url = f"{self.GRAPH_API_BASE}/me"
|
||||
params = {"$select": "mail,userPrincipalName"}
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url, params=params, headers=headers)
|
||||
|
||||
if response.status_code == 401:
|
||||
raise OutlookCalendarError("Access token expired or invalid")
|
||||
|
||||
if response.status_code != 200:
|
||||
error_msg = response.text
|
||||
logger.error("Microsoft Graph API error: %s", error_msg)
|
||||
raise OutlookCalendarError(f"API error: {error_msg}")
|
||||
|
||||
data = response.json()
|
||||
# Prefer mail, fall back to userPrincipalName
|
||||
email = data.get("mail") or data.get("userPrincipalName")
|
||||
if not email:
|
||||
raise OutlookCalendarError("No email in user profile response")
|
||||
|
||||
return str(email)
|
||||
|
||||
def _parse_event(self, item: dict[str, object]) -> CalendarEventInfo:
|
||||
"""Parse Microsoft Graph event into CalendarEventInfo."""
|
||||
event_id = str(item.get("id", ""))
|
||||
title = str(item.get("subject", "Untitled"))
|
||||
|
||||
# Parse start/end times
|
||||
start_data = item.get("start", {})
|
||||
end_data = item.get("end", {})
|
||||
|
||||
start_time = self._parse_datetime(start_data)
|
||||
end_time = self._parse_datetime(end_data)
|
||||
|
||||
# Check if all-day event
|
||||
is_all_day = bool(item.get("isAllDay", False))
|
||||
|
||||
# Parse attendees
|
||||
attendees_data = item.get("attendees", [])
|
||||
attendees = self._parse_attendees(attendees_data)
|
||||
|
||||
# Extract meeting URL
|
||||
meeting_url = self._extract_meeting_url(item)
|
||||
|
||||
# Check if recurring (has seriesMasterId)
|
||||
is_recurring = bool(item.get("seriesMasterId"))
|
||||
|
||||
# Location
|
||||
location_data = item.get("location", {})
|
||||
location = (
|
||||
str(location_data.get("displayName"))
|
||||
if isinstance(location_data, dict) and location_data.get("displayName")
|
||||
else None
|
||||
)
|
||||
|
||||
# Description (bodyPreview)
|
||||
description = item.get("bodyPreview")
|
||||
|
||||
return CalendarEventInfo(
|
||||
id=event_id,
|
||||
title=title,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
attendees=attendees,
|
||||
location=location,
|
||||
description=str(description) if description else None,
|
||||
meeting_url=meeting_url,
|
||||
is_recurring=is_recurring,
|
||||
is_all_day=is_all_day,
|
||||
provider="outlook",
|
||||
raw=dict(item) if isinstance(item, dict) else None,
|
||||
)
|
||||
|
||||
def _parse_datetime(self, dt_data: object) -> datetime:
|
||||
"""Parse datetime from Microsoft Graph format."""
|
||||
if not isinstance(dt_data, dict):
|
||||
return datetime.now(UTC)
|
||||
|
||||
dt_str = dt_data.get("dateTime")
|
||||
timezone = dt_data.get("timeZone", "UTC")
|
||||
|
||||
if not dt_str or not isinstance(dt_str, str):
|
||||
return datetime.now(UTC)
|
||||
|
||||
try:
|
||||
# Graph API returns ISO format without timezone suffix
|
||||
dt = datetime.fromisoformat(dt_str)
|
||||
# If no timezone info, assume UTC (we requested UTC in Prefer header)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=UTC)
|
||||
return dt
|
||||
except ValueError:
|
||||
logger.warning("Failed to parse datetime: %s (tz: %s)", dt_str, timezone)
|
||||
return datetime.now(UTC)
|
||||
|
||||
def _parse_attendees(self, attendees_data: object) -> tuple[str, ...]:
|
||||
"""Parse attendees from Microsoft Graph format."""
|
||||
if not isinstance(attendees_data, list):
|
||||
return ()
|
||||
|
||||
emails: list[str] = []
|
||||
for attendee in attendees_data:
|
||||
if not isinstance(attendee, dict):
|
||||
continue
|
||||
email_address = attendee.get("emailAddress", {})
|
||||
if isinstance(email_address, dict):
|
||||
email = email_address.get("address")
|
||||
if email and isinstance(email, str):
|
||||
emails.append(email)
|
||||
|
||||
return tuple(emails)
|
||||
|
||||
def _extract_meeting_url(self, item: dict[str, object]) -> str | None:
|
||||
"""Extract online meeting URL from event data."""
|
||||
# Try onlineMeetingUrl first (Teams link)
|
||||
online_url = item.get("onlineMeetingUrl")
|
||||
if isinstance(online_url, str) and online_url:
|
||||
return online_url
|
||||
|
||||
# Try onlineMeeting object
|
||||
online_meeting = item.get("onlineMeeting")
|
||||
if isinstance(online_meeting, dict):
|
||||
join_url = online_meeting.get("joinUrl")
|
||||
if isinstance(join_url, str) and join_url:
|
||||
return join_url
|
||||
|
||||
return None
|
||||
@@ -1,9 +1,11 @@
|
||||
"""Infrastructure converters for data transformation between layers."""
|
||||
|
||||
from noteflow.infrastructure.converters.asr_converters import AsrConverter
|
||||
from noteflow.infrastructure.converters.calendar_converters import CalendarEventConverter
|
||||
from noteflow.infrastructure.converters.orm_converters import OrmConverter
|
||||
|
||||
__all__ = [
|
||||
"AsrConverter",
|
||||
"CalendarEventConverter",
|
||||
"OrmConverter",
|
||||
]
|
||||
|
||||
116
src/noteflow/infrastructure/converters/calendar_converters.py
Normal file
116
src/noteflow/infrastructure/converters/calendar_converters.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Calendar data converters.
|
||||
|
||||
Convert between CalendarEventModel (ORM) and CalendarEventInfo (domain).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from noteflow.domain.ports.calendar import CalendarEventInfo
|
||||
from noteflow.infrastructure.triggers.calendar import CalendarEvent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.infrastructure.persistence.models.integrations.integration import (
|
||||
CalendarEventModel,
|
||||
)
|
||||
|
||||
|
||||
class CalendarEventConverter:
|
||||
"""Convert between CalendarEventModel and CalendarEventInfo."""
|
||||
|
||||
@staticmethod
|
||||
def orm_to_info(model: CalendarEventModel, provider: str = "") -> CalendarEventInfo:
|
||||
"""Convert ORM model to domain CalendarEventInfo.
|
||||
|
||||
Args:
|
||||
model: SQLAlchemy CalendarEventModel instance.
|
||||
provider: Provider name (google, outlook).
|
||||
|
||||
Returns:
|
||||
Domain CalendarEventInfo value object.
|
||||
"""
|
||||
return CalendarEventInfo(
|
||||
id=model.external_id,
|
||||
title=model.title,
|
||||
start_time=model.start_time,
|
||||
end_time=model.end_time,
|
||||
attendees=tuple(model.attendees) if model.attendees else (),
|
||||
location=model.location,
|
||||
description=model.description,
|
||||
meeting_url=model.meeting_link,
|
||||
is_recurring=False, # Not stored in current model
|
||||
is_all_day=model.is_all_day,
|
||||
provider=provider,
|
||||
raw=dict(model.raw) if model.raw else None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def info_to_orm_kwargs(
|
||||
event: CalendarEventInfo,
|
||||
integration_id: UUID,
|
||||
calendar_id: str = "primary",
|
||||
calendar_name: str = "Primary",
|
||||
) -> dict[str, object]:
|
||||
"""Convert CalendarEventInfo to ORM model kwargs.
|
||||
|
||||
Args:
|
||||
event: Domain CalendarEventInfo.
|
||||
integration_id: Parent integration UUID.
|
||||
calendar_id: Calendar identifier from provider.
|
||||
calendar_name: Display name of the calendar.
|
||||
|
||||
Returns:
|
||||
Kwargs dict for CalendarEventModel construction.
|
||||
"""
|
||||
return {
|
||||
"integration_id": integration_id,
|
||||
"external_id": event.id,
|
||||
"calendar_id": calendar_id,
|
||||
"calendar_name": calendar_name,
|
||||
"title": event.title,
|
||||
"description": event.description,
|
||||
"start_time": event.start_time,
|
||||
"end_time": event.end_time,
|
||||
"location": event.location,
|
||||
"attendees": list(event.attendees) if event.attendees else None,
|
||||
"is_all_day": event.is_all_day,
|
||||
"meeting_link": event.meeting_url,
|
||||
"raw": event.raw or {},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def info_to_trigger_event(event: CalendarEventInfo) -> CalendarEvent:
|
||||
"""Convert CalendarEventInfo to trigger-layer CalendarEvent.
|
||||
|
||||
The trigger-layer CalendarEvent is a lightweight dataclass used
|
||||
for signal detection in the trigger system.
|
||||
|
||||
Args:
|
||||
event: Domain CalendarEventInfo.
|
||||
|
||||
Returns:
|
||||
Trigger-layer CalendarEvent value object.
|
||||
"""
|
||||
return CalendarEvent(
|
||||
start=event.start_time,
|
||||
end=event.end_time,
|
||||
title=event.title,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def orm_to_trigger_event(model: CalendarEventModel) -> CalendarEvent:
|
||||
"""Convert ORM model to trigger-layer CalendarEvent.
|
||||
|
||||
Args:
|
||||
model: SQLAlchemy CalendarEventModel instance.
|
||||
|
||||
Returns:
|
||||
Trigger-layer CalendarEvent value object.
|
||||
"""
|
||||
return CalendarEvent(
|
||||
start=model.start_time,
|
||||
end=model.end_time,
|
||||
title=model.title,
|
||||
)
|
||||
@@ -0,0 +1,68 @@
|
||||
"""Integration domain ↔ ORM converters."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.domain.entities.integration import (
|
||||
Integration,
|
||||
IntegrationStatus,
|
||||
IntegrationType,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.infrastructure.persistence.models.integrations import IntegrationModel
|
||||
|
||||
|
||||
class IntegrationConverter:
|
||||
"""Convert between Integration domain objects and ORM models."""
|
||||
|
||||
@staticmethod
|
||||
def orm_to_domain(model: IntegrationModel) -> Integration:
|
||||
"""Convert ORM model to domain entity.
|
||||
|
||||
Args:
|
||||
model: SQLAlchemy IntegrationModel instance.
|
||||
|
||||
Returns:
|
||||
Domain Integration entity.
|
||||
"""
|
||||
return Integration(
|
||||
id=model.id,
|
||||
workspace_id=model.workspace_id,
|
||||
name=model.name,
|
||||
type=IntegrationType(model.type),
|
||||
status=IntegrationStatus(model.status),
|
||||
config=dict(model.config) if model.config else {},
|
||||
last_sync=model.last_sync,
|
||||
error_message=model.error_message,
|
||||
created_at=model.created_at,
|
||||
updated_at=model.updated_at,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def to_orm_kwargs(entity: Integration) -> dict[str, object]:
|
||||
"""Convert domain entity to ORM model kwargs.
|
||||
|
||||
Returns a dict of kwargs rather than instantiating IntegrationModel
|
||||
directly to avoid circular imports and allow the repository to
|
||||
handle ORM construction.
|
||||
|
||||
Args:
|
||||
entity: Domain Integration.
|
||||
|
||||
Returns:
|
||||
Kwargs dict for IntegrationModel construction.
|
||||
"""
|
||||
return {
|
||||
"id": entity.id,
|
||||
"workspace_id": entity.workspace_id,
|
||||
"name": entity.name,
|
||||
"type": entity.type.value,
|
||||
"status": entity.status.value,
|
||||
"config": entity.config,
|
||||
"last_sync": entity.last_sync,
|
||||
"error_message": entity.error_message,
|
||||
"created_at": entity.created_at,
|
||||
"updated_at": entity.updated_at,
|
||||
}
|
||||
69
src/noteflow/infrastructure/converters/ner_converters.py
Normal file
69
src/noteflow/infrastructure/converters/ner_converters.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""NER domain ↔ ORM converters."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.infrastructure.persistence.models import NamedEntityModel
|
||||
|
||||
|
||||
class NerConverter:
|
||||
"""Convert between NamedEntity domain objects and ORM models."""
|
||||
|
||||
@staticmethod
|
||||
def orm_to_domain(model: NamedEntityModel) -> NamedEntity:
|
||||
"""Convert ORM model to domain entity.
|
||||
|
||||
Args:
|
||||
model: SQLAlchemy NamedEntityModel instance.
|
||||
|
||||
Returns:
|
||||
Domain NamedEntity with validated and normalized fields.
|
||||
|
||||
Raises:
|
||||
ValueError: If model data is invalid (e.g., confidence out of range).
|
||||
"""
|
||||
# Validate and convert segment_ids to ensure it's a proper list
|
||||
segment_ids = list(model.segment_ids) if model.segment_ids else []
|
||||
|
||||
# Create domain entity - validation happens in __post_init__
|
||||
return NamedEntity(
|
||||
id=model.id,
|
||||
meeting_id=MeetingId(model.meeting_id),
|
||||
text=model.text,
|
||||
normalized_text=model.normalized_text,
|
||||
category=EntityCategory(model.category),
|
||||
segment_ids=segment_ids,
|
||||
confidence=model.confidence,
|
||||
is_pinned=model.is_pinned,
|
||||
db_id=model.id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def to_orm_kwargs(entity: NamedEntity) -> dict[str, object]:
|
||||
"""Convert domain entity to ORM model kwargs.
|
||||
|
||||
Returns a dict of kwargs rather than instantiating NamedEntityModel
|
||||
directly to avoid circular imports and allow the repository to
|
||||
handle ORM construction.
|
||||
|
||||
Args:
|
||||
entity: Domain NamedEntity.
|
||||
|
||||
Returns:
|
||||
Kwargs dict for NamedEntityModel construction.
|
||||
"""
|
||||
return {
|
||||
"id": entity.id,
|
||||
"meeting_id": entity.meeting_id,
|
||||
"text": entity.text,
|
||||
"normalized_text": entity.normalized_text,
|
||||
"category": entity.category.value,
|
||||
"segment_ids": entity.segment_ids,
|
||||
"confidence": entity.confidence,
|
||||
"is_pinned": entity.is_pinned,
|
||||
}
|
||||
@@ -6,7 +6,7 @@ Export meeting transcripts to PDF format.
|
||||
from __future__ import annotations
|
||||
|
||||
import html
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
from typing import TYPE_CHECKING, Protocol, cast
|
||||
|
||||
from noteflow.infrastructure.export._formatting import format_datetime, format_timestamp
|
||||
|
||||
@@ -36,7 +36,10 @@ def _get_weasy_html() -> type[_WeasyHTMLProtocol] | None:
|
||||
return None
|
||||
|
||||
weasyprint = importlib.import_module("weasyprint")
|
||||
return weasyprint.HTML
|
||||
html_class = getattr(weasyprint, "HTML", None)
|
||||
if html_class is None:
|
||||
return None
|
||||
return cast(type[_WeasyHTMLProtocol], html_class)
|
||||
|
||||
|
||||
def _escape(text: str) -> str:
|
||||
|
||||
5
src/noteflow/infrastructure/ner/__init__.py
Normal file
5
src/noteflow/infrastructure/ner/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Named Entity Recognition infrastructure."""
|
||||
|
||||
from noteflow.infrastructure.ner.engine import NerEngine
|
||||
|
||||
__all__ = ["NerEngine"]
|
||||
252
src/noteflow/infrastructure/ner/engine.py
Normal file
252
src/noteflow/infrastructure/ner/engine.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""NER engine implementation using spaCy.
|
||||
|
||||
Provides named entity extraction with lazy model loading and segment tracking.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Final
|
||||
|
||||
from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from spacy.language import Language
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Map spaCy entity types to our categories
|
||||
_SPACY_CATEGORY_MAP: Final[dict[str, EntityCategory]] = {
|
||||
# People
|
||||
"PERSON": EntityCategory.PERSON,
|
||||
# Organizations
|
||||
"ORG": EntityCategory.COMPANY,
|
||||
# Products and creative works
|
||||
"PRODUCT": EntityCategory.PRODUCT,
|
||||
"WORK_OF_ART": EntityCategory.PRODUCT,
|
||||
# Locations
|
||||
"GPE": EntityCategory.LOCATION, # Geo-political entity (countries, cities)
|
||||
"LOC": EntityCategory.LOCATION, # Non-GPE locations (mountains, rivers)
|
||||
"FAC": EntityCategory.LOCATION, # Facilities (buildings, airports)
|
||||
# Dates and times
|
||||
"DATE": EntityCategory.DATE,
|
||||
"TIME": EntityCategory.DATE,
|
||||
# Others (filtered out or mapped to OTHER)
|
||||
"MONEY": EntityCategory.OTHER,
|
||||
"PERCENT": EntityCategory.OTHER,
|
||||
"CARDINAL": EntityCategory.OTHER,
|
||||
"ORDINAL": EntityCategory.OTHER,
|
||||
"QUANTITY": EntityCategory.OTHER,
|
||||
"NORP": EntityCategory.OTHER, # Nationalities, religions
|
||||
"EVENT": EntityCategory.OTHER,
|
||||
"LAW": EntityCategory.OTHER,
|
||||
"LANGUAGE": EntityCategory.OTHER,
|
||||
}
|
||||
|
||||
# Entity types to skip (low value for meeting context)
|
||||
_SKIP_ENTITY_TYPES: Final[frozenset[str]] = frozenset({
|
||||
"CARDINAL",
|
||||
"ORDINAL",
|
||||
"QUANTITY",
|
||||
"PERCENT",
|
||||
"MONEY",
|
||||
})
|
||||
|
||||
# Valid model names
|
||||
VALID_SPACY_MODELS: Final[tuple[str, ...]] = (
|
||||
"en_core_web_sm",
|
||||
"en_core_web_md",
|
||||
"en_core_web_lg",
|
||||
"en_core_web_trf",
|
||||
)
|
||||
|
||||
|
||||
class NerEngine:
|
||||
"""Named entity recognition engine using spaCy.
|
||||
|
||||
Lazy-loads the spaCy model on first use to avoid startup delay.
|
||||
Implements the NerPort protocol for hexagonal architecture.
|
||||
|
||||
Uses chunking by segment (speaker turn/paragraph) to avoid OOM on
|
||||
long transcripts while maintaining segment tracking.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "en_core_web_trf") -> None:
|
||||
"""Initialize NER engine.
|
||||
|
||||
Args:
|
||||
model_name: spaCy model to use. Defaults to en_core_web_trf
|
||||
(transformer-based for higher accuracy).
|
||||
"""
|
||||
if model_name not in VALID_SPACY_MODELS:
|
||||
raise ValueError(
|
||||
f"Invalid model name: {model_name}. "
|
||||
f"Valid models: {', '.join(VALID_SPACY_MODELS)}"
|
||||
)
|
||||
self._model_name = model_name
|
||||
self._nlp: Language | None = None
|
||||
|
||||
def load_model(self) -> None:
|
||||
"""Load the spaCy model.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model loading fails.
|
||||
"""
|
||||
import spacy
|
||||
|
||||
logger.info("Loading spaCy model: %s", self._model_name)
|
||||
try:
|
||||
self._nlp = spacy.load(self._model_name)
|
||||
logger.info("spaCy model loaded successfully")
|
||||
except OSError as e:
|
||||
msg = (
|
||||
f"Failed to load spaCy model '{self._model_name}'. "
|
||||
f"Run: python -m spacy download {self._model_name}"
|
||||
)
|
||||
raise RuntimeError(msg) from e
|
||||
|
||||
def _ensure_loaded(self) -> Language:
|
||||
"""Ensure model is loaded, loading if necessary.
|
||||
|
||||
Returns:
|
||||
The loaded spaCy Language model.
|
||||
"""
|
||||
if self._nlp is None:
|
||||
self.load_model()
|
||||
assert self._nlp is not None, "load_model() should set self._nlp"
|
||||
return self._nlp
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
return self._nlp is not None
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload the model to free memory."""
|
||||
self._nlp = None
|
||||
logger.info("spaCy model unloaded")
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Return the model name."""
|
||||
return self._model_name
|
||||
|
||||
def extract(self, text: str) -> list[NamedEntity]:
|
||||
"""Extract named entities from text.
|
||||
|
||||
Args:
|
||||
text: Input text to analyze.
|
||||
|
||||
Returns:
|
||||
List of extracted entities (deduplicated by normalized text).
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return []
|
||||
|
||||
nlp = self._ensure_loaded()
|
||||
doc = nlp(text)
|
||||
|
||||
entities: list[NamedEntity] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
for ent in doc.ents:
|
||||
# Normalize for deduplication
|
||||
normalized = ent.text.lower().strip()
|
||||
if not normalized or normalized in seen:
|
||||
continue
|
||||
|
||||
# Skip low-value entity types
|
||||
if ent.label_ in _SKIP_ENTITY_TYPES:
|
||||
continue
|
||||
|
||||
seen.add(normalized)
|
||||
category = _SPACY_CATEGORY_MAP.get(ent.label_, EntityCategory.OTHER)
|
||||
|
||||
entities.append(
|
||||
NamedEntity.create(
|
||||
text=ent.text,
|
||||
category=category,
|
||||
segment_ids=[], # Filled by caller
|
||||
confidence=0.8, # spaCy doesn't provide per-entity confidence
|
||||
)
|
||||
)
|
||||
|
||||
return entities
|
||||
|
||||
def extract_from_segments(
|
||||
self,
|
||||
segments: list[tuple[int, str]],
|
||||
) -> list[NamedEntity]:
|
||||
"""Extract entities from multiple segments with segment tracking.
|
||||
|
||||
Processes each segment individually (chunking by speaker turn/paragraph)
|
||||
to avoid OOM on long transcripts. Entities appearing in multiple segments
|
||||
are deduplicated with merged segment lists.
|
||||
|
||||
Args:
|
||||
segments: List of (segment_id, text) tuples.
|
||||
|
||||
Returns:
|
||||
Entities with segment_ids populated (deduplicated across segments).
|
||||
"""
|
||||
if not segments:
|
||||
return []
|
||||
|
||||
# Track entities and their segment occurrences
|
||||
# Key: normalized text, Value: NamedEntity
|
||||
all_entities: dict[str, NamedEntity] = {}
|
||||
|
||||
for segment_id, text in segments:
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
|
||||
segment_entities = self.extract(text)
|
||||
|
||||
for entity in segment_entities:
|
||||
key = entity.normalized_text
|
||||
|
||||
if key in all_entities:
|
||||
# Merge segment IDs
|
||||
all_entities[key].merge_segments([segment_id])
|
||||
else:
|
||||
# New entity - set its segment_ids
|
||||
entity.segment_ids = [segment_id]
|
||||
all_entities[key] = entity
|
||||
|
||||
return list(all_entities.values())
|
||||
|
||||
async def extract_async(self, text: str) -> list[NamedEntity]:
|
||||
"""Extract entities asynchronously using executor.
|
||||
|
||||
Offloads blocking extraction to a thread pool executor to avoid
|
||||
blocking the asyncio event loop.
|
||||
|
||||
Args:
|
||||
text: Input text to analyze.
|
||||
|
||||
Returns:
|
||||
List of extracted entities.
|
||||
"""
|
||||
return await self._run_in_executor(partial(self.extract, text))
|
||||
|
||||
async def extract_from_segments_async(
|
||||
self,
|
||||
segments: list[tuple[int, str]],
|
||||
) -> list[NamedEntity]:
|
||||
"""Extract entities from segments asynchronously.
|
||||
|
||||
Args:
|
||||
segments: List of (segment_id, text) tuples.
|
||||
|
||||
Returns:
|
||||
Entities with segment_ids populated.
|
||||
"""
|
||||
return await self._run_in_executor(partial(self.extract_from_segments, segments))
|
||||
|
||||
async def _run_in_executor(
|
||||
self,
|
||||
func: partial[list[NamedEntity]],
|
||||
) -> list[NamedEntity]:
|
||||
"""Run sync extraction in thread pool executor."""
|
||||
return await asyncio.get_running_loop().run_in_executor(None, func)
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import asyncpg
|
||||
from sqlalchemy import text
|
||||
@@ -27,14 +27,40 @@ logger = logging.getLogger(__name__)
|
||||
_original_asyncpg_connect = asyncpg.connect
|
||||
|
||||
|
||||
async def _patched_asyncpg_connect(*args: Any, **kwargs: Any) -> Any:
|
||||
async def _patched_asyncpg_connect(
|
||||
dsn: str | None = None,
|
||||
*,
|
||||
host: str | None = None,
|
||||
port: int | None = None,
|
||||
user: str | None = None,
|
||||
password: str | None = None,
|
||||
passfile: str | None = None,
|
||||
database: str | None = None,
|
||||
timeout: float = 60,
|
||||
ssl: object | None = None,
|
||||
direct_tls: bool = False,
|
||||
**kwargs: object,
|
||||
) -> asyncpg.Connection:
|
||||
"""Patched asyncpg.connect that filters out unsupported 'schema' parameter.
|
||||
|
||||
SQLAlchemy may try to pass 'schema' to asyncpg.connect(), but asyncpg
|
||||
doesn't support this parameter. This wrapper filters it out.
|
||||
"""
|
||||
kwargs.pop("schema", None)
|
||||
return await _original_asyncpg_connect(*args, **kwargs)
|
||||
# Remove schema if present - SQLAlchemy passes this but asyncpg doesn't support it
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "schema"}
|
||||
return await _original_asyncpg_connect(
|
||||
dsn,
|
||||
host=host,
|
||||
port=port,
|
||||
user=user,
|
||||
password=password,
|
||||
passfile=passfile,
|
||||
database=database,
|
||||
timeout=timeout,
|
||||
ssl=ssl,
|
||||
direct_tls=direct_tls,
|
||||
**filtered_kwargs,
|
||||
)
|
||||
|
||||
|
||||
# Patch asyncpg.connect to filter out schema parameter
|
||||
|
||||
@@ -14,7 +14,11 @@ from noteflow.domain.entities import Meeting, Segment, Summary
|
||||
from noteflow.domain.value_objects import MeetingState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from uuid import UUID
|
||||
|
||||
from noteflow.domain.entities import Annotation
|
||||
from noteflow.domain.entities.integration import Integration
|
||||
from noteflow.domain.entities.named_entity import NamedEntity
|
||||
from noteflow.domain.value_objects import AnnotationId, MeetingId
|
||||
from noteflow.grpc.meeting_store import MeetingStore
|
||||
from noteflow.infrastructure.persistence.repositories import (
|
||||
@@ -97,7 +101,7 @@ class MemorySegmentRepository:
|
||||
include_words: bool = True,
|
||||
) -> Sequence[Segment]:
|
||||
"""Get all segments for a meeting."""
|
||||
return self._store.get_segments(str(meeting_id))
|
||||
return self._store.fetch_segments(str(meeting_id))
|
||||
|
||||
async def search_semantic(
|
||||
self,
|
||||
@@ -127,9 +131,14 @@ class MemorySegmentRepository:
|
||||
This method exists for interface compatibility.
|
||||
"""
|
||||
|
||||
async def get_next_segment_id(self, meeting_id: MeetingId) -> int:
|
||||
"""Get next segment ID for a meeting."""
|
||||
return self._store.get_next_segment_id(str(meeting_id))
|
||||
async def compute_next_segment_id(self, meeting_id: MeetingId) -> int:
|
||||
"""Compute next available segment ID for a meeting.
|
||||
|
||||
Returns:
|
||||
Next sequential segment ID (0 if meeting has no segments).
|
||||
"""
|
||||
meeting_id_str = str(meeting_id)
|
||||
return self._store.compute_next_segment_id(meeting_id_str)
|
||||
|
||||
|
||||
class MemorySummaryRepository:
|
||||
@@ -262,3 +271,109 @@ class UnsupportedPreferencesRepository:
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("Preferences require database persistence")
|
||||
|
||||
|
||||
class UnsupportedEntityRepository:
|
||||
"""Entity repository that raises for unsupported operations.
|
||||
|
||||
Used in memory mode where NER entities require database persistence.
|
||||
"""
|
||||
|
||||
async def save(self, entity: NamedEntity) -> NamedEntity:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("NER entities require database persistence")
|
||||
|
||||
async def save_batch(self, entities: Sequence[NamedEntity]) -> Sequence[NamedEntity]:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("NER entities require database persistence")
|
||||
|
||||
async def get(self, entity_id: UUID) -> NamedEntity | None:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("NER entities require database persistence")
|
||||
|
||||
async def get_by_meeting(self, meeting_id: MeetingId) -> Sequence[NamedEntity]:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("NER entities require database persistence")
|
||||
|
||||
async def delete_by_meeting(self, meeting_id: MeetingId) -> int:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("NER entities require database persistence")
|
||||
|
||||
async def update_pinned(self, entity_id: UUID, is_pinned: bool) -> bool:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("NER entities require database persistence")
|
||||
|
||||
async def exists_for_meeting(self, meeting_id: MeetingId) -> bool:
|
||||
"""Not supported in memory mode."""
|
||||
raise NotImplementedError("NER entities require database persistence")
|
||||
|
||||
|
||||
class InMemoryIntegrationRepository:
|
||||
"""In-memory integration repository for testing.
|
||||
|
||||
Provides a functional integration repository backed by dictionaries
|
||||
for use in tests and memory mode.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize with empty storage."""
|
||||
self._integrations: dict[UUID, Integration] = {}
|
||||
self._secrets: dict[UUID, dict[str, str]] = {}
|
||||
|
||||
async def get(self, integration_id: UUID) -> Integration | None:
|
||||
"""Retrieve an integration by ID."""
|
||||
return self._integrations.get(integration_id)
|
||||
|
||||
async def get_by_provider(
|
||||
self,
|
||||
provider: str,
|
||||
integration_type: str | None = None,
|
||||
) -> Integration | None:
|
||||
"""Retrieve an integration by provider name."""
|
||||
for integration in self._integrations.values():
|
||||
provider_match = (
|
||||
integration.config.get("provider") == provider
|
||||
or provider.lower() in integration.name.lower()
|
||||
)
|
||||
type_match = integration_type is None or integration.type.value == integration_type
|
||||
if provider_match and type_match:
|
||||
return integration
|
||||
return None
|
||||
|
||||
async def create(self, integration: Integration) -> Integration:
|
||||
"""Persist a new integration."""
|
||||
self._integrations[integration.id] = integration
|
||||
return integration
|
||||
|
||||
async def update(self, integration: Integration) -> Integration:
|
||||
"""Update an existing integration."""
|
||||
if integration.id not in self._integrations:
|
||||
msg = f"Integration {integration.id} not found"
|
||||
raise ValueError(msg)
|
||||
self._integrations[integration.id] = integration
|
||||
return integration
|
||||
|
||||
async def delete(self, integration_id: UUID) -> bool:
|
||||
"""Delete an integration and its secrets."""
|
||||
if integration_id not in self._integrations:
|
||||
return False
|
||||
del self._integrations[integration_id]
|
||||
self._secrets.pop(integration_id, None)
|
||||
return True
|
||||
|
||||
async def get_secrets(self, integration_id: UUID) -> dict[str, str] | None:
|
||||
"""Get secrets for an integration."""
|
||||
if integration_id not in self._integrations:
|
||||
return None
|
||||
return self._secrets.get(integration_id, {})
|
||||
|
||||
async def set_secrets(self, integration_id: UUID, secrets: dict[str, str]) -> None:
|
||||
"""Store secrets for an integration."""
|
||||
self._secrets[integration_id] = secrets
|
||||
|
||||
async def list_by_type(self, integration_type: str) -> Sequence[Integration]:
|
||||
"""List integrations by type."""
|
||||
return [
|
||||
i for i in self._integrations.values()
|
||||
if i.type.value == integration_type
|
||||
]
|
||||
|
||||
@@ -12,6 +12,8 @@ from typing import TYPE_CHECKING, Self
|
||||
from noteflow.domain.ports.repositories import (
|
||||
AnnotationRepository,
|
||||
DiarizationJobRepository,
|
||||
EntityRepository,
|
||||
IntegrationRepository,
|
||||
MeetingRepository,
|
||||
PreferencesRepository,
|
||||
SegmentRepository,
|
||||
@@ -19,11 +21,13 @@ from noteflow.domain.ports.repositories import (
|
||||
)
|
||||
|
||||
from .repositories import (
|
||||
InMemoryIntegrationRepository,
|
||||
MemoryMeetingRepository,
|
||||
MemorySegmentRepository,
|
||||
MemorySummaryRepository,
|
||||
UnsupportedAnnotationRepository,
|
||||
UnsupportedDiarizationJobRepository,
|
||||
UnsupportedEntityRepository,
|
||||
UnsupportedPreferencesRepository,
|
||||
)
|
||||
|
||||
@@ -60,6 +64,8 @@ class MemoryUnitOfWork:
|
||||
self._annotations = UnsupportedAnnotationRepository()
|
||||
self._diarization_jobs = UnsupportedDiarizationJobRepository()
|
||||
self._preferences = UnsupportedPreferencesRepository()
|
||||
self._entities = UnsupportedEntityRepository()
|
||||
self._integrations = InMemoryIntegrationRepository()
|
||||
|
||||
# Core repositories
|
||||
@property
|
||||
@@ -93,6 +99,16 @@ class MemoryUnitOfWork:
|
||||
"""Get preferences repository (unsupported)."""
|
||||
return self._preferences
|
||||
|
||||
@property
|
||||
def entities(self) -> EntityRepository:
|
||||
"""Get entities repository (unsupported)."""
|
||||
return self._entities
|
||||
|
||||
@property
|
||||
def integrations(self) -> IntegrationRepository:
|
||||
"""Get integrations repository."""
|
||||
return self._integrations
|
||||
|
||||
# Feature flags - limited in memory mode
|
||||
@property
|
||||
def supports_annotations(self) -> bool:
|
||||
@@ -109,6 +125,16 @@ class MemoryUnitOfWork:
|
||||
"""User preferences not supported in memory mode."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_entities(self) -> bool:
|
||||
"""Entity extraction not supported in memory mode."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_integrations(self) -> bool:
|
||||
"""Integration persistence supported in memory mode."""
|
||||
return True
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""Enter the unit of work context.
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ from .diarization_job_repo import (
|
||||
SqlAlchemyDiarizationJobRepository,
|
||||
StreamingTurn,
|
||||
)
|
||||
from .entity_repo import SqlAlchemyEntityRepository
|
||||
from .integration_repo import SqlAlchemyIntegrationRepository
|
||||
from .meeting_repo import SqlAlchemyMeetingRepository
|
||||
from .preferences_repo import SqlAlchemyPreferencesRepository
|
||||
from .segment_repo import SqlAlchemySegmentRepository
|
||||
@@ -17,6 +19,8 @@ __all__ = [
|
||||
"DiarizationJob",
|
||||
"SqlAlchemyAnnotationRepository",
|
||||
"SqlAlchemyDiarizationJobRepository",
|
||||
"SqlAlchemyEntityRepository",
|
||||
"SqlAlchemyIntegrationRepository",
|
||||
"SqlAlchemyMeetingRepository",
|
||||
"SqlAlchemyPreferencesRepository",
|
||||
"SqlAlchemySegmentRepository",
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
"""SQLAlchemy implementation of entity repository for named entities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from noteflow.domain.entities.named_entity import NamedEntity
|
||||
from noteflow.infrastructure.converters.ner_converters import NerConverter
|
||||
from noteflow.infrastructure.persistence.models import NamedEntityModel
|
||||
from noteflow.infrastructure.persistence.repositories._base import BaseRepository
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
|
||||
|
||||
class SqlAlchemyEntityRepository(BaseRepository):
|
||||
"""SQLAlchemy implementation of entity repository for NER results."""
|
||||
|
||||
async def save(self, entity: NamedEntity) -> NamedEntity:
|
||||
"""Save or update a named entity.
|
||||
|
||||
Uses merge to handle both insert and update cases.
|
||||
|
||||
Args:
|
||||
entity: The entity to save.
|
||||
|
||||
Returns:
|
||||
Saved entity with db_id populated.
|
||||
"""
|
||||
kwargs = NerConverter.to_orm_kwargs(entity)
|
||||
model = NamedEntityModel(**kwargs)
|
||||
merged = await self._session.merge(model)
|
||||
await self._session.flush()
|
||||
entity.db_id = merged.id
|
||||
return entity
|
||||
|
||||
async def save_batch(self, entities: Sequence[NamedEntity]) -> Sequence[NamedEntity]:
|
||||
"""Save multiple entities efficiently.
|
||||
|
||||
Uses individual merges to handle upsert semantics with the unique
|
||||
constraint on (meeting_id, normalized_text).
|
||||
|
||||
Args:
|
||||
entities: List of entities to save.
|
||||
|
||||
Returns:
|
||||
Saved entities with db_ids populated.
|
||||
"""
|
||||
for entity in entities:
|
||||
kwargs = NerConverter.to_orm_kwargs(entity)
|
||||
model = NamedEntityModel(**kwargs)
|
||||
merged = await self._session.merge(model)
|
||||
entity.db_id = merged.id
|
||||
|
||||
await self._session.flush()
|
||||
return entities
|
||||
|
||||
async def get(self, entity_id: UUID) -> NamedEntity | None:
|
||||
"""Get entity by ID.
|
||||
|
||||
Args:
|
||||
entity_id: The entity UUID.
|
||||
|
||||
Returns:
|
||||
Entity if found, None otherwise.
|
||||
"""
|
||||
stmt = select(NamedEntityModel).where(NamedEntityModel.id == entity_id)
|
||||
model = await self._execute_scalar(stmt)
|
||||
return NerConverter.orm_to_domain(model) if model else None
|
||||
|
||||
async def get_by_meeting(self, meeting_id: MeetingId) -> Sequence[NamedEntity]:
|
||||
"""Get all entities for a meeting.
|
||||
|
||||
Args:
|
||||
meeting_id: The meeting UUID.
|
||||
|
||||
Returns:
|
||||
List of entities ordered by category and text.
|
||||
"""
|
||||
stmt = (
|
||||
select(NamedEntityModel)
|
||||
.where(NamedEntityModel.meeting_id == UUID(str(meeting_id)))
|
||||
.order_by(NamedEntityModel.category, NamedEntityModel.text)
|
||||
)
|
||||
models = await self._execute_scalars(stmt)
|
||||
return [NerConverter.orm_to_domain(m) for m in models]
|
||||
|
||||
async def delete_by_meeting(self, meeting_id: MeetingId) -> int:
|
||||
"""Delete all entities for a meeting.
|
||||
|
||||
Args:
|
||||
meeting_id: The meeting UUID.
|
||||
|
||||
Returns:
|
||||
Number of deleted entities.
|
||||
"""
|
||||
stmt = delete(NamedEntityModel).where(
|
||||
NamedEntityModel.meeting_id == UUID(str(meeting_id))
|
||||
)
|
||||
return await self._execute_delete(stmt)
|
||||
|
||||
async def update_pinned(self, entity_id: UUID, is_pinned: bool) -> bool:
|
||||
"""Update the pinned status of an entity.
|
||||
|
||||
Args:
|
||||
entity_id: The entity UUID.
|
||||
is_pinned: New pinned status.
|
||||
|
||||
Returns:
|
||||
True if entity was found and updated.
|
||||
"""
|
||||
stmt = select(NamedEntityModel).where(NamedEntityModel.id == entity_id)
|
||||
model = await self._execute_scalar(stmt)
|
||||
|
||||
if model is None:
|
||||
return False
|
||||
|
||||
model.is_pinned = is_pinned
|
||||
await self._session.flush()
|
||||
return True
|
||||
|
||||
async def exists_for_meeting(self, meeting_id: MeetingId) -> bool:
|
||||
"""Check if any entities exist for a meeting.
|
||||
|
||||
More efficient than fetching all entities when only checking existence.
|
||||
|
||||
Args:
|
||||
meeting_id: The meeting UUID.
|
||||
|
||||
Returns:
|
||||
True if at least one entity exists.
|
||||
"""
|
||||
stmt = select(NamedEntityModel).where(
|
||||
NamedEntityModel.meeting_id == UUID(str(meeting_id))
|
||||
)
|
||||
return await self._execute_exists(stmt)
|
||||
@@ -0,0 +1,215 @@
|
||||
"""SQLAlchemy repository for Integration entities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from noteflow.domain.entities.integration import Integration
|
||||
from noteflow.infrastructure.converters.integration_converters import IntegrationConverter
|
||||
from noteflow.infrastructure.persistence.models.integrations import (
|
||||
IntegrationModel,
|
||||
IntegrationSecretModel,
|
||||
)
|
||||
from noteflow.infrastructure.persistence.repositories._base import BaseRepository
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
class SqlAlchemyIntegrationRepository(BaseRepository):
|
||||
"""SQLAlchemy implementation of IntegrationRepository.
|
||||
|
||||
Manages external service integrations and their encrypted secrets.
|
||||
"""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
"""Initialize repository with database session.
|
||||
|
||||
Args:
|
||||
session: SQLAlchemy async session.
|
||||
"""
|
||||
super().__init__(session)
|
||||
|
||||
async def get(self, integration_id: UUID) -> Integration | None:
|
||||
"""Retrieve an integration by ID.
|
||||
|
||||
Args:
|
||||
integration_id: Integration UUID.
|
||||
|
||||
Returns:
|
||||
Integration if found, None otherwise.
|
||||
"""
|
||||
stmt = select(IntegrationModel).where(IntegrationModel.id == integration_id)
|
||||
model = await self._execute_scalar(stmt)
|
||||
return IntegrationConverter.orm_to_domain(model) if model else None
|
||||
|
||||
async def get_by_provider(
|
||||
self,
|
||||
provider: str,
|
||||
integration_type: str | None = None,
|
||||
) -> Integration | None:
|
||||
"""Retrieve an integration by provider name.
|
||||
|
||||
Args:
|
||||
provider: Provider name (stored in config['provider'] or name).
|
||||
integration_type: Optional type filter.
|
||||
|
||||
Returns:
|
||||
Integration if found, None otherwise.
|
||||
"""
|
||||
stmt = select(IntegrationModel).where(
|
||||
IntegrationModel.config["provider"].astext == provider,
|
||||
)
|
||||
if integration_type:
|
||||
stmt = stmt.where(IntegrationModel.type == integration_type)
|
||||
|
||||
model = await self._execute_scalar(stmt)
|
||||
if model:
|
||||
return IntegrationConverter.orm_to_domain(model)
|
||||
|
||||
# Fallback: check by name if provider not in config
|
||||
fallback_stmt = select(IntegrationModel).where(
|
||||
IntegrationModel.name.ilike(f"%{provider}%"),
|
||||
)
|
||||
if integration_type:
|
||||
fallback_stmt = fallback_stmt.where(IntegrationModel.type == integration_type)
|
||||
|
||||
fallback_model = await self._execute_scalar(fallback_stmt)
|
||||
return IntegrationConverter.orm_to_domain(fallback_model) if fallback_model else None
|
||||
|
||||
async def create(self, integration: Integration) -> Integration:
|
||||
"""Persist a new integration.
|
||||
|
||||
Args:
|
||||
integration: Integration to create.
|
||||
|
||||
Returns:
|
||||
Created integration.
|
||||
"""
|
||||
kwargs = IntegrationConverter.to_orm_kwargs(integration)
|
||||
model = IntegrationModel(**kwargs)
|
||||
await self._add_and_flush(model)
|
||||
return IntegrationConverter.orm_to_domain(model)
|
||||
|
||||
async def update(self, integration: Integration) -> Integration:
|
||||
"""Update an existing integration.
|
||||
|
||||
Args:
|
||||
integration: Integration with updated fields.
|
||||
|
||||
Returns:
|
||||
Updated integration.
|
||||
|
||||
Raises:
|
||||
ValueError: If integration does not exist.
|
||||
"""
|
||||
stmt = select(IntegrationModel).where(IntegrationModel.id == integration.id)
|
||||
model = await self._execute_scalar(stmt)
|
||||
if not model:
|
||||
msg = f"Integration {integration.id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Update fields
|
||||
model.name = integration.name
|
||||
model.type = integration.type.value
|
||||
model.status = integration.status.value
|
||||
model.config = integration.config
|
||||
model.last_sync = integration.last_sync
|
||||
model.error_message = integration.error_message
|
||||
model.updated_at = integration.updated_at
|
||||
|
||||
await self._session.flush()
|
||||
return IntegrationConverter.orm_to_domain(model)
|
||||
|
||||
async def delete(self, integration_id: UUID) -> bool:
|
||||
"""Delete an integration and its secrets.
|
||||
|
||||
Args:
|
||||
integration_id: Integration UUID.
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found.
|
||||
"""
|
||||
stmt = select(IntegrationModel).where(IntegrationModel.id == integration_id)
|
||||
model = await self._execute_scalar(stmt)
|
||||
if not model:
|
||||
return False
|
||||
|
||||
await self._delete_and_flush(model)
|
||||
return True
|
||||
|
||||
async def get_secrets(self, integration_id: UUID) -> dict[str, str] | None:
|
||||
"""Get secrets for an integration.
|
||||
|
||||
Note: Secrets are stored as bytes. The caller is responsible for
|
||||
encryption/decryption. This method decodes bytes to UTF-8 strings.
|
||||
|
||||
Args:
|
||||
integration_id: Integration UUID.
|
||||
|
||||
Returns:
|
||||
Dictionary of secret key-value pairs, or None if integration not found.
|
||||
"""
|
||||
# Verify integration exists
|
||||
exists_stmt = select(IntegrationModel.id).where(
|
||||
IntegrationModel.id == integration_id,
|
||||
)
|
||||
if not await self._execute_exists(exists_stmt):
|
||||
return None
|
||||
|
||||
# Fetch secrets
|
||||
stmt = select(IntegrationSecretModel).where(
|
||||
IntegrationSecretModel.integration_id == integration_id,
|
||||
)
|
||||
models = await self._execute_scalars(stmt)
|
||||
|
||||
return {model.secret_key: model.secret_value.decode("utf-8") for model in models}
|
||||
|
||||
async def set_secrets(self, integration_id: UUID, secrets: dict[str, str]) -> None:
|
||||
"""Store secrets for an integration.
|
||||
|
||||
Note: Secrets are stored as bytes. The caller is responsible for
|
||||
encryption/decryption. This method encodes strings to UTF-8 bytes.
|
||||
|
||||
Args:
|
||||
integration_id: Integration UUID.
|
||||
secrets: Dictionary of secret key-value pairs.
|
||||
"""
|
||||
# Delete existing secrets
|
||||
delete_stmt = delete(IntegrationSecretModel).where(
|
||||
IntegrationSecretModel.integration_id == integration_id,
|
||||
)
|
||||
await self._session.execute(delete_stmt)
|
||||
|
||||
# Insert new secrets
|
||||
if secrets:
|
||||
models = [
|
||||
IntegrationSecretModel(
|
||||
integration_id=integration_id,
|
||||
secret_key=key,
|
||||
secret_value=value.encode("utf-8"),
|
||||
)
|
||||
for key, value in secrets.items()
|
||||
]
|
||||
await self._add_all_and_flush(models)
|
||||
|
||||
async def list_by_type(self, integration_type: str) -> Sequence[Integration]:
|
||||
"""List integrations by type.
|
||||
|
||||
Args:
|
||||
integration_type: Integration type (e.g., 'calendar', 'email').
|
||||
|
||||
Returns:
|
||||
List of integrations of the specified type.
|
||||
"""
|
||||
stmt = (
|
||||
select(IntegrationModel)
|
||||
.where(IntegrationModel.type == integration_type)
|
||||
.order_by(IntegrationModel.created_at.desc())
|
||||
)
|
||||
models = await self._execute_scalars(stmt)
|
||||
return [IntegrationConverter.orm_to_domain(m) for m in models]
|
||||
@@ -207,8 +207,8 @@ class SqlAlchemySegmentRepository(BaseRepository):
|
||||
model.speaker_confidence = speaker_confidence
|
||||
await self._session.flush()
|
||||
|
||||
async def get_next_segment_id(self, meeting_id: MeetingId) -> int:
|
||||
"""Get the next segment_id for a meeting.
|
||||
async def compute_next_segment_id(self, meeting_id: MeetingId) -> int:
|
||||
"""Compute the next segment_id for a meeting.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting identifier.
|
||||
|
||||
@@ -16,6 +16,8 @@ from noteflow.infrastructure.persistence.database import (
|
||||
from .repositories import (
|
||||
SqlAlchemyAnnotationRepository,
|
||||
SqlAlchemyDiarizationJobRepository,
|
||||
SqlAlchemyEntityRepository,
|
||||
SqlAlchemyIntegrationRepository,
|
||||
SqlAlchemyMeetingRepository,
|
||||
SqlAlchemyPreferencesRepository,
|
||||
SqlAlchemySegmentRepository,
|
||||
@@ -46,6 +48,8 @@ class SqlAlchemyUnitOfWork:
|
||||
self._session: AsyncSession | None = None
|
||||
self._annotations_repo: SqlAlchemyAnnotationRepository | None = None
|
||||
self._diarization_jobs_repo: SqlAlchemyDiarizationJobRepository | None = None
|
||||
self._entities_repo: SqlAlchemyEntityRepository | None = None
|
||||
self._integrations_repo: SqlAlchemyIntegrationRepository | None = None
|
||||
self._meetings_repo: SqlAlchemyMeetingRepository | None = None
|
||||
self._preferences_repo: SqlAlchemyPreferencesRepository | None = None
|
||||
self._segments_repo: SqlAlchemySegmentRepository | None = None
|
||||
@@ -98,6 +102,20 @@ class SqlAlchemyUnitOfWork:
|
||||
raise RuntimeError("UnitOfWork not in context")
|
||||
return self._diarization_jobs_repo
|
||||
|
||||
@property
|
||||
def entities(self) -> SqlAlchemyEntityRepository:
|
||||
"""Get entities repository for NER results."""
|
||||
if self._entities_repo is None:
|
||||
raise RuntimeError("UnitOfWork not in context")
|
||||
return self._entities_repo
|
||||
|
||||
@property
|
||||
def integrations(self) -> SqlAlchemyIntegrationRepository:
|
||||
"""Get integrations repository for OAuth connections."""
|
||||
if self._integrations_repo is None:
|
||||
raise RuntimeError("UnitOfWork not in context")
|
||||
return self._integrations_repo
|
||||
|
||||
@property
|
||||
def meetings(self) -> SqlAlchemyMeetingRepository:
|
||||
"""Get meetings repository."""
|
||||
@@ -142,6 +160,16 @@ class SqlAlchemyUnitOfWork:
|
||||
"""User preferences persistence is fully supported with database."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_entities(self) -> bool:
|
||||
"""NER entity persistence is fully supported with database."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_integrations(self) -> bool:
|
||||
"""OAuth integration persistence is fully supported with database."""
|
||||
return True
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""Enter the unit of work context.
|
||||
|
||||
@@ -153,6 +181,8 @@ class SqlAlchemyUnitOfWork:
|
||||
self._session = self._session_factory()
|
||||
self._annotations_repo = SqlAlchemyAnnotationRepository(self._session)
|
||||
self._diarization_jobs_repo = SqlAlchemyDiarizationJobRepository(self._session)
|
||||
self._entities_repo = SqlAlchemyEntityRepository(self._session)
|
||||
self._integrations_repo = SqlAlchemyIntegrationRepository(self._session)
|
||||
self._meetings_repo = SqlAlchemyMeetingRepository(self._session)
|
||||
self._preferences_repo = SqlAlchemyPreferencesRepository(self._session)
|
||||
self._segments_repo = SqlAlchemySegmentRepository(self._session)
|
||||
@@ -184,6 +214,8 @@ class SqlAlchemyUnitOfWork:
|
||||
self._session = None
|
||||
self._annotations_repo = None
|
||||
self._diarization_jobs_repo = None
|
||||
self._entities_repo = None
|
||||
self._integrations_repo = None
|
||||
self._meetings_repo = None
|
||||
self._preferences_repo = None
|
||||
self._segments_repo = None
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, TypedDict, cast
|
||||
|
||||
from noteflow.domain.entities import ActionItem, KeyPoint, Summary
|
||||
from noteflow.domain.summarization import InvalidResponseError
|
||||
@@ -18,6 +18,30 @@ if TYPE_CHECKING:
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _KeyPointData(TypedDict, total=False):
|
||||
"""Expected structure for key point data from LLM response."""
|
||||
|
||||
text: str
|
||||
segment_ids: list[int]
|
||||
|
||||
|
||||
class _ActionItemData(TypedDict, total=False):
|
||||
"""Expected structure for action item data from LLM response."""
|
||||
|
||||
text: str
|
||||
assignee: str
|
||||
priority: int
|
||||
segment_ids: list[int]
|
||||
|
||||
|
||||
class _LLMResponseData(TypedDict, total=False):
|
||||
"""Expected structure for LLM response JSON."""
|
||||
|
||||
executive_summary: str
|
||||
key_points: list[_KeyPointData]
|
||||
action_items: list[_ActionItemData]
|
||||
|
||||
_TONE_INSTRUCTIONS: dict[str, str] = {
|
||||
"professional": "Use formal, business-appropriate language.",
|
||||
"casual": "Use conversational, approachable language.",
|
||||
@@ -153,7 +177,7 @@ def _strip_markdown_fences(text: str) -> str:
|
||||
|
||||
|
||||
def _parse_key_point(
|
||||
data: dict[str, Any],
|
||||
data: _KeyPointData,
|
||||
valid_ids: set[int],
|
||||
segments: Sequence[Segment],
|
||||
) -> KeyPoint:
|
||||
@@ -181,7 +205,7 @@ def _parse_key_point(
|
||||
)
|
||||
|
||||
|
||||
def _parse_action_item(data: dict[str, Any], valid_ids: set[int]) -> ActionItem:
|
||||
def _parse_action_item(data: _ActionItemData, valid_ids: set[int]) -> ActionItem:
|
||||
"""Parse a single action item from LLM response data.
|
||||
|
||||
Args:
|
||||
@@ -219,7 +243,7 @@ def parse_llm_response(response_text: str, request: SummarizationRequest) -> Sum
|
||||
text = _strip_markdown_fences(response_text)
|
||||
|
||||
try:
|
||||
data = json.loads(text)
|
||||
data = cast(_LLMResponseData, json.loads(text))
|
||||
except json.JSONDecodeError as e:
|
||||
raise InvalidResponseError(f"Invalid JSON response: {e}") from e
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ class CalendarProvider:
|
||||
return event_start <= window_end and event_end >= window_start
|
||||
|
||||
|
||||
def parse_calendar_events(raw_events: object) -> list[CalendarEvent]:
|
||||
def parse_calendar_event_config(raw_events: object) -> list[CalendarEvent]:
|
||||
"""Parse calendar events from config/env payloads."""
|
||||
if raw_events is None:
|
||||
return []
|
||||
@@ -108,8 +108,8 @@ def parse_calendar_events(raw_events: object) -> list[CalendarEvent]:
|
||||
events.append(item)
|
||||
continue
|
||||
if isinstance(item, dict):
|
||||
start = _parse_datetime(item.get("start"))
|
||||
end = _parse_datetime(item.get("end"))
|
||||
start = _parse_event_datetime(item.get("start"))
|
||||
end = _parse_event_datetime(item.get("end"))
|
||||
if start and end:
|
||||
events.append(CalendarEvent(start=start, end=end, title=item.get("title")))
|
||||
return events
|
||||
@@ -126,7 +126,7 @@ def _load_events_from_json(raw: str) -> list[dict[str, object]]:
|
||||
return [parsed] if isinstance(parsed, dict) else []
|
||||
|
||||
|
||||
def _parse_datetime(value: object) -> datetime | None:
|
||||
def _parse_event_datetime(value: object) -> datetime | None:
|
||||
if isinstance(value, datetime):
|
||||
return value
|
||||
if not isinstance(value, str) or not value:
|
||||
|
||||
456
tests/application/test_calendar_service.py
Normal file
456
tests/application/test_calendar_service.py
Normal file
@@ -0,0 +1,456 @@
|
||||
"""Tests for calendar service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.domain.entities import Integration, IntegrationStatus, IntegrationType
|
||||
from noteflow.domain.ports.calendar import CalendarEventInfo, OAuthConnectionInfo
|
||||
from noteflow.domain.value_objects import OAuthProvider, OAuthTokens
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.config.settings import CalendarSettings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def calendar_settings() -> CalendarSettings:
|
||||
"""Create test calendar settings."""
|
||||
from noteflow.config.settings import CalendarSettings
|
||||
|
||||
return CalendarSettings(
|
||||
google_client_id="test-google-client-id",
|
||||
google_client_secret="test-google-client-secret",
|
||||
outlook_client_id="test-outlook-client-id",
|
||||
outlook_client_secret="test-outlook-client-secret",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_manager() -> MagicMock:
|
||||
"""Create mock OAuth manager."""
|
||||
manager = MagicMock()
|
||||
manager.initiate_auth.return_value = ("https://auth.example.com", "state-123")
|
||||
manager.complete_auth = AsyncMock(
|
||||
return_value=OAuthTokens(
|
||||
access_token="access-token",
|
||||
refresh_token="refresh-token",
|
||||
token_type="Bearer",
|
||||
expires_at=datetime.now(UTC) + timedelta(hours=1),
|
||||
scope="email calendar",
|
||||
)
|
||||
)
|
||||
manager.refresh_tokens = AsyncMock(
|
||||
return_value=OAuthTokens(
|
||||
access_token="new-access-token",
|
||||
refresh_token="refresh-token",
|
||||
token_type="Bearer",
|
||||
expires_at=datetime.now(UTC) + timedelta(hours=1),
|
||||
scope="email calendar",
|
||||
)
|
||||
)
|
||||
manager.revoke_tokens = AsyncMock()
|
||||
return manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_google_adapter() -> MagicMock:
|
||||
"""Create mock Google Calendar adapter."""
|
||||
adapter = MagicMock()
|
||||
adapter.list_events = AsyncMock(
|
||||
return_value=[
|
||||
CalendarEventInfo(
|
||||
id="event-1",
|
||||
title="Test Meeting",
|
||||
start_time=datetime.now(UTC) + timedelta(hours=1),
|
||||
end_time=datetime.now(UTC) + timedelta(hours=2),
|
||||
attendees=("alice@example.com",),
|
||||
provider="google",
|
||||
)
|
||||
]
|
||||
)
|
||||
adapter.get_user_email = AsyncMock(return_value="user@gmail.com")
|
||||
return adapter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_outlook_adapter() -> MagicMock:
|
||||
"""Create mock Outlook Calendar adapter."""
|
||||
adapter = MagicMock()
|
||||
adapter.list_events = AsyncMock(return_value=[])
|
||||
adapter.get_user_email = AsyncMock(return_value="user@outlook.com")
|
||||
return adapter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_uow() -> MagicMock:
|
||||
"""Create mock unit of work."""
|
||||
uow = MagicMock()
|
||||
uow.__aenter__ = AsyncMock(return_value=uow)
|
||||
uow.__aexit__ = AsyncMock(return_value=None)
|
||||
uow.integrations = MagicMock()
|
||||
uow.integrations.get_by_type_and_provider = AsyncMock(return_value=None)
|
||||
uow.integrations.add = AsyncMock()
|
||||
uow.integrations.get_secrets = AsyncMock(return_value=None)
|
||||
uow.integrations.set_secrets = AsyncMock()
|
||||
uow.commit = AsyncMock()
|
||||
return uow
|
||||
|
||||
|
||||
class TestCalendarServiceInitiateOAuth:
|
||||
"""Tests for CalendarService.initiate_oauth."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initiate_oauth_returns_auth_url_and_state(
|
||||
self,
|
||||
calendar_settings: CalendarSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_google_adapter: MagicMock,
|
||||
mock_outlook_adapter: MagicMock,
|
||||
mock_uow: MagicMock,
|
||||
) -> None:
|
||||
"""initiate_oauth should return auth URL and state."""
|
||||
from noteflow.application.services.calendar_service import CalendarService
|
||||
|
||||
service = CalendarService(
|
||||
uow_factory=lambda: mock_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
google_adapter=mock_google_adapter,
|
||||
outlook_adapter=mock_outlook_adapter,
|
||||
)
|
||||
|
||||
auth_url, state = await service.initiate_oauth("google", "http://localhost/callback")
|
||||
|
||||
assert auth_url == "https://auth.example.com"
|
||||
assert state == "state-123"
|
||||
mock_oauth_manager.initiate_auth.assert_called_once()
|
||||
|
||||
|
||||
class TestCalendarServiceCompleteOAuth:
|
||||
"""Tests for CalendarService.complete_oauth."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_oauth_stores_tokens(
|
||||
self,
|
||||
calendar_settings: CalendarSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_google_adapter: MagicMock,
|
||||
mock_outlook_adapter: MagicMock,
|
||||
mock_uow: MagicMock,
|
||||
) -> None:
|
||||
"""complete_oauth should store tokens in integration secrets."""
|
||||
from noteflow.application.services.calendar_service import CalendarService
|
||||
|
||||
mock_uow.integrations.get_by_provider = AsyncMock(return_value=None)
|
||||
mock_uow.integrations.create = AsyncMock()
|
||||
mock_uow.integrations.update = AsyncMock()
|
||||
|
||||
service = CalendarService(
|
||||
uow_factory=lambda: mock_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
google_adapter=mock_google_adapter,
|
||||
outlook_adapter=mock_outlook_adapter,
|
||||
)
|
||||
|
||||
result = await service.complete_oauth("google", "auth-code", "state-123")
|
||||
|
||||
assert result is True
|
||||
mock_oauth_manager.complete_auth.assert_called_once()
|
||||
mock_uow.integrations.set_secrets.assert_called_once()
|
||||
mock_uow.commit.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_oauth_creates_integration_if_not_exists(
|
||||
self,
|
||||
calendar_settings: CalendarSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_google_adapter: MagicMock,
|
||||
mock_outlook_adapter: MagicMock,
|
||||
mock_uow: MagicMock,
|
||||
) -> None:
|
||||
"""complete_oauth should create new integration if none exists."""
|
||||
from noteflow.application.services.calendar_service import CalendarService
|
||||
|
||||
mock_uow.integrations.get_by_provider = AsyncMock(return_value=None)
|
||||
mock_uow.integrations.create = AsyncMock()
|
||||
mock_uow.integrations.update = AsyncMock()
|
||||
|
||||
service = CalendarService(
|
||||
uow_factory=lambda: mock_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
google_adapter=mock_google_adapter,
|
||||
outlook_adapter=mock_outlook_adapter,
|
||||
)
|
||||
|
||||
await service.complete_oauth("google", "auth-code", "state-123")
|
||||
|
||||
mock_uow.integrations.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_oauth_updates_existing_integration(
|
||||
self,
|
||||
calendar_settings: CalendarSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_google_adapter: MagicMock,
|
||||
mock_outlook_adapter: MagicMock,
|
||||
mock_uow: MagicMock,
|
||||
) -> None:
|
||||
"""complete_oauth should update existing integration."""
|
||||
from noteflow.application.services.calendar_service import CalendarService
|
||||
|
||||
existing_integration = Integration.create(
|
||||
workspace_id=uuid4(),
|
||||
name="Google Calendar",
|
||||
integration_type=IntegrationType.CALENDAR,
|
||||
config={"provider": "google"},
|
||||
)
|
||||
mock_uow.integrations.get_by_provider = AsyncMock(return_value=existing_integration)
|
||||
mock_uow.integrations.create = AsyncMock()
|
||||
mock_uow.integrations.update = AsyncMock()
|
||||
|
||||
service = CalendarService(
|
||||
uow_factory=lambda: mock_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
google_adapter=mock_google_adapter,
|
||||
outlook_adapter=mock_outlook_adapter,
|
||||
)
|
||||
|
||||
await service.complete_oauth("google", "auth-code", "state-123")
|
||||
|
||||
mock_uow.integrations.create.assert_not_called()
|
||||
assert existing_integration.status == IntegrationStatus.CONNECTED
|
||||
|
||||
|
||||
class TestCalendarServiceGetConnectionStatus:
|
||||
"""Tests for CalendarService.get_connection_status."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_connection_status_returns_connected_info(
|
||||
self,
|
||||
calendar_settings: CalendarSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_google_adapter: MagicMock,
|
||||
mock_outlook_adapter: MagicMock,
|
||||
mock_uow: MagicMock,
|
||||
) -> None:
|
||||
"""get_connection_status should return connection info for connected provider."""
|
||||
from noteflow.application.services.calendar_service import CalendarService
|
||||
|
||||
integration = Integration.create(
|
||||
workspace_id=uuid4(),
|
||||
name="Google Calendar",
|
||||
integration_type=IntegrationType.CALENDAR,
|
||||
config={"provider": "google"},
|
||||
)
|
||||
integration.connect(provider_email="user@gmail.com")
|
||||
mock_uow.integrations.get_by_provider = AsyncMock(return_value=integration)
|
||||
mock_uow.integrations.get_secrets = AsyncMock(return_value={
|
||||
"access_token": "token",
|
||||
"refresh_token": "refresh",
|
||||
"token_type": "Bearer",
|
||||
"expires_at": (datetime.now(UTC) + timedelta(hours=1)).isoformat(),
|
||||
"scope": "calendar",
|
||||
})
|
||||
|
||||
service = CalendarService(
|
||||
uow_factory=lambda: mock_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
google_adapter=mock_google_adapter,
|
||||
outlook_adapter=mock_outlook_adapter,
|
||||
)
|
||||
|
||||
status = await service.get_connection_status("google")
|
||||
|
||||
assert status.status == "connected"
|
||||
assert status.provider == "google"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_connection_status_returns_disconnected_when_no_integration(
|
||||
self,
|
||||
calendar_settings: CalendarSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_google_adapter: MagicMock,
|
||||
mock_outlook_adapter: MagicMock,
|
||||
mock_uow: MagicMock,
|
||||
) -> None:
|
||||
"""get_connection_status should return disconnected when no integration."""
|
||||
from noteflow.application.services.calendar_service import CalendarService
|
||||
|
||||
mock_uow.integrations.get_by_provider = AsyncMock(return_value=None)
|
||||
|
||||
service = CalendarService(
|
||||
uow_factory=lambda: mock_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
google_adapter=mock_google_adapter,
|
||||
outlook_adapter=mock_outlook_adapter,
|
||||
)
|
||||
|
||||
status = await service.get_connection_status("google")
|
||||
|
||||
assert status.status == "disconnected"
|
||||
|
||||
|
||||
class TestCalendarServiceDisconnect:
|
||||
"""Tests for CalendarService.disconnect."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_revokes_tokens_and_deletes_integration(
|
||||
self,
|
||||
calendar_settings: CalendarSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_google_adapter: MagicMock,
|
||||
mock_outlook_adapter: MagicMock,
|
||||
mock_uow: MagicMock,
|
||||
) -> None:
|
||||
"""disconnect should revoke tokens and delete integration."""
|
||||
from noteflow.application.services.calendar_service import CalendarService
|
||||
|
||||
integration = Integration.create(
|
||||
workspace_id=uuid4(),
|
||||
name="Google Calendar",
|
||||
integration_type=IntegrationType.CALENDAR,
|
||||
config={"provider": "google"},
|
||||
)
|
||||
integration.connect(provider_email="user@gmail.com")
|
||||
mock_uow.integrations.get_by_provider = AsyncMock(return_value=integration)
|
||||
mock_uow.integrations.get_secrets = AsyncMock(return_value={"access_token": "token"})
|
||||
mock_uow.integrations.delete = AsyncMock()
|
||||
|
||||
service = CalendarService(
|
||||
uow_factory=lambda: mock_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
google_adapter=mock_google_adapter,
|
||||
outlook_adapter=mock_outlook_adapter,
|
||||
)
|
||||
|
||||
result = await service.disconnect("google")
|
||||
|
||||
assert result is True
|
||||
mock_oauth_manager.revoke_tokens.assert_called_once()
|
||||
mock_uow.integrations.delete.assert_called_once()
|
||||
mock_uow.commit.assert_called()
|
||||
|
||||
|
||||
class TestCalendarServiceListEvents:
|
||||
"""Tests for CalendarService.list_calendar_events."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_events_fetches_from_connected_provider(
|
||||
self,
|
||||
calendar_settings: CalendarSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_google_adapter: MagicMock,
|
||||
mock_outlook_adapter: MagicMock,
|
||||
mock_uow: MagicMock,
|
||||
) -> None:
|
||||
"""list_calendar_events should fetch events from connected provider."""
|
||||
from noteflow.application.services.calendar_service import CalendarService
|
||||
|
||||
integration = Integration.create(
|
||||
workspace_id=uuid4(),
|
||||
name="Google Calendar",
|
||||
integration_type=IntegrationType.CALENDAR,
|
||||
config={"provider": "google"},
|
||||
)
|
||||
integration.connect(provider_email="user@gmail.com")
|
||||
mock_uow.integrations.get_by_provider = AsyncMock(return_value=integration)
|
||||
mock_uow.integrations.get_secrets = AsyncMock(return_value={
|
||||
"access_token": "token",
|
||||
"refresh_token": "refresh",
|
||||
"token_type": "Bearer",
|
||||
"expires_at": (datetime.now(UTC) + timedelta(hours=1)).isoformat(),
|
||||
"scope": "calendar",
|
||||
})
|
||||
mock_uow.integrations.update = AsyncMock()
|
||||
|
||||
service = CalendarService(
|
||||
uow_factory=lambda: mock_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
google_adapter=mock_google_adapter,
|
||||
outlook_adapter=mock_outlook_adapter,
|
||||
)
|
||||
|
||||
events = await service.list_calendar_events(provider="google")
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].title == "Test Meeting"
|
||||
mock_google_adapter.list_events.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_events_refreshes_expired_token(
|
||||
self,
|
||||
calendar_settings: CalendarSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_google_adapter: MagicMock,
|
||||
mock_outlook_adapter: MagicMock,
|
||||
mock_uow: MagicMock,
|
||||
) -> None:
|
||||
"""list_calendar_events should refresh expired token before fetching."""
|
||||
from noteflow.application.services.calendar_service import CalendarService
|
||||
|
||||
integration = Integration.create(
|
||||
workspace_id=uuid4(),
|
||||
name="Google Calendar",
|
||||
integration_type=IntegrationType.CALENDAR,
|
||||
config={"provider": "google"},
|
||||
)
|
||||
integration.connect(provider_email="user@gmail.com")
|
||||
mock_uow.integrations.get_by_provider = AsyncMock(return_value=integration)
|
||||
mock_uow.integrations.get_secrets = AsyncMock(return_value={
|
||||
"access_token": "expired-token",
|
||||
"refresh_token": "refresh-token",
|
||||
"token_type": "Bearer",
|
||||
"expires_at": (datetime.now(UTC) - timedelta(hours=1)).isoformat(),
|
||||
"scope": "calendar",
|
||||
})
|
||||
mock_uow.integrations.update = AsyncMock()
|
||||
|
||||
service = CalendarService(
|
||||
uow_factory=lambda: mock_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
google_adapter=mock_google_adapter,
|
||||
outlook_adapter=mock_outlook_adapter,
|
||||
)
|
||||
|
||||
await service.list_calendar_events(provider="google")
|
||||
|
||||
mock_oauth_manager.refresh_tokens.assert_called_once()
|
||||
mock_uow.integrations.set_secrets.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_events_raises_when_not_connected(
|
||||
self,
|
||||
calendar_settings: CalendarSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_google_adapter: MagicMock,
|
||||
mock_outlook_adapter: MagicMock,
|
||||
mock_uow: MagicMock,
|
||||
) -> None:
|
||||
"""list_calendar_events should raise error when provider not connected."""
|
||||
from noteflow.application.services.calendar_service import CalendarService, CalendarServiceError
|
||||
|
||||
mock_uow.integrations.get_by_provider = AsyncMock(return_value=None)
|
||||
|
||||
service = CalendarService(
|
||||
uow_factory=lambda: mock_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
google_adapter=mock_google_adapter,
|
||||
outlook_adapter=mock_outlook_adapter,
|
||||
)
|
||||
|
||||
with pytest.raises(CalendarServiceError, match="not connected"):
|
||||
await service.list_calendar_events(provider="google")
|
||||
@@ -83,5 +83,6 @@ def test_get_supported_formats_returns_names_and_extensions() -> None:
|
||||
|
||||
formats = {name.lower(): ext for name, ext in service.get_supported_formats()}
|
||||
|
||||
assert formats["markdown"] == ".md"
|
||||
assert formats["html"] == ".html"
|
||||
assert formats["markdown"] == ".md", "Markdown format should have .md extension"
|
||||
assert formats["html"] == ".html", "HTML format should have .html extension"
|
||||
assert formats["pdf"] == ".pdf", "PDF format should have .pdf extension"
|
||||
|
||||
364
tests/application/test_ner_service.py
Normal file
364
tests/application/test_ner_service.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""Tests for NER application service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.application.services.ner_service import ExtractionResult, NerService
|
||||
from noteflow.domain.entities import Meeting, Segment
|
||||
from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
|
||||
|
||||
class MockNerEngine:
|
||||
"""Mock NER engine for testing."""
|
||||
|
||||
def __init__(self, entities: list[NamedEntity] | None = None) -> None:
|
||||
self._entities = entities or []
|
||||
self._ready = False
|
||||
self.extract_call_count = 0
|
||||
|
||||
def extract(self, text: str) -> list[NamedEntity]:
|
||||
"""Extract entities from text (mock)."""
|
||||
self._ready = True
|
||||
self.extract_call_count += 1
|
||||
return self._entities
|
||||
|
||||
def extract_from_segments(
|
||||
self, segments: list[tuple[int, str]]
|
||||
) -> list[NamedEntity]:
|
||||
"""Extract entities from segments (mock)."""
|
||||
self._ready = True
|
||||
self.extract_call_count += 1
|
||||
return self._entities
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Check if engine is ready."""
|
||||
return self._ready
|
||||
|
||||
|
||||
def _create_entity(
|
||||
text: str,
|
||||
category: EntityCategory = EntityCategory.PERSON,
|
||||
segment_ids: list[int] | None = None,
|
||||
) -> NamedEntity:
|
||||
"""Create a test entity."""
|
||||
return NamedEntity(
|
||||
text=text,
|
||||
normalized_text=text.lower(),
|
||||
category=category,
|
||||
segment_ids=segment_ids or [0],
|
||||
confidence=0.9,
|
||||
)
|
||||
|
||||
|
||||
def _create_segment(segment_id: int, text: str) -> Segment:
|
||||
"""Create a test segment."""
|
||||
return Segment(
|
||||
segment_id=segment_id,
|
||||
text=text,
|
||||
start_time=segment_id * 5.0,
|
||||
end_time=(segment_id + 1) * 5.0,
|
||||
)
|
||||
|
||||
|
||||
class TestNerServiceExtraction:
|
||||
"""Tests for entity extraction."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_uow_factory(self, mock_uow: MagicMock) -> MagicMock:
|
||||
"""Create mock UoW factory."""
|
||||
mock_uow.entities.get_by_meeting = AsyncMock(return_value=[])
|
||||
mock_uow.entities.save_batch = AsyncMock(return_value=[])
|
||||
mock_uow.entities.delete_by_meeting = AsyncMock(return_value=0)
|
||||
mock_uow.segments.get_by_meeting = AsyncMock(return_value=[])
|
||||
return MagicMock(return_value=mock_uow)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_meeting(self) -> Meeting:
|
||||
"""Create a sample meeting with segments."""
|
||||
meeting = Meeting.create(title="Test Meeting")
|
||||
meeting.segments.extend([
|
||||
_create_segment(0, "John talked to Mary."),
|
||||
_create_segment(1, "They discussed the project."),
|
||||
])
|
||||
return meeting
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ner_engine(self) -> MockNerEngine:
|
||||
"""Create mock NER engine with sample entities."""
|
||||
return MockNerEngine(
|
||||
entities=[
|
||||
_create_entity("John", EntityCategory.PERSON, [0]),
|
||||
_create_entity("Mary", EntityCategory.PERSON, [0]),
|
||||
]
|
||||
)
|
||||
|
||||
async def test_extract_entities_returns_result(
|
||||
self,
|
||||
mock_uow: MagicMock,
|
||||
mock_uow_factory: MagicMock,
|
||||
mock_ner_engine: MockNerEngine,
|
||||
sample_meeting: Meeting,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Extract entities returns ExtractionResult."""
|
||||
mock_uow.meetings.get = AsyncMock(return_value=sample_meeting)
|
||||
mock_uow.entities.get_by_meeting = AsyncMock(return_value=[])
|
||||
mock_uow.entities.save_batch = AsyncMock(return_value=[])
|
||||
mock_uow.segments.get_by_meeting = AsyncMock(return_value=sample_meeting.segments)
|
||||
|
||||
# Mock feature flag
|
||||
monkeypatch.setattr(
|
||||
"noteflow.application.services.ner_service.get_settings",
|
||||
lambda: MagicMock(feature_flags=MagicMock(ner_enabled=True)),
|
||||
)
|
||||
|
||||
service = NerService(mock_ner_engine, mock_uow_factory)
|
||||
result = await service.extract_entities(sample_meeting.id)
|
||||
|
||||
assert isinstance(result, ExtractionResult), "Should return ExtractionResult"
|
||||
assert result.total_count == 2, "Should have 2 entities"
|
||||
assert not result.cached, "Should not be cached on first extraction"
|
||||
assert len(result.entities) == 2, "Entities list should match total_count"
|
||||
|
||||
async def test_extract_entities_uses_cache(
|
||||
self,
|
||||
mock_uow: MagicMock,
|
||||
mock_uow_factory: MagicMock,
|
||||
mock_ner_engine: MockNerEngine,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Returns cached entities without re-extraction."""
|
||||
cached_entities = [_create_entity("Cached", EntityCategory.PERSON)]
|
||||
mock_uow.entities.get_by_meeting = AsyncMock(return_value=cached_entities)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"noteflow.application.services.ner_service.get_settings",
|
||||
lambda: MagicMock(feature_flags=MagicMock(ner_enabled=True)),
|
||||
)
|
||||
|
||||
service = NerService(mock_ner_engine, mock_uow_factory)
|
||||
result = await service.extract_entities(MeetingId(uuid4()))
|
||||
|
||||
assert result.cached
|
||||
assert result.total_count == 1
|
||||
assert mock_ner_engine.extract_call_count == 0
|
||||
|
||||
async def test_extract_entities_force_refresh_bypasses_cache(
|
||||
self,
|
||||
mock_uow: MagicMock,
|
||||
mock_uow_factory: MagicMock,
|
||||
mock_ner_engine: MockNerEngine,
|
||||
sample_meeting: Meeting,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Force refresh re-extracts even with cached data."""
|
||||
cached_entities = [_create_entity("Cached", EntityCategory.PERSON)]
|
||||
mock_uow.meetings.get = AsyncMock(return_value=sample_meeting)
|
||||
mock_uow.entities.get_by_meeting = AsyncMock(return_value=cached_entities)
|
||||
mock_uow.entities.save_batch = AsyncMock(return_value=[])
|
||||
mock_uow.entities.delete_by_meeting = AsyncMock(return_value=1)
|
||||
mock_uow.segments.get_by_meeting = AsyncMock(return_value=sample_meeting.segments)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"noteflow.application.services.ner_service.get_settings",
|
||||
lambda: MagicMock(feature_flags=MagicMock(ner_enabled=True)),
|
||||
)
|
||||
|
||||
# Mark engine as ready to skip warmup extraction
|
||||
mock_ner_engine._ready = True
|
||||
initial_count = mock_ner_engine.extract_call_count
|
||||
|
||||
service = NerService(mock_ner_engine, mock_uow_factory)
|
||||
result = await service.extract_entities(sample_meeting.id, force_refresh=True)
|
||||
|
||||
assert not result.cached
|
||||
assert mock_ner_engine.extract_call_count == initial_count + 1
|
||||
mock_uow.entities.delete_by_meeting.assert_called_once()
|
||||
|
||||
async def test_extract_entities_meeting_not_found(
|
||||
self,
|
||||
mock_uow: MagicMock,
|
||||
mock_uow_factory: MagicMock,
|
||||
mock_ner_engine: MockNerEngine,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Raises ValueError when meeting not found."""
|
||||
mock_uow.entities.get_by_meeting = AsyncMock(return_value=[])
|
||||
mock_uow.meetings.get = AsyncMock(return_value=None)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"noteflow.application.services.ner_service.get_settings",
|
||||
lambda: MagicMock(feature_flags=MagicMock(ner_enabled=True)),
|
||||
)
|
||||
|
||||
service = NerService(mock_ner_engine, mock_uow_factory)
|
||||
meeting_id = MeetingId(uuid4())
|
||||
|
||||
with pytest.raises(ValueError, match=str(meeting_id)):
|
||||
await service.extract_entities(meeting_id)
|
||||
|
||||
async def test_extract_entities_feature_disabled(
|
||||
self,
|
||||
mock_uow_factory: MagicMock,
|
||||
mock_ner_engine: MockNerEngine,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Raises RuntimeError when feature flag is disabled."""
|
||||
monkeypatch.setattr(
|
||||
"noteflow.application.services.ner_service.get_settings",
|
||||
lambda: MagicMock(feature_flags=MagicMock(ner_enabled=False)),
|
||||
)
|
||||
|
||||
service = NerService(mock_ner_engine, mock_uow_factory)
|
||||
|
||||
with pytest.raises(RuntimeError, match="disabled"):
|
||||
await service.extract_entities(MeetingId(uuid4()))
|
||||
|
||||
async def test_extract_entities_no_segments_returns_empty(
|
||||
self,
|
||||
mock_uow: MagicMock,
|
||||
mock_uow_factory: MagicMock,
|
||||
mock_ner_engine: MockNerEngine,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Returns empty result for meeting with no segments."""
|
||||
meeting = Meeting.create(title="Empty Meeting")
|
||||
mock_uow.meetings.get = AsyncMock(return_value=meeting)
|
||||
mock_uow.entities.get_by_meeting = AsyncMock(return_value=[])
|
||||
|
||||
monkeypatch.setattr(
|
||||
"noteflow.application.services.ner_service.get_settings",
|
||||
lambda: MagicMock(feature_flags=MagicMock(ner_enabled=True)),
|
||||
)
|
||||
|
||||
service = NerService(mock_ner_engine, mock_uow_factory)
|
||||
result = await service.extract_entities(meeting.id)
|
||||
|
||||
assert result.total_count == 0
|
||||
assert result.entities == []
|
||||
assert not result.cached
|
||||
|
||||
|
||||
class TestNerServicePinning:
|
||||
"""Tests for entity pinning."""
|
||||
|
||||
async def test_pin_entity_success(self, mock_uow: MagicMock) -> None:
|
||||
"""Pin entity updates database and returns True."""
|
||||
mock_uow.entities.update_pinned = AsyncMock(return_value=True)
|
||||
mock_uow_factory = MagicMock(return_value=mock_uow)
|
||||
|
||||
service = NerService(MockNerEngine(), mock_uow_factory)
|
||||
entity_id = uuid4()
|
||||
|
||||
result = await service.pin_entity(entity_id, is_pinned=True)
|
||||
|
||||
assert result is True
|
||||
mock_uow.entities.update_pinned.assert_called_once_with(entity_id, True)
|
||||
mock_uow.commit.assert_called_once()
|
||||
|
||||
async def test_pin_entity_not_found(self, mock_uow: MagicMock) -> None:
|
||||
"""Pin entity returns False when entity not found."""
|
||||
mock_uow.entities.update_pinned = AsyncMock(return_value=False)
|
||||
mock_uow_factory = MagicMock(return_value=mock_uow)
|
||||
|
||||
service = NerService(MockNerEngine(), mock_uow_factory)
|
||||
entity_id = uuid4()
|
||||
|
||||
result = await service.pin_entity(entity_id)
|
||||
|
||||
assert result is False
|
||||
mock_uow.commit.assert_not_called()
|
||||
|
||||
|
||||
class TestNerServiceClear:
|
||||
"""Tests for clearing entities."""
|
||||
|
||||
async def test_clear_entities(self, mock_uow: MagicMock) -> None:
|
||||
"""Clear entities deletes all for meeting."""
|
||||
mock_uow.entities.delete_by_meeting = AsyncMock(return_value=5)
|
||||
mock_uow_factory = MagicMock(return_value=mock_uow)
|
||||
|
||||
service = NerService(MockNerEngine(), mock_uow_factory)
|
||||
meeting_id = MeetingId(uuid4())
|
||||
|
||||
count = await service.clear_entities(meeting_id)
|
||||
|
||||
assert count == 5
|
||||
mock_uow.entities.delete_by_meeting.assert_called_once_with(meeting_id)
|
||||
mock_uow.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestNerServiceHelpers:
|
||||
"""Tests for helper methods."""
|
||||
|
||||
async def test_get_entities(self, mock_uow: MagicMock) -> None:
|
||||
"""Get entities returns cached entities without extraction."""
|
||||
entities = [_create_entity("Test")]
|
||||
mock_uow.entities.get_by_meeting = AsyncMock(return_value=entities)
|
||||
mock_uow_factory = MagicMock(return_value=mock_uow)
|
||||
|
||||
service = NerService(MockNerEngine(), mock_uow_factory)
|
||||
meeting_id = MeetingId(uuid4())
|
||||
|
||||
result = await service.get_entities(meeting_id)
|
||||
|
||||
assert result == entities
|
||||
mock_uow.entities.get_by_meeting.assert_called_once_with(meeting_id)
|
||||
|
||||
async def test_has_entities_true(self, mock_uow: MagicMock) -> None:
|
||||
"""Has entities returns True when entities exist."""
|
||||
mock_uow.entities.exists_for_meeting = AsyncMock(return_value=True)
|
||||
mock_uow_factory = MagicMock(return_value=mock_uow)
|
||||
|
||||
service = NerService(MockNerEngine(), mock_uow_factory)
|
||||
meeting_id = MeetingId(uuid4())
|
||||
|
||||
result = await service.has_entities(meeting_id)
|
||||
|
||||
assert result is True
|
||||
|
||||
async def test_has_entities_false(self, mock_uow: MagicMock) -> None:
|
||||
"""Has entities returns False when no entities exist."""
|
||||
mock_uow.entities.exists_for_meeting = AsyncMock(return_value=False)
|
||||
mock_uow_factory = MagicMock(return_value=mock_uow)
|
||||
|
||||
service = NerService(MockNerEngine(), mock_uow_factory)
|
||||
meeting_id = MeetingId(uuid4())
|
||||
|
||||
result = await service.has_entities(meeting_id)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_is_engine_ready_checks_both_flag_and_engine(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Is engine ready checks both feature flag and engine state."""
|
||||
engine = MockNerEngine()
|
||||
service = NerService(engine, MagicMock())
|
||||
|
||||
# NER disabled - should be False regardless of engine state
|
||||
monkeypatch.setattr(
|
||||
"noteflow.application.services.ner_service.get_settings",
|
||||
lambda: MagicMock(feature_flags=MagicMock(ner_enabled=False)),
|
||||
)
|
||||
engine._ready = True
|
||||
assert not service.is_engine_ready(), "Should be False when NER disabled"
|
||||
|
||||
# NER enabled but engine not ready
|
||||
monkeypatch.setattr(
|
||||
"noteflow.application.services.ner_service.get_settings",
|
||||
lambda: MagicMock(feature_flags=MagicMock(ner_enabled=True)),
|
||||
)
|
||||
engine._ready = False
|
||||
assert not service.is_engine_ready(), "Should be False when engine not ready"
|
||||
|
||||
# NER enabled and engine ready
|
||||
engine._ready = True
|
||||
assert service.is_engine_ready(), "Should be True when both conditions met"
|
||||
@@ -136,6 +136,7 @@ def mock_uow() -> MagicMock:
|
||||
uow.annotations = MagicMock()
|
||||
uow.preferences = MagicMock()
|
||||
uow.diarization_jobs = MagicMock()
|
||||
uow.entities = MagicMock()
|
||||
return uow
|
||||
|
||||
|
||||
|
||||
@@ -211,31 +211,31 @@ class TestMeetingProperties:
|
||||
meeting.started_at = utc_now() - timedelta(seconds=5)
|
||||
assert meeting.duration_seconds >= 5.0
|
||||
|
||||
def test_is_active_created(self) -> None:
|
||||
"""Test is_active returns True for CREATED state."""
|
||||
def test_is_in_active_state_created(self) -> None:
|
||||
"""Test is_in_active_state returns True for CREATED state."""
|
||||
meeting = Meeting.create()
|
||||
assert meeting.is_active() is True
|
||||
assert meeting.is_in_active_state() is True
|
||||
|
||||
def test_is_active_recording(self) -> None:
|
||||
"""Test is_active returns True for RECORDING state."""
|
||||
def test_is_in_active_state_recording(self) -> None:
|
||||
"""Test is_in_active_state returns True for RECORDING state."""
|
||||
meeting = Meeting.create()
|
||||
meeting.start_recording()
|
||||
assert meeting.is_active() is True
|
||||
assert meeting.is_in_active_state() is True
|
||||
|
||||
def test_is_active_stopping(self) -> None:
|
||||
"""Test is_active returns False for STOPPING state."""
|
||||
def test_is_in_active_state_stopping(self) -> None:
|
||||
"""Test is_in_active_state returns False for STOPPING state."""
|
||||
meeting = Meeting.create()
|
||||
meeting.start_recording()
|
||||
meeting.begin_stopping()
|
||||
assert meeting.is_active() is False
|
||||
assert meeting.is_in_active_state() is False
|
||||
|
||||
def test_is_active_stopped(self) -> None:
|
||||
"""Test is_active returns False for STOPPED state."""
|
||||
def test_is_in_active_state_stopped(self) -> None:
|
||||
"""Test is_in_active_state returns False for STOPPED state."""
|
||||
meeting = Meeting.create()
|
||||
meeting.start_recording()
|
||||
meeting.begin_stopping()
|
||||
meeting.stop_recording()
|
||||
assert meeting.is_active() is False
|
||||
assert meeting.is_in_active_state() is False
|
||||
|
||||
def test_has_summary_false(self) -> None:
|
||||
"""Test has_summary returns False when no summary."""
|
||||
|
||||
271
tests/domain/test_named_entity.py
Normal file
271
tests/domain/test_named_entity.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""Tests for NamedEntity domain entity."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
|
||||
|
||||
class TestEntityCategory:
|
||||
"""Tests for EntityCategory enum."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("value", "expected"),
|
||||
[
|
||||
("person", EntityCategory.PERSON),
|
||||
("company", EntityCategory.COMPANY),
|
||||
("product", EntityCategory.PRODUCT),
|
||||
("technical", EntityCategory.TECHNICAL),
|
||||
("acronym", EntityCategory.ACRONYM),
|
||||
("location", EntityCategory.LOCATION),
|
||||
("date", EntityCategory.DATE),
|
||||
("other", EntityCategory.OTHER),
|
||||
],
|
||||
)
|
||||
def test_from_string_valid_values(
|
||||
self, value: str, expected: EntityCategory
|
||||
) -> None:
|
||||
"""Convert lowercase string to EntityCategory."""
|
||||
assert EntityCategory.from_string(value) == expected
|
||||
|
||||
@pytest.mark.parametrize("value", ["PERSON", "Person", "COMPANY"])
|
||||
def test_from_string_case_insensitive(self, value: str) -> None:
|
||||
"""Convert mixed case string to EntityCategory."""
|
||||
result = EntityCategory.from_string(value)
|
||||
assert result in EntityCategory
|
||||
|
||||
def test_from_string_invalid_raises(self) -> None:
|
||||
"""Invalid category string raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid entity category"):
|
||||
EntityCategory.from_string("invalid_category")
|
||||
|
||||
|
||||
class TestNamedEntityValidation:
|
||||
"""Tests for NamedEntity validation in __post_init__."""
|
||||
|
||||
@pytest.mark.parametrize("confidence", [-0.1, 1.1, 2.0, -1.0])
|
||||
def test_invalid_confidence_raises(self, confidence: float) -> None:
|
||||
"""Confidence outside 0-1 range raises ValueError."""
|
||||
with pytest.raises(ValueError, match="confidence must be between 0 and 1"):
|
||||
NamedEntity(
|
||||
text="John",
|
||||
category=EntityCategory.PERSON,
|
||||
confidence=confidence,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("confidence", [0.0, 0.5, 1.0, 0.95])
|
||||
def test_valid_confidence_boundaries(self, confidence: float) -> None:
|
||||
"""Confidence at valid boundaries is accepted."""
|
||||
entity = NamedEntity(
|
||||
text="John",
|
||||
category=EntityCategory.PERSON,
|
||||
confidence=confidence,
|
||||
)
|
||||
assert entity.confidence == confidence
|
||||
|
||||
def test_auto_computes_normalized_text(self) -> None:
|
||||
"""Normalized text is auto-computed from text when not provided."""
|
||||
entity = NamedEntity(
|
||||
text="John SMITH",
|
||||
category=EntityCategory.PERSON,
|
||||
confidence=0.9,
|
||||
)
|
||||
assert entity.normalized_text == "john smith"
|
||||
|
||||
def test_preserves_explicit_normalized_text(self) -> None:
|
||||
"""Explicit normalized_text is preserved."""
|
||||
entity = NamedEntity(
|
||||
text="John Smith",
|
||||
normalized_text="custom_normalization",
|
||||
category=EntityCategory.PERSON,
|
||||
confidence=0.9,
|
||||
)
|
||||
assert entity.normalized_text == "custom_normalization"
|
||||
|
||||
|
||||
class TestNamedEntityCreate:
|
||||
"""Tests for NamedEntity.create() factory method."""
|
||||
|
||||
def test_create_with_valid_input(self) -> None:
|
||||
"""Create entity with valid input returns properly initialized entity."""
|
||||
meeting_id = MeetingId(uuid4())
|
||||
entity = NamedEntity.create(
|
||||
text="Acme Corporation",
|
||||
category=EntityCategory.COMPANY,
|
||||
segment_ids=[0, 2, 1],
|
||||
confidence=0.85,
|
||||
meeting_id=meeting_id,
|
||||
)
|
||||
|
||||
assert entity.text == "Acme Corporation", "Text should be preserved"
|
||||
assert entity.normalized_text == "acme corporation", "Normalized text should be lowercase"
|
||||
assert entity.category == EntityCategory.COMPANY, "Category should be preserved"
|
||||
assert entity.segment_ids == [0, 1, 2], "Segment IDs should be sorted"
|
||||
assert entity.confidence == 0.85, "Confidence should be preserved"
|
||||
assert entity.meeting_id == meeting_id, "Meeting ID should be preserved"
|
||||
|
||||
def test_create_strips_whitespace(self) -> None:
|
||||
"""Create strips leading/trailing whitespace from text."""
|
||||
entity = NamedEntity.create(
|
||||
text=" John Smith ",
|
||||
category=EntityCategory.PERSON,
|
||||
segment_ids=[0],
|
||||
confidence=0.9,
|
||||
)
|
||||
assert entity.text == "John Smith"
|
||||
assert entity.normalized_text == "john smith"
|
||||
|
||||
def test_create_deduplicates_segment_ids(self) -> None:
|
||||
"""Create deduplicates and sorts segment IDs."""
|
||||
entity = NamedEntity.create(
|
||||
text="Test",
|
||||
category=EntityCategory.OTHER,
|
||||
segment_ids=[3, 1, 1, 3, 2],
|
||||
confidence=0.8,
|
||||
)
|
||||
assert entity.segment_ids == [1, 2, 3]
|
||||
|
||||
def test_create_empty_text_raises(self) -> None:
|
||||
"""Create with empty text raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Entity text cannot be empty"):
|
||||
NamedEntity.create(
|
||||
text="",
|
||||
category=EntityCategory.PERSON,
|
||||
segment_ids=[0],
|
||||
confidence=0.9,
|
||||
)
|
||||
|
||||
def test_create_whitespace_only_text_raises(self) -> None:
|
||||
"""Create with whitespace-only text raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Entity text cannot be empty"):
|
||||
NamedEntity.create(
|
||||
text=" ",
|
||||
category=EntityCategory.PERSON,
|
||||
segment_ids=[0],
|
||||
confidence=0.9,
|
||||
)
|
||||
|
||||
def test_create_invalid_confidence_raises(self) -> None:
|
||||
"""Create with invalid confidence raises ValueError."""
|
||||
with pytest.raises(ValueError, match="confidence must be between 0 and 1"):
|
||||
NamedEntity.create(
|
||||
text="John",
|
||||
category=EntityCategory.PERSON,
|
||||
segment_ids=[0],
|
||||
confidence=1.5,
|
||||
)
|
||||
|
||||
|
||||
class TestNamedEntityOccurrenceCount:
|
||||
"""Tests for occurrence_count property."""
|
||||
|
||||
def test_occurrence_count_with_segments(self) -> None:
|
||||
"""Occurrence count returns number of unique segment IDs."""
|
||||
entity = NamedEntity(
|
||||
text="Test",
|
||||
category=EntityCategory.OTHER,
|
||||
segment_ids=[0, 1, 2],
|
||||
confidence=0.8,
|
||||
)
|
||||
assert entity.occurrence_count == 3
|
||||
|
||||
def test_occurrence_count_empty_segments(self) -> None:
|
||||
"""Occurrence count returns 0 for empty segment_ids."""
|
||||
entity = NamedEntity(
|
||||
text="Test",
|
||||
category=EntityCategory.OTHER,
|
||||
segment_ids=[],
|
||||
confidence=0.8,
|
||||
)
|
||||
assert entity.occurrence_count == 0
|
||||
|
||||
def test_occurrence_count_single_segment(self) -> None:
|
||||
"""Occurrence count returns 1 for single segment."""
|
||||
entity = NamedEntity(
|
||||
text="Test",
|
||||
category=EntityCategory.OTHER,
|
||||
segment_ids=[5],
|
||||
confidence=0.8,
|
||||
)
|
||||
assert entity.occurrence_count == 1
|
||||
|
||||
|
||||
class TestNamedEntityMergeSegments:
|
||||
"""Tests for merge_segments method."""
|
||||
|
||||
def test_merge_segments_adds_new(self) -> None:
|
||||
"""Merge segments adds new segment IDs."""
|
||||
entity = NamedEntity(
|
||||
text="John",
|
||||
category=EntityCategory.PERSON,
|
||||
segment_ids=[0, 1],
|
||||
confidence=0.9,
|
||||
)
|
||||
entity.merge_segments([3, 4])
|
||||
assert entity.segment_ids == [0, 1, 3, 4]
|
||||
|
||||
def test_merge_segments_deduplicates(self) -> None:
|
||||
"""Merge segments deduplicates overlapping IDs."""
|
||||
entity = NamedEntity(
|
||||
text="John",
|
||||
category=EntityCategory.PERSON,
|
||||
segment_ids=[0, 1, 2],
|
||||
confidence=0.9,
|
||||
)
|
||||
entity.merge_segments([1, 2, 3])
|
||||
assert entity.segment_ids == [0, 1, 2, 3]
|
||||
|
||||
def test_merge_segments_sorts(self) -> None:
|
||||
"""Merge segments keeps result sorted."""
|
||||
entity = NamedEntity(
|
||||
text="John",
|
||||
category=EntityCategory.PERSON,
|
||||
segment_ids=[5, 10],
|
||||
confidence=0.9,
|
||||
)
|
||||
entity.merge_segments([1, 3])
|
||||
assert entity.segment_ids == [1, 3, 5, 10]
|
||||
|
||||
def test_merge_empty_segments(self) -> None:
|
||||
"""Merge with empty list preserves original segments."""
|
||||
entity = NamedEntity(
|
||||
text="John",
|
||||
category=EntityCategory.PERSON,
|
||||
segment_ids=[0, 1],
|
||||
confidence=0.9,
|
||||
)
|
||||
entity.merge_segments([])
|
||||
assert entity.segment_ids == [0, 1]
|
||||
|
||||
|
||||
class TestNamedEntityDefaults:
|
||||
"""Tests for NamedEntity default values."""
|
||||
|
||||
def test_default_meeting_id_is_none(self) -> None:
|
||||
"""Default meeting_id is None."""
|
||||
entity = NamedEntity(text="Test", category=EntityCategory.OTHER, confidence=0.5)
|
||||
assert entity.meeting_id is None
|
||||
|
||||
def test_default_segment_ids_is_empty(self) -> None:
|
||||
"""Default segment_ids is empty list."""
|
||||
entity = NamedEntity(text="Test", category=EntityCategory.OTHER, confidence=0.5)
|
||||
assert entity.segment_ids == []
|
||||
|
||||
def test_default_is_pinned_is_false(self) -> None:
|
||||
"""Default is_pinned is False."""
|
||||
entity = NamedEntity(text="Test", category=EntityCategory.OTHER, confidence=0.5)
|
||||
assert entity.is_pinned is False
|
||||
|
||||
def test_default_db_id_is_none(self) -> None:
|
||||
"""Default db_id is None."""
|
||||
entity = NamedEntity(text="Test", category=EntityCategory.OTHER, confidence=0.5)
|
||||
assert entity.db_id is None
|
||||
|
||||
def test_id_is_auto_generated(self) -> None:
|
||||
"""UUID id is auto-generated."""
|
||||
entity = NamedEntity(text="Test", category=EntityCategory.OTHER, confidence=0.5)
|
||||
assert entity.id is not None
|
||||
1
tests/infrastructure/calendar/__init__.py
Normal file
1
tests/infrastructure/calendar/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for calendar infrastructure."""
|
||||
269
tests/infrastructure/calendar/test_google_adapter.py
Normal file
269
tests/infrastructure/calendar/test_google_adapter.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""Tests for Google Calendar adapter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestGoogleCalendarAdapterListEvents:
|
||||
"""Tests for GoogleCalendarAdapter.list_events."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_events_returns_calendar_events(self) -> None:
|
||||
"""list_events should return parsed calendar events."""
|
||||
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
|
||||
|
||||
adapter = GoogleCalendarAdapter()
|
||||
|
||||
now = datetime.now(UTC)
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"items": [
|
||||
{
|
||||
"id": "event-1",
|
||||
"summary": "Team Standup",
|
||||
"start": {"dateTime": (now + timedelta(hours=1)).isoformat()},
|
||||
"end": {"dateTime": (now + timedelta(hours=2)).isoformat()},
|
||||
"attendees": [
|
||||
{"email": "alice@example.com"},
|
||||
{"email": "bob@example.com"},
|
||||
],
|
||||
"hangoutLink": "https://meet.google.com/abc-defg-hij",
|
||||
},
|
||||
{
|
||||
"id": "event-2",
|
||||
"summary": "All-Day Planning",
|
||||
"start": {"date": now.strftime("%Y-%m-%d")},
|
||||
"end": {"date": (now + timedelta(days=1)).strftime("%Y-%m-%d")},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
events = await adapter.list_events("access-token", hours_ahead=24, limit=10)
|
||||
|
||||
assert len(events) == 2
|
||||
assert events[0].title == "Team Standup"
|
||||
assert events[0].attendees == ("alice@example.com", "bob@example.com")
|
||||
assert events[0].meeting_url == "https://meet.google.com/abc-defg-hij"
|
||||
assert events[0].provider == "google"
|
||||
assert events[1].title == "All-Day Planning"
|
||||
assert events[1].is_all_day is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_events_handles_empty_response(self) -> None:
|
||||
"""list_events should return empty list when no events."""
|
||||
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
|
||||
|
||||
adapter = GoogleCalendarAdapter()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"items": []}
|
||||
|
||||
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
events = await adapter.list_events("access-token")
|
||||
|
||||
assert events == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_events_raises_on_expired_token(self) -> None:
|
||||
"""list_events should raise error on 401 response."""
|
||||
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
|
||||
from noteflow.infrastructure.calendar.google_adapter import GoogleCalendarError
|
||||
|
||||
adapter = GoogleCalendarAdapter()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = "Token expired"
|
||||
|
||||
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with pytest.raises(GoogleCalendarError, match="expired or invalid"):
|
||||
await adapter.list_events("expired-token")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_events_raises_on_api_error(self) -> None:
|
||||
"""list_events should raise error on non-200 response."""
|
||||
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
|
||||
from noteflow.infrastructure.calendar.google_adapter import GoogleCalendarError
|
||||
|
||||
adapter = GoogleCalendarAdapter()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.text = "Internal server error"
|
||||
|
||||
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with pytest.raises(GoogleCalendarError, match="API error"):
|
||||
await adapter.list_events("access-token")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_events_parses_conference_data(self) -> None:
|
||||
"""list_events should extract meeting URL from conferenceData."""
|
||||
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
|
||||
|
||||
adapter = GoogleCalendarAdapter()
|
||||
|
||||
now = datetime.now(UTC)
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"items": [
|
||||
{
|
||||
"id": "event-zoom",
|
||||
"summary": "Zoom Meeting",
|
||||
"start": {"dateTime": now.isoformat()},
|
||||
"end": {"dateTime": (now + timedelta(hours=1)).isoformat()},
|
||||
"conferenceData": {
|
||||
"entryPoints": [
|
||||
{"entryPointType": "video", "uri": "https://zoom.us/j/123456"},
|
||||
{"entryPointType": "phone", "uri": "tel:+1234567890"},
|
||||
]
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
events = await adapter.list_events("access-token")
|
||||
|
||||
assert events[0].meeting_url == "https://zoom.us/j/123456"
|
||||
|
||||
|
||||
class TestGoogleCalendarAdapterGetUserEmail:
|
||||
"""Tests for GoogleCalendarAdapter.get_user_email."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_email_returns_email(self) -> None:
|
||||
"""get_user_email should return user's email address."""
|
||||
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
|
||||
|
||||
adapter = GoogleCalendarAdapter()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"email": "user@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
email = await adapter.get_user_email("access-token")
|
||||
|
||||
assert email == "user@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_email_raises_on_missing_email(self) -> None:
|
||||
"""get_user_email should raise when email not in response."""
|
||||
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
|
||||
from noteflow.infrastructure.calendar.google_adapter import GoogleCalendarError
|
||||
|
||||
adapter = GoogleCalendarAdapter()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"name": "No Email User"}
|
||||
|
||||
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with pytest.raises(GoogleCalendarError, match="No email"):
|
||||
await adapter.get_user_email("access-token")
|
||||
|
||||
|
||||
class TestGoogleCalendarAdapterDateParsing:
|
||||
"""Tests for date/time parsing in GoogleCalendarAdapter."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parses_utc_datetime_with_z_suffix(self) -> None:
|
||||
"""Should parse datetime with Z suffix correctly."""
|
||||
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
|
||||
|
||||
adapter = GoogleCalendarAdapter()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"items": [
|
||||
{
|
||||
"id": "event-utc",
|
||||
"summary": "UTC Event",
|
||||
"start": {"dateTime": "2024-03-15T10:00:00Z"},
|
||||
"end": {"dateTime": "2024-03-15T11:00:00Z"},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
events = await adapter.list_events("access-token")
|
||||
|
||||
assert events[0].start_time.tzinfo is not None
|
||||
assert events[0].start_time.hour == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parses_datetime_with_offset(self) -> None:
|
||||
"""Should parse datetime with timezone offset correctly."""
|
||||
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
|
||||
|
||||
adapter = GoogleCalendarAdapter()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"items": [
|
||||
{
|
||||
"id": "event-offset",
|
||||
"summary": "Offset Event",
|
||||
"start": {"dateTime": "2024-03-15T10:00:00-08:00"},
|
||||
"end": {"dateTime": "2024-03-15T11:00:00-08:00"},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
events = await adapter.list_events("access-token")
|
||||
|
||||
assert events[0].start_time.tzinfo is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_identifies_recurring_events(self) -> None:
|
||||
"""Should identify recurring events via recurringEventId."""
|
||||
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
|
||||
|
||||
adapter = GoogleCalendarAdapter()
|
||||
|
||||
now = datetime.now(UTC)
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"items": [
|
||||
{
|
||||
"id": "event-instance",
|
||||
"summary": "Weekly Meeting",
|
||||
"start": {"dateTime": now.isoformat()},
|
||||
"end": {"dateTime": (now + timedelta(hours=1)).isoformat()},
|
||||
"recurringEventId": "event-master",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
events = await adapter.list_events("access-token")
|
||||
|
||||
assert events[0].is_recurring is True
|
||||
266
tests/infrastructure/calendar/test_oauth_manager.py
Normal file
266
tests/infrastructure/calendar/test_oauth_manager.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""Tests for OAuth manager."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.config.settings import CalendarSettings
|
||||
from noteflow.domain.value_objects import OAuthProvider
|
||||
from noteflow.infrastructure.calendar.oauth_manager import OAuthError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def calendar_settings() -> CalendarSettings:
|
||||
"""Create test calendar settings."""
|
||||
return CalendarSettings(
|
||||
google_client_id="test-google-client-id",
|
||||
google_client_secret="test-google-client-secret",
|
||||
outlook_client_id="test-outlook-client-id",
|
||||
outlook_client_secret="test-outlook-client-secret",
|
||||
)
|
||||
|
||||
|
||||
class TestOAuthManagerInitiateAuth:
|
||||
"""Tests for OAuthManager.initiate_auth."""
|
||||
|
||||
def test_initiate_google_auth_returns_url_and_state(
|
||||
self, calendar_settings: CalendarSettings
|
||||
) -> None:
|
||||
"""initiate_auth should return auth URL and state for Google."""
|
||||
from noteflow.infrastructure.calendar import OAuthManager
|
||||
|
||||
manager = OAuthManager(calendar_settings)
|
||||
auth_url, state = manager.initiate_auth(OAuthProvider.GOOGLE, "http://localhost:8080/callback")
|
||||
|
||||
assert auth_url.startswith("https://accounts.google.com/o/oauth2/v2/auth")
|
||||
assert "client_id=test-google-client-id" in auth_url
|
||||
assert "redirect_uri=http" in auth_url
|
||||
assert "code_challenge=" in auth_url
|
||||
assert "code_challenge_method=S256" in auth_url
|
||||
assert len(state) > 0
|
||||
|
||||
def test_initiate_outlook_auth_returns_url_and_state(
|
||||
self, calendar_settings: CalendarSettings
|
||||
) -> None:
|
||||
"""initiate_auth should return auth URL and state for Outlook."""
|
||||
from noteflow.infrastructure.calendar import OAuthManager
|
||||
|
||||
manager = OAuthManager(calendar_settings)
|
||||
auth_url, state = manager.initiate_auth(OAuthProvider.OUTLOOK, "http://localhost:8080/callback")
|
||||
|
||||
assert auth_url.startswith("https://login.microsoftonline.com/common/oauth2/v2.0/authorize")
|
||||
assert "client_id=test-outlook-client-id" in auth_url
|
||||
assert len(state) > 0
|
||||
|
||||
def test_initiate_auth_stores_pending_state(
|
||||
self, calendar_settings: CalendarSettings
|
||||
) -> None:
|
||||
"""initiate_auth should store pending state for later validation."""
|
||||
from noteflow.infrastructure.calendar import OAuthManager
|
||||
|
||||
manager = OAuthManager(calendar_settings)
|
||||
_, state = manager.initiate_auth(OAuthProvider.GOOGLE, "http://localhost:8080/callback")
|
||||
|
||||
assert state in manager._pending_states
|
||||
assert manager._pending_states[state].provider == OAuthProvider.GOOGLE
|
||||
|
||||
def test_initiate_auth_missing_credentials_raises(
|
||||
self, calendar_settings: CalendarSettings
|
||||
) -> None:
|
||||
"""initiate_auth should raise for missing OAuth credentials."""
|
||||
from noteflow.infrastructure.calendar import OAuthManager
|
||||
|
||||
# Create settings with missing Google credentials
|
||||
settings = CalendarSettings(
|
||||
google_client_id="",
|
||||
google_client_secret="",
|
||||
outlook_client_id="test-id",
|
||||
outlook_client_secret="test-secret",
|
||||
)
|
||||
manager = OAuthManager(settings)
|
||||
|
||||
with pytest.raises(OAuthError, match="not configured"):
|
||||
manager.initiate_auth(OAuthProvider.GOOGLE, "http://localhost:8080/callback")
|
||||
|
||||
|
||||
class TestOAuthManagerCompleteAuth:
|
||||
"""Tests for OAuthManager.complete_auth."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_auth_exchanges_code_for_tokens(
|
||||
self, calendar_settings: CalendarSettings
|
||||
) -> None:
|
||||
"""complete_auth should exchange authorization code for tokens."""
|
||||
from noteflow.infrastructure.calendar import OAuthManager
|
||||
|
||||
manager = OAuthManager(calendar_settings)
|
||||
_, state = manager.initiate_auth(OAuthProvider.GOOGLE, "http://localhost:8080/callback")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "access-token-123",
|
||||
"refresh_token": "refresh-token-456",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"scope": "email calendar.readonly",
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient.post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
tokens = await manager.complete_auth(OAuthProvider.GOOGLE, "auth-code-xyz", state)
|
||||
|
||||
assert tokens.access_token == "access-token-123"
|
||||
assert tokens.refresh_token == "refresh-token-456"
|
||||
assert tokens.token_type == "Bearer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_auth_invalid_state_raises(
|
||||
self, calendar_settings: CalendarSettings
|
||||
) -> None:
|
||||
"""complete_auth should raise for invalid state."""
|
||||
from noteflow.infrastructure.calendar import OAuthManager
|
||||
|
||||
manager = OAuthManager(calendar_settings)
|
||||
|
||||
with pytest.raises(OAuthError, match="Invalid or expired"):
|
||||
await manager.complete_auth(OAuthProvider.GOOGLE, "code", "invalid-state")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_auth_expired_state_raises(
|
||||
self, calendar_settings: CalendarSettings
|
||||
) -> None:
|
||||
"""complete_auth should raise for expired state."""
|
||||
from noteflow.infrastructure.calendar import OAuthManager
|
||||
|
||||
manager = OAuthManager(calendar_settings)
|
||||
_, state = manager.initiate_auth(OAuthProvider.GOOGLE, "http://localhost:8080/callback")
|
||||
|
||||
# Expire the state manually by replacing the OAuthState
|
||||
old_state = manager._pending_states[state]
|
||||
from noteflow.domain.value_objects import OAuthState
|
||||
expired_state = OAuthState(
|
||||
state=old_state.state,
|
||||
provider=old_state.provider,
|
||||
redirect_uri=old_state.redirect_uri,
|
||||
code_verifier=old_state.code_verifier,
|
||||
created_at=old_state.created_at,
|
||||
expires_at=datetime.now(UTC) - timedelta(minutes=1),
|
||||
)
|
||||
manager._pending_states[state] = expired_state
|
||||
|
||||
with pytest.raises(OAuthError, match="expired"):
|
||||
await manager.complete_auth(OAuthProvider.GOOGLE, "code", state)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_auth_removes_used_state(
|
||||
self, calendar_settings: CalendarSettings
|
||||
) -> None:
|
||||
"""complete_auth should remove state after successful use."""
|
||||
from noteflow.infrastructure.calendar import OAuthManager
|
||||
|
||||
manager = OAuthManager(calendar_settings)
|
||||
_, state = manager.initiate_auth(OAuthProvider.GOOGLE, "http://localhost:8080/callback")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "token",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient.post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
await manager.complete_auth(OAuthProvider.GOOGLE, "code", state)
|
||||
|
||||
assert state not in manager._pending_states
|
||||
|
||||
|
||||
class TestOAuthManagerRefreshTokens:
|
||||
"""Tests for OAuthManager.refresh_tokens."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_tokens_returns_new_tokens(
|
||||
self, calendar_settings: CalendarSettings
|
||||
) -> None:
|
||||
"""refresh_tokens should return new access token."""
|
||||
from noteflow.infrastructure.calendar import OAuthManager
|
||||
|
||||
manager = OAuthManager(calendar_settings)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "new-access-token",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient.post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
tokens = await manager.refresh_tokens(OAuthProvider.GOOGLE, "old-refresh-token")
|
||||
|
||||
assert tokens.access_token == "new-access-token"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_tokens_preserves_refresh_token_if_not_returned(
|
||||
self, calendar_settings: CalendarSettings
|
||||
) -> None:
|
||||
"""refresh_tokens should preserve old refresh token if not in response."""
|
||||
from noteflow.infrastructure.calendar import OAuthManager
|
||||
|
||||
manager = OAuthManager(calendar_settings)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "new-access-token",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient.post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
tokens = await manager.refresh_tokens(OAuthProvider.GOOGLE, "old-refresh-token")
|
||||
|
||||
assert tokens.refresh_token == "old-refresh-token"
|
||||
|
||||
|
||||
class TestOAuthManagerRevokeTokens:
|
||||
"""Tests for OAuthManager.revoke_tokens."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_google_tokens_calls_revocation_endpoint(
|
||||
self, calendar_settings: CalendarSettings
|
||||
) -> None:
|
||||
"""revoke_tokens should call Google revocation endpoint."""
|
||||
from noteflow.infrastructure.calendar import OAuthManager
|
||||
|
||||
manager = OAuthManager(calendar_settings)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.AsyncClient.post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
await manager.revoke_tokens(OAuthProvider.GOOGLE, "access-token")
|
||||
|
||||
mock_post.assert_called_once()
|
||||
call_args = mock_post.call_args
|
||||
assert "oauth2.googleapis.com/revoke" in call_args[0][0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_outlook_tokens_handles_no_revocation_endpoint(
|
||||
self, calendar_settings: CalendarSettings
|
||||
) -> None:
|
||||
"""revoke_tokens should succeed silently for Outlook (no revocation endpoint)."""
|
||||
from noteflow.infrastructure.calendar import OAuthManager
|
||||
|
||||
manager = OAuthManager(calendar_settings)
|
||||
|
||||
# Should not raise - Outlook doesn't have a revocation endpoint
|
||||
await manager.revoke_tokens(OAuthProvider.OUTLOOK, "access-token")
|
||||
1
tests/infrastructure/ner/__init__.py
Normal file
1
tests/infrastructure/ner/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""NER infrastructure tests."""
|
||||
127
tests/infrastructure/ner/test_engine.py
Normal file
127
tests/infrastructure/ner/test_engine.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Tests for NER engine (spaCy wrapper)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity
|
||||
from noteflow.infrastructure.ner import NerEngine
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ner_engine() -> NerEngine:
|
||||
"""Create NER engine (module-scoped to avoid repeated model loads)."""
|
||||
engine = NerEngine(model_name="en_core_web_sm")
|
||||
return engine
|
||||
|
||||
|
||||
class TestNerEngineBasics:
|
||||
"""Basic NER engine functionality tests."""
|
||||
|
||||
def test_is_ready_before_load(self) -> None:
|
||||
"""Engine is not ready before first use."""
|
||||
engine = NerEngine()
|
||||
assert not engine.is_ready()
|
||||
|
||||
def test_is_ready_after_extract(self, ner_engine: NerEngine) -> None:
|
||||
"""Engine is ready after extraction triggers lazy load."""
|
||||
ner_engine.extract("Hello, John.")
|
||||
assert ner_engine.is_ready()
|
||||
|
||||
|
||||
class TestEntityExtraction:
|
||||
"""Tests for entity extraction from text."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("text", "expected_category", "expected_text"),
|
||||
[
|
||||
("I met John Smith yesterday.", EntityCategory.PERSON, "John Smith"),
|
||||
("Apple Inc. announced new products.", EntityCategory.COMPANY, "Apple Inc."),
|
||||
("We visited New York City.", EntityCategory.LOCATION, "New York City"),
|
||||
("The meeting is scheduled for Monday.", EntityCategory.DATE, "Monday"),
|
||||
],
|
||||
)
|
||||
def test_extract_entity_types(
|
||||
self,
|
||||
ner_engine: NerEngine,
|
||||
text: str,
|
||||
expected_category: EntityCategory,
|
||||
expected_text: str,
|
||||
) -> None:
|
||||
"""Extract entities of various types."""
|
||||
entities = ner_engine.extract(text)
|
||||
matching = [e for e in entities if e.category == expected_category]
|
||||
assert matching, f"No {expected_category.value} entities found in: {text}"
|
||||
texts = [e.text for e in matching]
|
||||
assert expected_text in texts, f"Expected '{expected_text}' not found in {texts}"
|
||||
|
||||
def test_extract_returns_list(self, ner_engine: NerEngine) -> None:
|
||||
"""Extract returns a list of NamedEntity objects."""
|
||||
entities = ner_engine.extract("John works at Google.")
|
||||
assert isinstance(entities, list)
|
||||
assert all(isinstance(e, NamedEntity) for e in entities)
|
||||
|
||||
def test_extract_empty_text_returns_empty(self, ner_engine: NerEngine) -> None:
|
||||
"""Empty text returns empty list."""
|
||||
entities = ner_engine.extract("")
|
||||
assert entities == []
|
||||
|
||||
def test_extract_no_entities_returns_empty(self, ner_engine: NerEngine) -> None:
|
||||
"""Text with no entities returns empty list."""
|
||||
entities = ner_engine.extract("The quick brown fox.")
|
||||
# May still find entities depending on model, but should not crash
|
||||
assert isinstance(entities, list)
|
||||
|
||||
|
||||
class TestSegmentExtraction:
|
||||
"""Tests for extraction from multiple segments."""
|
||||
|
||||
def test_extract_from_segments_tracks_segment_ids(self, ner_engine: NerEngine) -> None:
|
||||
"""Segment extraction tracks which segments contain each entity."""
|
||||
segments = [
|
||||
(0, "John talked to Mary."),
|
||||
(1, "Mary went to the store."),
|
||||
(2, "Then John called Mary again."),
|
||||
]
|
||||
entities = ner_engine.extract_from_segments(segments)
|
||||
|
||||
# Mary should appear in multiple segments
|
||||
mary_entities = [e for e in entities if "mary" in e.normalized_text]
|
||||
assert mary_entities, "Mary entity not found"
|
||||
mary = mary_entities[0]
|
||||
assert len(mary.segment_ids) >= 2, "Mary should appear in multiple segments"
|
||||
|
||||
def test_extract_from_segments_deduplicates(self, ner_engine: NerEngine) -> None:
|
||||
"""Same entity in multiple segments is deduplicated."""
|
||||
segments = [
|
||||
(0, "John Smith is here."),
|
||||
(1, "John Smith left."),
|
||||
]
|
||||
entities = ner_engine.extract_from_segments(segments)
|
||||
|
||||
# Should have one John Smith entity (deduplicated by normalized text)
|
||||
john_entities = [e for e in entities if "john" in e.normalized_text]
|
||||
assert len(john_entities) == 1, "John Smith should be deduplicated"
|
||||
assert len(john_entities[0].segment_ids) == 2, "Should track both segments"
|
||||
|
||||
def test_extract_from_segments_empty_returns_empty(self, ner_engine: NerEngine) -> None:
|
||||
"""Empty segments list returns empty entities."""
|
||||
entities = ner_engine.extract_from_segments([])
|
||||
assert entities == []
|
||||
|
||||
|
||||
class TestEntityNormalization:
|
||||
"""Tests for entity text normalization."""
|
||||
|
||||
def test_normalized_text_is_lowercase(self, ner_engine: NerEngine) -> None:
|
||||
"""Normalized text should be lowercase."""
|
||||
entities = ner_engine.extract("John SMITH went to NYC.")
|
||||
for entity in entities:
|
||||
assert entity.normalized_text == entity.normalized_text.lower()
|
||||
|
||||
def test_confidence_is_set(self, ner_engine: NerEngine) -> None:
|
||||
"""Entities should have confidence score."""
|
||||
entities = ner_engine.extract("Microsoft Corporation is based in Seattle.")
|
||||
assert entities, "Should find entities"
|
||||
for entity in entities:
|
||||
assert 0.0 <= entity.confidence <= 1.0, "Confidence should be between 0 and 1"
|
||||
@@ -2,11 +2,17 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.domain import entities
|
||||
from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
from noteflow.infrastructure.asr import dto
|
||||
from noteflow.infrastructure.converters import AsrConverter, OrmConverter
|
||||
from noteflow.infrastructure.converters.ner_converters import NerConverter
|
||||
|
||||
|
||||
class TestAsrConverter:
|
||||
@@ -125,3 +131,161 @@ class TestOrmConverterToOrmKwargs:
|
||||
assert result["end_time"] == 0.987654321
|
||||
assert result["probability"] == 0.111111
|
||||
assert result["word_index"] == 5
|
||||
|
||||
|
||||
class TestNerConverterToOrmKwargs:
|
||||
"""Tests for NerConverter.to_orm_kwargs."""
|
||||
|
||||
def test_converts_domain_to_orm_kwargs(self) -> None:
|
||||
"""Convert domain NamedEntity to ORM kwargs dict."""
|
||||
meeting_id = MeetingId(uuid4())
|
||||
entity = NamedEntity(
|
||||
id=uuid4(),
|
||||
meeting_id=meeting_id,
|
||||
text="Acme Corp",
|
||||
normalized_text="acme corp",
|
||||
category=EntityCategory.COMPANY,
|
||||
segment_ids=[0, 1, 2],
|
||||
confidence=0.95,
|
||||
is_pinned=True,
|
||||
)
|
||||
|
||||
result = NerConverter.to_orm_kwargs(entity)
|
||||
|
||||
assert result["id"] == entity.id, "ID should be preserved"
|
||||
assert result["meeting_id"] == meeting_id, "Meeting ID should be preserved"
|
||||
assert result["text"] == "Acme Corp", "Text should be preserved"
|
||||
assert result["normalized_text"] == "acme corp", "Normalized text should be preserved"
|
||||
assert result["category"] == "company", "Category should be string value"
|
||||
assert result["segment_ids"] == [0, 1, 2], "Segment IDs should be preserved"
|
||||
assert result["confidence"] == 0.95, "Confidence should be preserved"
|
||||
assert result["is_pinned"] is True, "is_pinned should be preserved"
|
||||
|
||||
def test_converts_category_to_string_value(self) -> None:
|
||||
"""Category enum is converted to its string value."""
|
||||
entity = NamedEntity(
|
||||
text="John",
|
||||
category=EntityCategory.PERSON,
|
||||
confidence=0.8,
|
||||
)
|
||||
|
||||
result = NerConverter.to_orm_kwargs(entity)
|
||||
|
||||
assert result["category"] == "person"
|
||||
assert isinstance(result["category"], str)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("category", "expected_value"),
|
||||
[
|
||||
(EntityCategory.PERSON, "person"),
|
||||
(EntityCategory.COMPANY, "company"),
|
||||
(EntityCategory.PRODUCT, "product"),
|
||||
(EntityCategory.LOCATION, "location"),
|
||||
(EntityCategory.DATE, "date"),
|
||||
(EntityCategory.OTHER, "other"),
|
||||
],
|
||||
)
|
||||
def test_all_category_values_convert(
|
||||
self, category: EntityCategory, expected_value: str
|
||||
) -> None:
|
||||
"""All category enum values convert to correct string."""
|
||||
entity = NamedEntity(text="Test", category=category, confidence=0.5)
|
||||
result = NerConverter.to_orm_kwargs(entity)
|
||||
assert result["category"] == expected_value
|
||||
|
||||
|
||||
class TestNerConverterOrmToDomain:
|
||||
"""Tests for NerConverter.orm_to_domain."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_orm_model(self) -> MagicMock:
|
||||
"""Create mock ORM model with typical values."""
|
||||
model = MagicMock()
|
||||
model.id = uuid4()
|
||||
model.meeting_id = uuid4()
|
||||
model.text = "Google Inc."
|
||||
model.normalized_text = "google inc."
|
||||
model.category = "company"
|
||||
model.segment_ids = [1, 3, 5]
|
||||
model.confidence = 0.92
|
||||
model.is_pinned = False
|
||||
return model
|
||||
|
||||
def test_converts_orm_to_domain_entity(self, mock_orm_model: MagicMock) -> None:
|
||||
"""Convert ORM model to domain NamedEntity."""
|
||||
result = NerConverter.orm_to_domain(mock_orm_model)
|
||||
|
||||
assert isinstance(result, NamedEntity), "Should return NamedEntity instance"
|
||||
assert result.id == mock_orm_model.id, "ID should match ORM model"
|
||||
assert result.meeting_id == MeetingId(mock_orm_model.meeting_id), "Meeting ID should match"
|
||||
assert result.text == "Google Inc.", "Text should match ORM model"
|
||||
assert result.normalized_text == "google inc.", "Normalized text should match"
|
||||
assert result.category == EntityCategory.COMPANY, "Category should convert to enum"
|
||||
assert result.segment_ids == [1, 3, 5], "Segment IDs should match"
|
||||
assert result.confidence == 0.92, "Confidence should match"
|
||||
assert result.is_pinned is False, "is_pinned should match"
|
||||
|
||||
def test_sets_db_id_from_orm_id(self, mock_orm_model: MagicMock) -> None:
|
||||
"""Domain entity db_id is set from ORM id."""
|
||||
result = NerConverter.orm_to_domain(mock_orm_model)
|
||||
assert result.db_id == mock_orm_model.id
|
||||
|
||||
def test_converts_category_string_to_enum(self, mock_orm_model: MagicMock) -> None:
|
||||
"""Category string from ORM is converted to EntityCategory enum."""
|
||||
mock_orm_model.category = "person"
|
||||
result = NerConverter.orm_to_domain(mock_orm_model)
|
||||
assert result.category == EntityCategory.PERSON
|
||||
assert isinstance(result.category, EntityCategory)
|
||||
|
||||
def test_handles_none_segment_ids(self, mock_orm_model: MagicMock) -> None:
|
||||
"""Null segment_ids in ORM becomes empty list in domain."""
|
||||
mock_orm_model.segment_ids = None
|
||||
result = NerConverter.orm_to_domain(mock_orm_model)
|
||||
assert result.segment_ids == []
|
||||
|
||||
def test_handles_empty_segment_ids(self, mock_orm_model: MagicMock) -> None:
|
||||
"""Empty segment_ids in ORM becomes empty list in domain."""
|
||||
mock_orm_model.segment_ids = []
|
||||
result = NerConverter.orm_to_domain(mock_orm_model)
|
||||
assert result.segment_ids == []
|
||||
|
||||
|
||||
class TestNerConverterRoundTrip:
|
||||
"""Tests for round-trip conversion fidelity."""
|
||||
|
||||
def test_domain_to_orm_to_domain_preserves_values(self) -> None:
|
||||
"""Round-trip conversion preserves all field values."""
|
||||
original = NamedEntity(
|
||||
id=uuid4(),
|
||||
meeting_id=MeetingId(uuid4()),
|
||||
text="Microsoft Corporation",
|
||||
normalized_text="microsoft corporation",
|
||||
category=EntityCategory.COMPANY,
|
||||
segment_ids=[0, 5, 10],
|
||||
confidence=0.88,
|
||||
is_pinned=True,
|
||||
)
|
||||
|
||||
# Convert to ORM kwargs and simulate ORM model
|
||||
orm_kwargs = NerConverter.to_orm_kwargs(original)
|
||||
mock_orm = MagicMock()
|
||||
mock_orm.id = orm_kwargs["id"]
|
||||
mock_orm.meeting_id = orm_kwargs["meeting_id"]
|
||||
mock_orm.text = orm_kwargs["text"]
|
||||
mock_orm.normalized_text = orm_kwargs["normalized_text"]
|
||||
mock_orm.category = orm_kwargs["category"]
|
||||
mock_orm.segment_ids = orm_kwargs["segment_ids"]
|
||||
mock_orm.confidence = orm_kwargs["confidence"]
|
||||
mock_orm.is_pinned = orm_kwargs["is_pinned"]
|
||||
|
||||
# Convert back to domain
|
||||
result = NerConverter.orm_to_domain(mock_orm)
|
||||
|
||||
assert result.id == original.id, "ID preserved through round-trip"
|
||||
assert result.meeting_id == original.meeting_id, "Meeting ID preserved"
|
||||
assert result.text == original.text, "Text preserved"
|
||||
assert result.normalized_text == original.normalized_text, "Normalized text preserved"
|
||||
assert result.category == original.category, "Category preserved"
|
||||
assert result.segment_ids == original.segment_ids, "Segment IDs preserved"
|
||||
assert result.confidence == original.confidence, "Confidence preserved"
|
||||
assert result.is_pinned == original.is_pinned, "is_pinned preserved"
|
||||
|
||||
@@ -5,6 +5,7 @@ Tests the complete export workflow with database:
|
||||
- gRPC ExportTranscript with database
|
||||
- Markdown export format
|
||||
- HTML export format
|
||||
- PDF export format
|
||||
- File export with correct extensions
|
||||
- Error handling for export operations
|
||||
"""
|
||||
@@ -25,6 +26,22 @@ from noteflow.grpc.proto import noteflow_pb2
|
||||
from noteflow.grpc.service import NoteFlowServicer
|
||||
from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWork
|
||||
|
||||
|
||||
def _weasyprint_available() -> bool:
|
||||
"""Check if weasyprint is available (may fail due to missing system libraries)."""
|
||||
try:
|
||||
import weasyprint as _ # noqa: F401
|
||||
|
||||
return True
|
||||
except (ImportError, OSError):
|
||||
return False
|
||||
|
||||
|
||||
requires_weasyprint = pytest.mark.skipif(
|
||||
not _weasyprint_available(),
|
||||
reason="weasyprint not available (missing system libraries: pango, gobject)",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
@@ -114,6 +131,85 @@ class TestExportServiceDatabase:
|
||||
assert "<html" in content.lower() or "<!doctype" in content.lower()
|
||||
assert "HTML content test" in content
|
||||
|
||||
@pytest.mark.slow
|
||||
@requires_weasyprint
|
||||
async def test_export_pdf_from_database(
|
||||
self, session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Test exporting meeting as PDF from database with full content verification."""
|
||||
|
||||
async with SqlAlchemyUnitOfWork(session_factory) as uow:
|
||||
meeting = Meeting.create(title="Export PDF Test")
|
||||
meeting.start_recording()
|
||||
meeting.begin_stopping()
|
||||
meeting.stop_recording()
|
||||
await uow.meetings.create(meeting)
|
||||
|
||||
segments = [
|
||||
Segment(
|
||||
segment_id=0,
|
||||
text="First speaker talks here.",
|
||||
start_time=0.0,
|
||||
end_time=3.0,
|
||||
speaker_id="Alice",
|
||||
),
|
||||
Segment(
|
||||
segment_id=1,
|
||||
text="Second speaker responds.",
|
||||
start_time=3.0,
|
||||
end_time=6.0,
|
||||
speaker_id="Bob",
|
||||
),
|
||||
]
|
||||
for segment in segments:
|
||||
await uow.segments.add(meeting.id, segment)
|
||||
await uow.commit()
|
||||
|
||||
export_service = ExportService(SqlAlchemyUnitOfWork(session_factory))
|
||||
content = await export_service.export_transcript(meeting.id, ExportFormat.PDF)
|
||||
|
||||
assert isinstance(content, bytes), "PDF export should return bytes"
|
||||
assert content.startswith(b"%PDF-"), "PDF should have valid magic bytes"
|
||||
assert len(content) > 1000, "PDF should have substantial content"
|
||||
|
||||
@pytest.mark.slow
|
||||
@requires_weasyprint
|
||||
async def test_export_to_file_creates_pdf_file(
|
||||
self, session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Test export_to_file writes valid binary PDF file to disk."""
|
||||
|
||||
async with SqlAlchemyUnitOfWork(session_factory) as uow:
|
||||
meeting = Meeting.create(title="PDF File Export Test")
|
||||
await uow.meetings.create(meeting)
|
||||
|
||||
segment = Segment(
|
||||
segment_id=0,
|
||||
text="Content for PDF file.",
|
||||
start_time=0.0,
|
||||
end_time=2.0,
|
||||
speaker_id="Speaker",
|
||||
)
|
||||
await uow.segments.add(meeting.id, segment)
|
||||
await uow.commit()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
output_path = Path(tmpdir) / "transcript.pdf"
|
||||
|
||||
export_service = ExportService(SqlAlchemyUnitOfWork(session_factory))
|
||||
result_path = await export_service.export_to_file(
|
||||
meeting.id,
|
||||
output_path,
|
||||
ExportFormat.PDF,
|
||||
)
|
||||
|
||||
assert result_path.exists(), "PDF file should be created"
|
||||
assert result_path.suffix == ".pdf", "File should have .pdf extension"
|
||||
|
||||
file_bytes = result_path.read_bytes()
|
||||
assert file_bytes.startswith(b"%PDF-"), "File should contain valid PDF"
|
||||
assert len(file_bytes) > 500, "PDF file should have content"
|
||||
|
||||
async def test_export_to_file_creates_file(
|
||||
self, session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
@@ -254,6 +350,44 @@ class TestExportGrpcServicer:
|
||||
assert "HTML export content" in result.content
|
||||
assert result.file_extension == ".html"
|
||||
|
||||
@pytest.mark.slow
|
||||
@requires_weasyprint
|
||||
async def test_export_transcript_pdf_via_grpc(
|
||||
self, session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Test ExportTranscript RPC with PDF format returns valid base64-encoded PDF."""
|
||||
import base64
|
||||
|
||||
async with SqlAlchemyUnitOfWork(session_factory) as uow:
|
||||
meeting = Meeting.create(title="gRPC PDF Export Test")
|
||||
await uow.meetings.create(meeting)
|
||||
|
||||
segment = Segment(
|
||||
segment_id=0,
|
||||
text="PDF via gRPC content.",
|
||||
start_time=0.0,
|
||||
end_time=5.0,
|
||||
speaker_id="TestSpeaker",
|
||||
)
|
||||
await uow.segments.add(meeting.id, segment)
|
||||
await uow.commit()
|
||||
|
||||
servicer = NoteFlowServicer(session_factory=session_factory)
|
||||
|
||||
request = noteflow_pb2.ExportTranscriptRequest(
|
||||
meeting_id=str(meeting.id),
|
||||
format=noteflow_pb2.EXPORT_FORMAT_PDF,
|
||||
)
|
||||
result = await servicer.ExportTranscript(request, MockContext())
|
||||
|
||||
assert result.content, "PDF export should return content"
|
||||
assert result.file_extension == ".pdf", "File extension should be .pdf"
|
||||
|
||||
# gRPC returns base64-encoded PDF; verify it decodes to valid PDF
|
||||
pdf_bytes = base64.b64decode(result.content)
|
||||
assert pdf_bytes.startswith(b"%PDF-"), "Decoded content should be valid PDF"
|
||||
assert len(pdf_bytes) > 500, "PDF should have substantial content"
|
||||
|
||||
async def test_export_transcript_nonexistent_meeting(
|
||||
self, session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
@@ -406,13 +540,13 @@ class TestExportFormats:
|
||||
export_service = ExportService(SqlAlchemyUnitOfWork(session_factory))
|
||||
formats = export_service.get_supported_formats()
|
||||
|
||||
assert len(formats) >= 2
|
||||
assert len(formats) >= 3, "Should support at least markdown, html, and pdf"
|
||||
|
||||
format_names = [f[0] for f in formats]
|
||||
extensions = [f[1] for f in formats]
|
||||
|
||||
assert ".md" in extensions
|
||||
assert ".html" in extensions
|
||||
assert ".md" in extensions, "Markdown format should be supported"
|
||||
assert ".html" in extensions, "HTML format should be supported"
|
||||
assert ".pdf" in extensions, "PDF format should be supported"
|
||||
|
||||
async def test_infer_format_markdown(
|
||||
self, session_factory: async_sessionmaker[AsyncSession]
|
||||
@@ -438,6 +572,15 @@ class TestExportFormats:
|
||||
fmt = export_service._infer_format_from_extension(".htm")
|
||||
assert fmt == ExportFormat.HTML
|
||||
|
||||
async def test_infer_format_pdf(
|
||||
self, session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Test format inference from .pdf extension."""
|
||||
export_service = ExportService(SqlAlchemyUnitOfWork(session_factory))
|
||||
|
||||
fmt = export_service._infer_format_from_extension(".pdf")
|
||||
assert fmt == ExportFormat.PDF, "Should infer PDF format from .pdf extension"
|
||||
|
||||
async def test_infer_format_unknown_raises(
|
||||
self, session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
@@ -446,3 +589,29 @@ class TestExportFormats:
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot infer format"):
|
||||
export_service._infer_format_from_extension(".txt")
|
||||
|
||||
@pytest.mark.slow
|
||||
@requires_weasyprint
|
||||
async def test_export_pdf_returns_bytes(
|
||||
self, session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Test PDF export returns bytes with valid PDF magic bytes."""
|
||||
|
||||
async with SqlAlchemyUnitOfWork(session_factory) as uow:
|
||||
meeting = Meeting.create(title="PDF Export Test")
|
||||
await uow.meetings.create(meeting)
|
||||
|
||||
segment = Segment(
|
||||
segment_id=0,
|
||||
text="PDF content test.",
|
||||
start_time=0.0,
|
||||
end_time=2.0,
|
||||
)
|
||||
await uow.segments.add(meeting.id, segment)
|
||||
await uow.commit()
|
||||
|
||||
export_service = ExportService(SqlAlchemyUnitOfWork(session_factory))
|
||||
content = await export_service.export_transcript(meeting.id, ExportFormat.PDF)
|
||||
|
||||
assert isinstance(content, bytes), "PDF export should return bytes"
|
||||
assert content.startswith(b"%PDF-"), "PDF should have valid magic bytes"
|
||||
|
||||
407
tests/integration/test_e2e_ner.py
Normal file
407
tests/integration/test_e2e_ner.py
Normal file
@@ -0,0 +1,407 @@
|
||||
"""End-to-end integration tests for NER extraction.
|
||||
|
||||
Tests the complete NER workflow with database persistence:
|
||||
- Entity extraction from meeting segments
|
||||
- Persistence and retrieval
|
||||
- Caching behavior
|
||||
- Force refresh functionality
|
||||
- Pin entity operations
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import grpc
|
||||
import pytest
|
||||
|
||||
from noteflow.application.services.ner_service import ExtractionResult, NerService
|
||||
from noteflow.domain.entities import Meeting, Segment
|
||||
from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
from noteflow.grpc.proto import noteflow_pb2
|
||||
from noteflow.grpc.service import NoteFlowServicer
|
||||
from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWork
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_feature_flags() -> MagicMock:
|
||||
"""Mock feature flags to enable NER for all tests."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.feature_flags = MagicMock(ner_enabled=True)
|
||||
with patch(
|
||||
"noteflow.application.services.ner_service.get_settings",
|
||||
return_value=mock_settings,
|
||||
):
|
||||
yield mock_settings
|
||||
|
||||
|
||||
class MockContext:
|
||||
"""Mock gRPC context for testing."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize mock context."""
|
||||
self.aborted = False
|
||||
self.abort_code: grpc.StatusCode | None = None
|
||||
self.abort_details: str | None = None
|
||||
|
||||
async def abort(self, code: grpc.StatusCode, details: str) -> None:
|
||||
"""Record abort and raise to simulate gRPC behavior."""
|
||||
self.aborted = True
|
||||
self.abort_code = code
|
||||
self.abort_details = details
|
||||
raise grpc.RpcError()
|
||||
|
||||
|
||||
class MockNerEngine:
|
||||
"""Mock NER engine that returns controlled entities."""
|
||||
|
||||
def __init__(self, entities: list[NamedEntity] | None = None) -> None:
|
||||
"""Initialize with predefined entities."""
|
||||
self._entities = entities or []
|
||||
self._ready = False
|
||||
self.extract_call_count = 0
|
||||
self.extract_from_segments_call_count = 0
|
||||
|
||||
def extract(self, text: str) -> list[NamedEntity]:
|
||||
"""Extract entities from text (mock)."""
|
||||
self._ready = True
|
||||
self.extract_call_count += 1
|
||||
return self._entities
|
||||
|
||||
def extract_from_segments(
|
||||
self, segments: list[tuple[int, str]]
|
||||
) -> list[NamedEntity]:
|
||||
"""Extract entities from segments (mock)."""
|
||||
self._ready = True
|
||||
self.extract_from_segments_call_count += 1
|
||||
return self._entities
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Check if engine is ready."""
|
||||
return self._ready
|
||||
|
||||
|
||||
def _create_test_entity(
|
||||
text: str,
|
||||
category: EntityCategory = EntityCategory.PERSON,
|
||||
segment_ids: list[int] | None = None,
|
||||
confidence: float = 0.9,
|
||||
) -> NamedEntity:
|
||||
"""Create a test entity."""
|
||||
return NamedEntity.create(
|
||||
text=text,
|
||||
category=category,
|
||||
segment_ids=segment_ids or [0],
|
||||
confidence=confidence,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestNerExtractionFlow:
|
||||
"""Integration tests for entity extraction workflow."""
|
||||
|
||||
async def test_extract_entities_persists_to_database(
|
||||
self, session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Extracted entities are persisted to database."""
|
||||
# Create meeting with segments
|
||||
async with SqlAlchemyUnitOfWork(session_factory) as uow:
|
||||
meeting = Meeting.create(title="NER Test Meeting")
|
||||
await uow.meetings.create(meeting)
|
||||
|
||||
for i in range(3):
|
||||
segment = Segment(
|
||||
segment_id=i,
|
||||
text=f"Segment {i} mentioning John Smith.",
|
||||
start_time=float(i * 10),
|
||||
end_time=float((i + 1) * 10),
|
||||
)
|
||||
await uow.segments.add(meeting.id, segment)
|
||||
await uow.commit()
|
||||
meeting_id = meeting.id
|
||||
|
||||
# Create mock engine that returns test entities
|
||||
mock_engine = MockNerEngine(
|
||||
entities=[
|
||||
_create_test_entity("John Smith", EntityCategory.PERSON, [0, 1, 2]),
|
||||
]
|
||||
)
|
||||
mock_engine._ready = True
|
||||
|
||||
# Create service and extract
|
||||
uow_factory = lambda: SqlAlchemyUnitOfWork(session_factory)
|
||||
service = NerService(mock_engine, uow_factory)
|
||||
|
||||
result = await service.extract_entities(meeting_id)
|
||||
|
||||
assert result.total_count == 1, "Should extract exactly one entity"
|
||||
assert not result.cached, "First extraction should not be cached"
|
||||
|
||||
# Verify persistence
|
||||
async with SqlAlchemyUnitOfWork(session_factory) as uow:
|
||||
entities = await uow.entities.get_by_meeting(meeting_id)
|
||||
assert len(entities) == 1, "Should persist exactly one entity"
|
||||
assert entities[0].text == "John Smith", "Entity text should match"
|
||||
assert entities[0].category == EntityCategory.PERSON, "Category should be PERSON"
|
||||
assert entities[0].segment_ids == [0, 1, 2], "Segment IDs should match"
|
||||
|
||||
async def test_extract_entities_returns_cached_on_second_call(
|
||||
self, session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Second extraction returns cached entities without re-extraction."""
|
||||
async with SqlAlchemyUnitOfWork(session_factory) as uow:
|
||||
meeting = Meeting.create(title="Cache Test")
|
||||
await uow.meetings.create(meeting)
|
||||
await uow.segments.add(
|
||||
meeting.id, Segment(0, "John mentioned Acme Corp.", 0.0, 5.0)
|
||||
)
|
||||
await uow.commit()
|
||||
meeting_id = meeting.id
|
||||
|
||||
mock_engine = MockNerEngine(
|
||||
entities=[
|
||||
_create_test_entity("John", EntityCategory.PERSON, [0]),
|
||||
_create_test_entity("Acme Corp", EntityCategory.COMPANY, [0]),
|
||||
]
|
||||
)
|
||||
mock_engine._ready = True
|
||||
|
||||
uow_factory = lambda: SqlAlchemyUnitOfWork(session_factory)
|
||||
service = NerService(mock_engine, uow_factory)
|
||||
|
||||
# First extraction
|
||||
result1 = await service.extract_entities(meeting_id)
|
||||
assert result1.total_count == 2, "First extraction should find 2 entities"
|
||||
assert not result1.cached, "First extraction should not be cached"
|
||||
initial_count = mock_engine.extract_from_segments_call_count
|
||||
|
||||
# Second extraction should use cache
|
||||
result2 = await service.extract_entities(meeting_id)
|
||||
assert result2.total_count == 2, "Second extraction should return same count"
|
||||
assert result2.cached, "Second extraction should come from cache"
|
||||
assert mock_engine.extract_from_segments_call_count == initial_count, (
|
||||
"Engine should not be called again when using cache"
|
||||
)
|
||||
|
||||
async def test_extract_entities_force_refresh_re_extracts(
|
||||
self, session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Force refresh re-extracts and replaces cached entities."""
|
||||
async with SqlAlchemyUnitOfWork(session_factory) as uow:
|
||||
meeting = Meeting.create(title="Force Refresh Test")
|
||||
await uow.meetings.create(meeting)
|
||||
await uow.segments.add(
|
||||
meeting.id, Segment(0, "Testing force refresh.", 0.0, 5.0)
|
||||
)
|
||||
await uow.commit()
|
||||
meeting_id = meeting.id
|
||||
|
||||
# First extraction
|
||||
mock_engine = MockNerEngine(
|
||||
entities=[_create_test_entity("Test", EntityCategory.OTHER, [0])]
|
||||
)
|
||||
mock_engine._ready = True
|
||||
|
||||
uow_factory = lambda: SqlAlchemyUnitOfWork(session_factory)
|
||||
service = NerService(mock_engine, uow_factory)
|
||||
|
||||
await service.extract_entities(meeting_id)
|
||||
initial_count = mock_engine.extract_from_segments_call_count
|
||||
|
||||
# Force refresh should re-extract
|
||||
result = await service.extract_entities(meeting_id, force_refresh=True)
|
||||
assert not result.cached
|
||||
assert mock_engine.extract_from_segments_call_count == initial_count + 1
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestNerPersistence:
|
||||
"""Integration tests for entity persistence operations."""
|
||||
|
||||
async def test_entities_persist_across_service_instances(
|
||||
self, session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Entities persist in database across service instances."""
|
||||
async with SqlAlchemyUnitOfWork(session_factory) as uow:
|
||||
meeting = Meeting.create(title="Persistence Test")
|
||||
await uow.meetings.create(meeting)
|
||||
await uow.segments.add(meeting.id, Segment(0, "Hello world.", 0.0, 5.0))
|
||||
await uow.commit()
|
||||
meeting_id = meeting.id
|
||||
|
||||
# Extract with first service instance
|
||||
mock_engine = MockNerEngine(
|
||||
entities=[_create_test_entity("World", EntityCategory.OTHER, [0])]
|
||||
)
|
||||
mock_engine._ready = True
|
||||
service1 = NerService(
|
||||
mock_engine, lambda: SqlAlchemyUnitOfWork(session_factory)
|
||||
)
|
||||
await service1.extract_entities(meeting_id)
|
||||
|
||||
# Create new service instance (simulating server restart)
|
||||
service2 = NerService(
|
||||
MockNerEngine(), lambda: SqlAlchemyUnitOfWork(session_factory)
|
||||
)
|
||||
|
||||
# Should get cached result without extraction
|
||||
result = await service2.get_entities(meeting_id)
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "World"
|
||||
|
||||
async def test_clear_entities_removes_all_for_meeting(
|
||||
self, session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Clear entities removes all entities for meeting."""
|
||||
async with SqlAlchemyUnitOfWork(session_factory) as uow:
|
||||
meeting = Meeting.create(title="Clear Test")
|
||||
await uow.meetings.create(meeting)
|
||||
await uow.segments.add(meeting.id, Segment(0, "Test content.", 0.0, 5.0))
|
||||
await uow.commit()
|
||||
meeting_id = meeting.id
|
||||
|
||||
mock_engine = MockNerEngine(
|
||||
entities=[
|
||||
_create_test_entity("Entity1", EntityCategory.PERSON, [0]),
|
||||
_create_test_entity("Entity2", EntityCategory.COMPANY, [0]),
|
||||
]
|
||||
)
|
||||
mock_engine._ready = True
|
||||
|
||||
uow_factory = lambda: SqlAlchemyUnitOfWork(session_factory)
|
||||
service = NerService(mock_engine, uow_factory)
|
||||
|
||||
await service.extract_entities(meeting_id)
|
||||
deleted_count = await service.clear_entities(meeting_id)
|
||||
assert deleted_count == 2
|
||||
|
||||
# Verify entities are gone
|
||||
entities = await service.get_entities(meeting_id)
|
||||
assert len(entities) == 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestNerPinning:
|
||||
"""Integration tests for entity pinning operations."""
|
||||
|
||||
async def test_pin_entity_persists_pinned_state(
|
||||
self, session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Pin entity updates and persists is_pinned flag."""
|
||||
async with SqlAlchemyUnitOfWork(session_factory) as uow:
|
||||
meeting = Meeting.create(title="Pin Test")
|
||||
await uow.meetings.create(meeting)
|
||||
await uow.segments.add(meeting.id, Segment(0, "John Doe test.", 0.0, 5.0))
|
||||
await uow.commit()
|
||||
meeting_id = meeting.id
|
||||
|
||||
mock_engine = MockNerEngine(
|
||||
entities=[_create_test_entity("John Doe", EntityCategory.PERSON, [0])]
|
||||
)
|
||||
mock_engine._ready = True
|
||||
|
||||
uow_factory = lambda: SqlAlchemyUnitOfWork(session_factory)
|
||||
service = NerService(mock_engine, uow_factory)
|
||||
|
||||
await service.extract_entities(meeting_id)
|
||||
|
||||
# Get entity ID
|
||||
entities = await service.get_entities(meeting_id)
|
||||
entity_id = entities[0].id
|
||||
|
||||
# Pin entity
|
||||
result = await service.pin_entity(entity_id, is_pinned=True)
|
||||
assert result is True
|
||||
|
||||
# Verify persistence
|
||||
entities = await service.get_entities(meeting_id)
|
||||
assert entities[0].is_pinned is True
|
||||
|
||||
# Unpin
|
||||
await service.pin_entity(entity_id, is_pinned=False)
|
||||
entities = await service.get_entities(meeting_id)
|
||||
assert entities[0].is_pinned is False
|
||||
|
||||
async def test_pin_entity_nonexistent_returns_false(
|
||||
self, session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Pin entity returns False for nonexistent entity."""
|
||||
uow_factory = lambda: SqlAlchemyUnitOfWork(session_factory)
|
||||
service = NerService(MockNerEngine(), uow_factory)
|
||||
|
||||
result = await service.pin_entity(uuid4(), is_pinned=True)
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestNerEdgeCases:
|
||||
"""Integration tests for edge cases."""
|
||||
|
||||
async def test_extract_from_meeting_with_no_segments(
|
||||
self, session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Extract from meeting with no segments returns empty result."""
|
||||
async with SqlAlchemyUnitOfWork(session_factory) as uow:
|
||||
meeting = Meeting.create(title="Empty Meeting")
|
||||
await uow.meetings.create(meeting)
|
||||
await uow.commit()
|
||||
meeting_id = meeting.id
|
||||
|
||||
mock_engine = MockNerEngine()
|
||||
uow_factory = lambda: SqlAlchemyUnitOfWork(session_factory)
|
||||
service = NerService(mock_engine, uow_factory)
|
||||
|
||||
result = await service.extract_entities(meeting_id)
|
||||
|
||||
assert result.total_count == 0
|
||||
assert result.entities == []
|
||||
assert not result.cached
|
||||
|
||||
async def test_extract_from_nonexistent_meeting_raises(
|
||||
self, session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""Extract from nonexistent meeting raises ValueError."""
|
||||
mock_engine = MockNerEngine()
|
||||
uow_factory = lambda: SqlAlchemyUnitOfWork(session_factory)
|
||||
service = NerService(mock_engine, uow_factory)
|
||||
|
||||
nonexistent_id = MeetingId(uuid4())
|
||||
|
||||
with pytest.raises(ValueError, match=str(nonexistent_id)):
|
||||
await service.extract_entities(nonexistent_id)
|
||||
|
||||
async def test_has_entities_reflects_extraction_state(
|
||||
self, session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> None:
|
||||
"""has_entities returns correct state before and after extraction."""
|
||||
async with SqlAlchemyUnitOfWork(session_factory) as uow:
|
||||
meeting = Meeting.create(title="Has Entities Test")
|
||||
await uow.meetings.create(meeting)
|
||||
await uow.segments.add(meeting.id, Segment(0, "Test.", 0.0, 5.0))
|
||||
await uow.commit()
|
||||
meeting_id = meeting.id
|
||||
|
||||
mock_engine = MockNerEngine(
|
||||
entities=[_create_test_entity("Test", EntityCategory.OTHER, [0])]
|
||||
)
|
||||
mock_engine._ready = True
|
||||
|
||||
uow_factory = lambda: SqlAlchemyUnitOfWork(session_factory)
|
||||
service = NerService(mock_engine, uow_factory)
|
||||
|
||||
# Before extraction
|
||||
assert await service.has_entities(meeting_id) is False
|
||||
|
||||
# After extraction
|
||||
await service.extract_entities(meeting_id)
|
||||
assert await service.has_entities(meeting_id) is True
|
||||
|
||||
# After clearing
|
||||
await service.clear_entities(meeting_id)
|
||||
assert await service.has_entities(meeting_id) is False
|
||||
14
uv.lock
generated
14
uv.lock
generated
@@ -237,6 +237,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3a/2a/7cc015f5b9f5db42b7d48157e23356022889fc354a2813c15934b7cb5c0e/attrs-25.4.0-py3-none-any.whl", hash = "sha256:adcf7e2a1fb3b36ac48d97835bb6d8ade15b8dcce26aba8bf1d14847b57a3373", size = 67615, upload-time = "2025-10-06T13:54:43.17Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "authlib"
|
||||
version = "1.6.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "cryptography" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/bb/9b/b1661026ff24bc641b76b78c5222d614776b0c085bcfdac9bd15a1cb4b35/authlib-1.6.6.tar.gz", hash = "sha256:45770e8e056d0f283451d9996fbb59b70d45722b45d854d58f32878d0a40c38e", size = 164894, upload-time = "2025-12-12T08:01:41.464Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/54/51/321e821856452f7386c4e9df866f196720b1ad0c5ea1623ea7399969ae3b/authlib-1.6.6-py2.py3-none-any.whl", hash = "sha256:7d9e9bc535c13974313a87f53e8430eb6ea3d1cf6ae4f6efcd793f2e949143fd", size = 244005, upload-time = "2025-12-12T08:01:40.209Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "av"
|
||||
version = "16.0.1"
|
||||
@@ -2212,6 +2224,7 @@ source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "alembic" },
|
||||
{ name = "asyncpg" },
|
||||
{ name = "authlib" },
|
||||
{ name = "cryptography" },
|
||||
{ name = "diart" },
|
||||
{ name = "faster-whisper" },
|
||||
@@ -2298,6 +2311,7 @@ requires-dist = [
|
||||
{ name = "alembic", specifier = ">=1.13" },
|
||||
{ name = "anthropic", marker = "extra == 'summarization'", specifier = ">=0.75.0" },
|
||||
{ name = "asyncpg", specifier = ">=0.29" },
|
||||
{ name = "authlib", specifier = ">=1.6.6" },
|
||||
{ name = "basedpyright", marker = "extra == 'dev'", specifier = ">=1.18" },
|
||||
{ name = "cryptography", specifier = ">=42.0" },
|
||||
{ name = "diart", specifier = ">=0.9.2" },
|
||||
|
||||
Reference in New Issue
Block a user