refactor: add API initialization memoization and replace initializeAPI with getAPI across hooks

- Added initializePromise singleton to prevent concurrent API initialization attempts and ensure single initialization
- Wrapped initializeAPI logic in IIFE that caches promise and clears on error for retry capability
- Replaced all initializeAPI() calls with synchronous getAPI() in hooks (useAudioDevices, useDiarization, usePostProcessing, useRecordingAppPolicy)
- Updated corresponding test mocks from
This commit is contained in:
2026-01-16 01:34:33 +00:00
parent 2b908b5d64
commit 1d2bc25024
33 changed files with 1092 additions and 346 deletions

View File

@@ -31,6 +31,7 @@ import { startReconnection } from './reconnection';
import { initializeTauriAPI } from './tauri-adapter';
const log = debug('NoteFlowAPI');
let initializePromise: Promise<NoteFlowAPI> | null = null;
// ============================================================================
// API Initialization
@@ -46,65 +47,78 @@ const log = debug('NoteFlowAPI');
* Sprint GAP-009: Event bridge starts before connection to capture early events.
*/
export async function initializeAPI(): Promise<NoteFlowAPI> {
// Always try Tauri first - initializeTauriAPI tests the API and throws if unavailable
if (initializePromise) {
return initializePromise;
}
initializePromise = (async () => {
// Always try Tauri first - initializeTauriAPI tests the API and throws if unavailable
try {
const tauriAPI = await initializeTauriAPI();
setAPIInstance(tauriAPI);
try {
const { invoke } = await import('@tauri-apps/api/core');
window.__NOTEFLOW_TEST_INVOKE__ = invoke;
} catch (error) {
log('Test invoke binding unavailable (expected in non-Tauri contexts)', {
error: error instanceof Error ? error.message : String(error),
});
}
// Sprint GAP-009: Start event bridge before connection to capture early events
// (e.g., connection errors, early warnings). Non-critical if it fails.
await startTauriEventBridge().catch((error) => {
addClientLog({
level: 'warning',
source: 'api',
message: 'Event bridge initialization failed - continuing without early events',
details: error instanceof Error ? error.message : String(error),
metadata: { context: 'api_event_bridge_init' },
});
});
// Attempt to connect to the gRPC server
try {
const preferredUrl = preferences.getServerUrl();
await tauriAPI.connect(preferredUrl || undefined);
setConnectionMode('connected');
await preferences.initialize();
startReconnection();
// Sprint GAP-007: Log successful connection
log('Adapter: Tauri | Mode: connected', { server: preferredUrl || 'default' });
return tauriAPI;
} catch (connectError) {
// Connection failed - fall back to cached mode but keep Tauri adapter
const message = extractErrorMessage(connectError, 'Connection failed');
setConnectionMode('cached', message);
await preferences.initialize();
startReconnection();
// Sprint GAP-007: Log cached mode fallback
addClientLog({
level: 'warning',
source: 'api',
message: 'Adapter fallback to cached mode',
details: message,
metadata: { context: 'api_cached_mode_fallback' },
});
return tauriAPI; // Keep Tauri adapter for reconnection attempts
}
} catch (_tauriError) {
// Tauri unavailable - use mock API (we're in a browser)
setConnectionMode('mock');
setAPIInstance(mockAPI);
// Sprint GAP-007: Log mock mode
log('Adapter: Mock | Mode: mock | Environment: Browser');
return mockAPI;
}
})();
try {
const tauriAPI = await initializeTauriAPI();
setAPIInstance(tauriAPI);
try {
const { invoke } = await import('@tauri-apps/api/core');
window.__NOTEFLOW_TEST_INVOKE__ = invoke;
} catch (error) {
log('Test invoke binding unavailable (expected in non-Tauri contexts)', {
error: error instanceof Error ? error.message : String(error),
});
}
// Sprint GAP-009: Start event bridge before connection to capture early events
// (e.g., connection errors, early warnings). Non-critical if it fails.
await startTauriEventBridge().catch((error) => {
addClientLog({
level: 'warning',
source: 'api',
message: 'Event bridge initialization failed - continuing without early events',
details: error instanceof Error ? error.message : String(error),
metadata: { context: 'api_event_bridge_init' },
});
});
// Attempt to connect to the gRPC server
try {
const preferredUrl = preferences.getServerUrl();
await tauriAPI.connect(preferredUrl || undefined);
setConnectionMode('connected');
await preferences.initialize();
startReconnection();
// Sprint GAP-007: Log successful connection
log('Adapter: Tauri | Mode: connected', { server: preferredUrl || 'default' });
return tauriAPI;
} catch (connectError) {
// Connection failed - fall back to cached mode but keep Tauri adapter
const message = extractErrorMessage(connectError, 'Connection failed');
setConnectionMode('cached', message);
await preferences.initialize();
startReconnection();
// Sprint GAP-007: Log cached mode fallback
addClientLog({
level: 'warning',
source: 'api',
message: 'Adapter fallback to cached mode',
details: message,
metadata: { context: 'api_cached_mode_fallback' },
});
return tauriAPI; // Keep Tauri adapter for reconnection attempts
}
} catch (_tauriError) {
// Tauri unavailable - use mock API (we're in a browser)
setConnectionMode('mock');
setAPIInstance(mockAPI);
// Sprint GAP-007: Log mock mode
log('Adapter: Mock | Mode: mock | Environment: Browser');
return mockAPI;
return await initializePromise;
} catch (error) {
initializePromise = null;
throw error;
}
}

View File

@@ -136,6 +136,10 @@ export function shouldAutoStartProcessing(
// If processing is already complete or in progress, don't restart
const { summary, entities, diarization } = processingStatus;
const anyFailed =
summary.status === 'failed' ||
entities.status === 'failed' ||
diarization.status === 'failed';
const allTerminal =
['completed', 'failed', 'skipped'].includes(summary.status) &&
['completed', 'failed', 'skipped'].includes(entities.status) &&
@@ -147,7 +151,7 @@ export function shouldAutoStartProcessing(
diarization.status === 'running';
// If all done or any running, don't auto-start
if (allTerminal || anyRunning) {
if (allTerminal || anyRunning || anyFailed) {
return false;
}

View File

@@ -1,6 +1,6 @@
import { act, renderHook } from '@testing-library/react';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { initializeAPI } from '@/api';
import { getAPI } from '@/api';
import { TauriCommands } from '@/api/constants';
import { isTauriEnvironment } from '@/api/tauri-adapter';
import { toast } from '@/hooks/use-toast';
@@ -9,7 +9,7 @@ import { useTauriEvent } from '@/lib/tauri-events';
import { useAudioDevices } from './use-audio-devices';
vi.mock('@/api', () => ({
initializeAPI: vi.fn(),
getAPI: vi.fn(),
}));
vi.mock('@/api/tauri-adapter', async () => {
@@ -172,7 +172,7 @@ describe('useAudioDevices', () => {
it('loads devices in tauri mode and syncs selection', async () => {
vi.mocked(isTauriEnvironment).mockReturnValue(true);
vi.mocked(initializeAPI).mockResolvedValue({
vi.mocked(getAPI).mockReturnValue({
getPreferences: vi.fn().mockResolvedValue({
audio_devices: { input_device_id: '', output_device_id: '' },
}),
@@ -188,7 +188,7 @@ describe('useAudioDevices', () => {
system_gain: 1,
}),
selectAudioDevice: vi.fn(),
} as unknown as Awaited<ReturnType<typeof initializeAPI>>);
} as unknown as ReturnType<typeof getAPI>);
const { result } = renderHook(() => useAudioDevices({ showToasts: false }));
@@ -218,7 +218,7 @@ describe('useAudioDevices', () => {
audio_devices: { input_device_id: 'input:0:Mic', output_device_id: 'output:1:Speakers' },
});
const selectAudioDevice = vi.fn();
vi.mocked(initializeAPI).mockResolvedValue({
vi.mocked(getAPI).mockReturnValue({
getPreferences: vi.fn().mockResolvedValue({
audio_devices: { input_device_id: 'input:0:Mic', output_device_id: 'output:1:Speakers' },
}),
@@ -234,7 +234,7 @@ describe('useAudioDevices', () => {
system_gain: 1,
}),
selectAudioDevice,
} as unknown as Awaited<ReturnType<typeof initializeAPI>>);
} as unknown as ReturnType<typeof getAPI>);
const { result } = renderHook(() => useAudioDevices({ showToasts: false }));
@@ -387,7 +387,7 @@ describe('useAudioDevices', () => {
});
const selectAudioDevice = vi.fn();
vi.mocked(initializeAPI).mockResolvedValue({
vi.mocked(getAPI).mockReturnValue({
getPreferences: vi.fn().mockImplementation(() =>
Promise.resolve({ audio_devices: prefsState.audio_devices })
),
@@ -403,7 +403,7 @@ describe('useAudioDevices', () => {
system_gain: 1,
}),
selectAudioDevice,
} as unknown as Awaited<ReturnType<typeof initializeAPI>>);
} as unknown as ReturnType<typeof getAPI>);
const first = renderHook(() => useAudioDevices({ showToasts: false }));
@@ -644,10 +644,10 @@ describe('useAudioDevices', () => {
it('sets selected devices and syncs tauri selection', async () => {
vi.mocked(isTauriEnvironment).mockReturnValue(true);
const selectAudioDevice = vi.fn();
vi.mocked(initializeAPI).mockResolvedValue({
vi.mocked(getAPI).mockReturnValue({
listAudioDevices: vi.fn().mockResolvedValue([]),
selectAudioDevice,
} as unknown as Awaited<ReturnType<typeof initializeAPI>>);
} as unknown as ReturnType<typeof getAPI>);
const { result } = renderHook(() => useAudioDevices({ showToasts: false }));

View File

@@ -1,5 +1,5 @@
import { useCallback, useEffect, useRef, useState } from 'react';
import { initializeAPI } from '@/api';
import { getAPI } from '@/api';
import { clientLog } from '@/lib/client-log-events';
import { addClientLog } from '@/lib/client-logs';
import { isTauriEnvironment } from '@/api/tauri-adapter';
@@ -133,7 +133,7 @@ export function useAudioDevices(options: UseAudioDevicesOptions = {}): UseAudioD
try {
if (isTauriEnvironment()) {
const api = await initializeAPI();
const api = getAPI();
let storedPrefs = preferences.get();
try {
const tauriPrefs = await api.getPreferences();
@@ -308,7 +308,7 @@ export function useAudioDevices(options: UseAudioDevicesOptions = {}): UseAudioD
});
if (isTauriEnvironment()) {
try {
const api = await initializeAPI();
const api = getAPI();
await api.selectAudioDevice(deviceId, true);
} catch (err) {
const errorMessage = err instanceof Error ? err.message : String(err);
@@ -340,7 +340,7 @@ export function useAudioDevices(options: UseAudioDevicesOptions = {}): UseAudioD
});
if (isTauriEnvironment()) {
try {
const api = await initializeAPI();
const api = getAPI();
await api.selectAudioDevice(deviceId, false);
} catch (err) {
const errorMessage = err instanceof Error ? err.message : String(err);
@@ -364,7 +364,7 @@ export function useAudioDevices(options: UseAudioDevicesOptions = {}): UseAudioD
preferences.setSystemAudioDevice(deviceId);
if (isTauriEnvironment()) {
try {
const api = await initializeAPI();
const api = getAPI();
await api.setSystemAudioDevice(deviceId);
} catch (err) {
const errorMessage = err instanceof Error ? err.message : String(err);
@@ -388,7 +388,7 @@ export function useAudioDevices(options: UseAudioDevicesOptions = {}): UseAudioD
preferences.setDualCaptureEnabled(enabled);
if (isTauriEnvironment()) {
try {
const api = await initializeAPI();
const api = getAPI();
await api.setDualCaptureEnabled(enabled);
if (enabled && !selectedSystemDevice && loopbackDevices.length > 0) {
const firstDevice = loopbackDevices[0].deviceId;
@@ -418,7 +418,7 @@ export function useAudioDevices(options: UseAudioDevicesOptions = {}): UseAudioD
preferences.setAudioMixLevels(newMicGain, newSystemGain);
if (isTauriEnvironment()) {
try {
const api = await initializeAPI();
const api = getAPI();
await api.setAudioMixLevels(newMicGain, newSystemGain);
} catch (err) {
if (showToasts) {

View File

@@ -19,7 +19,7 @@ import { useDiarization } from './use-diarization';
// Mock the API module
vi.mock('@/api', () => ({
initializeAPI: vi.fn(),
getAPI: vi.fn(),
}));
// Mock the toast hook
@@ -47,9 +47,7 @@ describe('useDiarization', () => {
beforeEach(() => {
vi.useFakeTimers();
vi.mocked(api.initializeAPI).mockResolvedValue(
mockAPI as unknown as Awaited<ReturnType<typeof api.initializeAPI>>
);
vi.mocked(api.getAPI).mockReturnValue(mockAPI as unknown as ReturnType<typeof api.getAPI>);
vi.clearAllMocks();
});

View File

@@ -6,7 +6,7 @@
*/
import { useCallback, useEffect, useRef, useState } from 'react';
import { initializeAPI } from '@/api';
import { getAPI } from '@/api';
import type { DiarizationJobStatus, JobStatus } from '@/api/types';
import { toast } from '@/hooks/use-toast';
import { PollingConfig } from '@/lib/config';
@@ -151,7 +151,7 @@ export function useDiarization(options: UseDiarizationOptions = {}): UseDiarizat
}
try {
const api = await initializeAPI();
const api = getAPI();
const status = await api.getDiarizationJobStatus(jobId);
if (!isMountedRef.current) {
@@ -289,7 +289,7 @@ export function useDiarization(options: UseDiarizationOptions = {}): UseDiarizat
retryCountRef.current = 0;
try {
const api = await initializeAPI();
const api = getAPI();
const status = await api.refineSpeakers(meetingId, numSpeakers);
if (!isMountedRef.current) {
@@ -345,7 +345,7 @@ export function useDiarization(options: UseDiarizationOptions = {}): UseDiarizat
stopPolling();
try {
const api = await initializeAPI();
const api = getAPI();
const result = await api.cancelDiarization(jobId);
if (!isMountedRef.current) {
@@ -396,7 +396,7 @@ export function useDiarization(options: UseDiarizationOptions = {}): UseDiarizat
}
try {
const api = await initializeAPI();
const api = getAPI();
const activeJobs = await api.getActiveDiarizationJobs();
if (!isMountedRef.current) {

View File

@@ -23,20 +23,20 @@ import { toast } from '@/hooks/use-toast';
import { usePostProcessing } from './use-post-processing';
// Use vi.hoisted to define mocks that are available in vi.mock
const { mockAPI, mockInitializeAPI } = vi.hoisted(() => {
const { mockAPI, mockGetAPI } = vi.hoisted(() => {
const mockAPI = {
generateSummary: vi.fn(),
extractEntities: vi.fn(),
refineSpeakers: vi.fn(),
getDiarizationJobStatus: vi.fn(),
};
const mockInitializeAPI = vi.fn(() => Promise.resolve(mockAPI));
return { mockAPI, mockInitializeAPI };
const mockGetAPI = vi.fn(() => mockAPI);
return { mockAPI, mockGetAPI };
});
// Mock the API module
vi.mock('@/api', () => ({
initializeAPI: mockInitializeAPI,
getAPI: mockGetAPI,
}));
// Mock the toast hook

View File

@@ -11,9 +11,10 @@
*/
import { useCallback, useEffect, useRef, useState } from 'react';
import { initializeAPI } from '@/api';
import { getAPI } from '@/api';
import type { DiarizationJobStatus, MeetingState, ProcessingStatus } from '@/api/types';
import { toast } from '@/hooks/use-toast';
import { setEntitiesFromExtraction } from '@/lib/entity-store';
import { errorMessageFrom, toastError } from '@/lib/error-reporting';
import { usePostProcessingEvents } from './post-processing/events';
import type {
@@ -40,7 +41,6 @@ export type {
UsePostProcessingReturn,
} from './post-processing/state';
/**
* Hook for orchestrating post-processing after recording stops
*/
@@ -116,7 +116,7 @@ export function usePostProcessing(options: UsePostProcessingOptions = {}): UsePo
}
try {
const api = await initializeAPI();
const api = getAPI();
const status: DiarizationJobStatus = await api.getDiarizationJobStatus(jobId);
if (!isMountedRef.current) {
@@ -201,7 +201,7 @@ export function usePostProcessing(options: UsePostProcessingOptions = {}): UsePo
});
try {
const api = await initializeAPI();
const api = getAPI();
await api.generateSummary(meetingId, false);
if (!isMountedRef.current) {
@@ -253,8 +253,9 @@ export function usePostProcessing(options: UsePostProcessingOptions = {}): UsePo
});
try {
const api = await initializeAPI();
await api.extractEntities(meetingId, false);
const api = getAPI();
const response = await api.extractEntities(meetingId, false);
setEntitiesFromExtraction(response.entities);
if (!isMountedRef.current) {
return;
@@ -305,7 +306,7 @@ export function usePostProcessing(options: UsePostProcessingOptions = {}): UsePo
});
try {
const api = await initializeAPI();
const api = getAPI();
const response = await api.refineSpeakers(meetingId, numSpeakers);
if (!isMountedRef.current) {

View File

@@ -1,5 +1,5 @@
import { useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { initializeAPI } from '@/api';
import { getAPI } from '@/api';
import { isTauriEnvironment } from '@/api/tauri-adapter';
import type {
AppMatcher,
@@ -216,7 +216,7 @@ export function useRecordingAppPolicy(): RecordingAppPolicyState {
setIsLoadingApps(true);
try {
const api = await initializeAPI();
const api = getAPI();
const page = reset ? 0 : currentPageRef.current;
const response: ListInstalledAppsResponse = await api.listInstalledApps({
commonOnly,

View File

@@ -19,7 +19,7 @@ import { debug } from '@/lib/debug';
const log = debug('Preferences:Replace');
import { clonePreferences } from './core';
import { arePreferencesEqual, clonePreferences } from './core';
import { defaultPreferences } from './constants';
import { mergeIntegrationsWithDefaults, resetIntegrationsForServerSwitch } from './integrations';
import {
@@ -78,10 +78,6 @@ export const preferences = {
},
replace(prefs: UserPreferences): void {
log('Replacing preferences', {
input_device_id: prefs.audio_devices?.input_device_id,
output_device_id: prefs.audio_devices?.output_device_id,
});
const mergedAudioDevices = isRecord(prefs.audio_devices)
? { ...defaultPreferences.audio_devices, ...prefs.audio_devices }
: { ...defaultPreferences.audio_devices };
@@ -92,7 +88,16 @@ export const preferences = {
audio_devices: mergedAudioDevices,
integrations: mergedIntegrations,
};
savePreferences(applyLocalOverrides(merged), persistPreferencesToTauri);
const next = applyLocalOverrides(merged);
const current = loadPreferences();
if (arePreferencesEqual(current, next)) {
return;
}
log('Replacing preferences', {
input_device_id: next.audio_devices?.input_device_id,
output_device_id: next.audio_devices?.output_device_id,
});
savePreferences(next, persistPreferencesToTauri);
},
resetToDefaults(): void {
@@ -462,9 +467,7 @@ export const preferences = {
setServerAddressOverride(hasAddress, trimmedHost, trimmedPort);
});
},
// SYNC SCHEDULER
isSyncSchedulerPaused(): boolean {
return loadPreferences().sync_scheduler_paused;
},
@@ -474,9 +477,7 @@ export const preferences = {
prefs.sync_scheduler_paused = paused;
});
},
// SYNC HISTORY
getSyncHistory(): SyncHistoryEvent[] {
return loadPreferences().sync_history || [];
},

View File

@@ -145,3 +145,10 @@ export function normalizePreferences(prefs: UserPreferences): UserPreferences {
ai_config: normalizeAiConfig(prefs.ai_config),
};
}
/**
* Compare preferences payloads after normalization.
*/
export function arePreferencesEqual(a: UserPreferences, b: UserPreferences): boolean {
return JSON.stringify(normalizePreferences(a)) === JSON.stringify(normalizePreferences(b));
}

View File

@@ -7,7 +7,7 @@ import type { Integration, UserPreferences } from '@/api/types';
import { addClientLog } from '@/lib/client-logs';
import { debug } from '@/lib/debug';
import { isTauriRuntime } from './core';
import { arePreferencesEqual, isTauriRuntime } from './core';
import { loadPreferences } from './storage';
import { emitValidationEvent } from './validation-events';
@@ -241,8 +241,10 @@ export async function hydratePreferencesFromTauri(
});
// Only call replacePrefs if we actually need to update something
// This avoids unnecessary writes
replacePrefs(mergedPrefs);
// This avoids unnecessary writes and persistence churn.
if (!arePreferencesEqual(localPrefs, mergedPrefs)) {
replacePrefs(mergedPrefs);
}
hasHydratedFromTauri = true;
notifyHydrationComplete();

View File

@@ -2,7 +2,7 @@
* Hook for meeting detail data and state.
*/
import { useEffect, useMemo, useState } from 'react';
import { useEffect, useMemo, useRef, useState } from 'react';
import { getAPI } from '@/api';
import type { Meeting } from '@/api/types';
@@ -25,19 +25,14 @@ export function useMeetingDetail({ meetingId }: UseMeetingDetailProps) {
const [selectedSegment, setSelectedSegment] = useState<number | null>(null);
const [speakerNameMap, setSpeakerNameMap] = useState<Map<string, string>>(new Map());
const { mode } = useConnectionState();
const autoStartedRef = useRef<Set<string>>(new Set());
const autoStartLogRef = useRef<Set<string>>(new Set());
// Entity extraction hook
const {
state: entityState,
extract: extractEntities,
isExtracting,
} = useEntityExtraction({
meetingId,
meetingTitle: meeting?.title,
meetingState: meeting?.state,
autoExtract: true,
});
const { entities } = entityState;
const failedStep = meeting?.processing_status
? (['summary', 'entities', 'diarization'] as const).find(
(step) => meeting.processing_status?.[step]?.status === 'failed'
) ?? null
: null;
// Post-processing orchestration hook
const {
@@ -54,6 +49,26 @@ export function useMeetingDetail({ meetingId }: UseMeetingDetailProps) {
}
},
});
const hasLocalPostProcessing = processingState.meetingId === meetingId;
const shouldAutoExtract =
!!meeting &&
!!meetingId &&
meeting.state === 'completed' &&
!hasLocalPostProcessing &&
!shouldAutoStart(meeting.state, meeting.processing_status);
// Entity extraction hook
const {
state: entityState,
extract: extractEntities,
isExtracting,
} = useEntityExtraction({
meetingId,
meetingTitle: meeting?.title,
meetingState: meeting?.state,
autoExtract: shouldAutoExtract,
});
const { entities } = entityState;
// Summary metadata
const summaryMeta = useMemo(() => {
@@ -128,9 +143,35 @@ export function useMeetingDetail({ meetingId }: UseMeetingDetailProps) {
// Auto-start post-processing
useEffect(() => {
if (meeting && meetingId && shouldAutoStart(meeting.state, meeting.processing_status)) {
void startProcessing(meetingId);
if (!meeting || !meetingId) {
return;
}
if (failedStep && !autoStartLogRef.current.has(meetingId)) {
autoStartLogRef.current.add(meetingId);
addClientLog({
level: 'info',
source: 'app',
message: 'Post-processing auto-start skipped due to failed step',
metadata: {
meeting_id: meetingId,
failed_step: failedStep,
},
});
}
if (!shouldAutoStart(meeting.state, meeting.processing_status)) {
return;
}
if (autoStartedRef.current.has(meetingId)) {
return;
}
autoStartedRef.current.add(meetingId);
addClientLog({
level: 'info',
source: 'app',
message: 'Post-processing auto-start triggered',
metadata: { meeting_id: meetingId },
});
void startProcessing(meetingId);
}, [meeting, meetingId, shouldAutoStart, startProcessing]);
const handleGenerateSummary = async () => {

View File

@@ -0,0 +1,41 @@
"""Helpers for normalizing extracted entities."""
from __future__ import annotations
from collections.abc import Sequence
from noteflow.domain.entities.named_entity import NamedEntity
from noteflow.infrastructure.logging import get_logger
logger = get_logger(__name__)
def dedupe_entities(entities: Sequence[NamedEntity]) -> list[NamedEntity]:
"""Collapse duplicate entities by normalized text.
Merges segment IDs and keeps the highest-confidence category.
"""
deduped: dict[str, NamedEntity] = {}
for entity in entities:
key = entity.normalized_text or entity.text.lower().strip()
if not key:
continue
existing = deduped.get(key)
if existing is None:
entity.normalized_text = key
deduped[key] = entity
continue
existing.merge_segments(entity.segment_ids)
if entity.confidence > existing.confidence:
existing.confidence = entity.confidence
existing.category = entity.category
if len(entity.text) > len(existing.text):
existing.text = entity.text
if len(deduped) != len(entities):
logger.info(
"ner_entities_deduped",
total=len(entities),
deduped=len(deduped),
)
return list(deduped.values())

View File

@@ -7,9 +7,11 @@ Orchestrates NER extraction, caching, and persistence following hexagonal archit
from __future__ import annotations
import asyncio
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING
from noteflow.application.services.ner._dedupe import dedupe_entities
from noteflow.config.constants import ERROR_MSG_MEETING_PREFIX
from noteflow.config.settings import get_feature_flags
from noteflow.domain.entities.named_entity import NamedEntity
@@ -17,7 +19,7 @@ from noteflow.infrastructure.logging import get_logger, log_timing
from noteflow.infrastructure.metrics.memory_logger import log_memory_snapshot
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from collections.abc import Callable
from uuid import UUID
from noteflow.domain.ports.ner import NerPort
@@ -211,6 +213,7 @@ class NerService:
entities = await self._extraction_helper.extract(segments)
for entity in entities:
entity.meeting_id = meeting_id
entities = dedupe_entities(entities)
await self._persist_entities(meeting_id, entities, force_refresh)
log_memory_snapshot(

View File

@@ -14,6 +14,13 @@ PROVIDER: Final[str] = "provider"
UNKNOWN: Final[str] = "unknown"
DEVICE: Final[Literal["device"]] = "device"
SEGMENT_IDS: Final[Literal["segment_ids"]] = "segment_ids"
MEETING_ID: Final[Literal["meeting_id"]] = "meeting_id"
TEXT: Final[Literal["text"]] = "text"
NORMALIZED_TEXT: Final[Literal["normalized_text"]] = "normalized_text"
CATEGORY: Final[Literal["category"]] = "category"
CONFIDENCE: Final[Literal["confidence"]] = "confidence"
IS_PINNED: Final[Literal["is_pinned"]] = "is_pinned"
UPDATED_AT: Final[Literal["updated_at"]] = "updated_at"
PROJECT_ID: Final[str] = "project_id"
PROJECT_IDS: Final[str] = "project_ids"
CALENDAR: Final[str] = "calendar"

View File

@@ -9,7 +9,7 @@ from uuid import UUID, uuid4
from noteflow.domain.constants.fields import DATE as ENTITY_DATE
from noteflow.domain.constants.fields import LOCATION as ENTITY_LOCATION
from noteflow.domain.constants.fields import SEGMENT_IDS
from noteflow.domain.constants.fields import CATEGORY, CONFIDENCE, SEGMENT_IDS, TEXT
if TYPE_CHECKING:
from noteflow.domain.value_objects import MeetingId
@@ -116,10 +116,10 @@ class NamedEntity:
ValueError: If text is empty or confidence is out of range.
"""
# Validate required text
text = kwargs["text"]
category = kwargs["category"]
text = kwargs[TEXT]
category = kwargs[CATEGORY]
segment_ids: list[int] = kwargs[SEGMENT_IDS]
confidence = kwargs["confidence"]
confidence = kwargs[CONFIDENCE]
meeting_id = kwargs.get("meeting_id")
stripped_text = text.strip()

View File

@@ -0,0 +1,166 @@
"""Helpers for updating meeting post-processing status."""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal, cast
from uuid import UUID
from noteflow.domain.entities.processing import (
ProcessingStatus,
ProcessingStepState,
ProcessingStepStatus,
)
from noteflow.domain.ports.unit_of_work import UnitOfWork
from noteflow.domain.value_objects import MeetingId
from noteflow.infrastructure.logging import get_logger
if TYPE_CHECKING:
from noteflow.domain.entities import Meeting
logger = get_logger(__name__)
ProcessingStepName = Literal["summary", "entities", "diarization"]
def _parse_meeting_id(meeting_id: str) -> MeetingId | None:
try:
return MeetingId(UUID(meeting_id))
except ValueError:
return None
@dataclass(frozen=True, slots=True)
class ProcessingStatusUpdate:
step: ProcessingStepName
status: ProcessingStepStatus
error_message: str | None = None
def _build_step_state(
current: ProcessingStepState,
update: ProcessingStatusUpdate,
) -> ProcessingStepState:
if update.status == ProcessingStepStatus.RUNNING:
return ProcessingStepState.running()
if update.status == ProcessingStepStatus.COMPLETED:
return ProcessingStepState.completed(started_at=current.started_at)
if update.status == ProcessingStepStatus.FAILED:
return ProcessingStepState.failed(
update.error_message or "processing failed",
started_at=current.started_at,
)
if update.status == ProcessingStepStatus.SKIPPED:
return ProcessingStepState.skipped()
return ProcessingStepState.pending()
def _apply_step_update(
current: ProcessingStatus,
step: ProcessingStepName,
new_state: ProcessingStepState,
) -> ProcessingStatus:
if step == "summary":
return ProcessingStatus(
summary=new_state,
entities=current.entities,
diarization=current.diarization,
queued_at=current.queued_at,
)
if step == "entities":
return ProcessingStatus(
summary=current.summary,
entities=new_state,
diarization=current.diarization,
queued_at=current.queued_at,
)
return ProcessingStatus(
summary=current.summary,
entities=current.entities,
diarization=new_state,
queued_at=current.queued_at,
)
async def _commit_processing_update(
repo: UnitOfWork,
meeting: Meeting,
meeting_id: str,
update: ProcessingStatusUpdate,
) -> bool:
try:
await repo.meetings.update(meeting)
await repo.commit()
return True
except ValueError as exc:
await repo.rollback()
if "modified concurrently" in str(exc):
return False
logger.warning(
"processing_status_update_failed",
meeting_id=meeting_id,
step=update.step,
error=str(exc),
)
return True
async def _apply_processing_update(
repo_provider: Callable[[], object],
parsed_id: MeetingId,
meeting_id: str,
update: ProcessingStatusUpdate,
) -> bool:
async with cast(UnitOfWork, repo_provider()) as repo:
meeting = await repo.meetings.get(parsed_id)
if meeting is None:
logger.warning(
"processing_status_meeting_missing",
meeting_id=meeting_id,
step=update.step,
)
return True
current_status = meeting.processing_status or ProcessingStatus.create_pending()
current_step = getattr(current_status, update.step)
new_state = _build_step_state(current_step, update)
meeting.processing_status = _apply_step_update(current_status, update.step, new_state)
return await _commit_processing_update(repo, meeting, meeting_id, update)
async def update_processing_status(
repo_provider: Callable[[], object],
meeting_id: str,
update: ProcessingStatusUpdate,
*,
max_attempts: int = 3,
) -> None:
"""Update a meeting's processing status with retry on version conflicts."""
parsed_id = _parse_meeting_id(meeting_id)
if parsed_id is None:
logger.debug(
"processing_status_meeting_id_invalid",
meeting_id=meeting_id,
step=update.step,
)
return
for attempt in range(max_attempts):
completed = await _apply_processing_update(
repo_provider,
parsed_id,
meeting_id,
update,
)
if completed:
return
if attempt >= max_attempts - 1:
logger.warning(
"processing_status_update_failed",
meeting_id=meeting_id,
step=update.step,
error="retry_limit_exceeded",
)
return

View File

@@ -14,7 +14,7 @@ from opentelemetry.trace import Span
from ._types import DIARIZATION_TIMEOUT_SECONDS
if TYPE_CHECKING:
from ._jobs import DiarizationJobContext
from ._job_validation import DiarizationJobContext
def _set_diarization_span_attributes(span: Span, ctx: DiarizationJobContext) -> None:

View File

@@ -0,0 +1,143 @@
"""Diarization job validation and error response helpers."""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, cast
from uuid import UUID
import grpc
from noteflow.domain.ports.unit_of_work import UnitOfWork
from noteflow.domain.value_objects import MeetingState
from noteflow.infrastructure.logging import get_logger
from noteflow.infrastructure.persistence.repositories import DiarizationJob
from ...proto import noteflow_pb2
from .._types import GrpcStatusContext
from ..converters import parse_meeting_id
from ..errors._constants import INVALID_MEETING_ID_MESSAGE
if TYPE_CHECKING:
from noteflow.domain.entities import Meeting
from ..protocols import ServicerHost
logger = get_logger(__name__)
def job_status_name(status: int) -> str:
name_fn = cast(Callable[[int], str], noteflow_pb2.JobStatus.Name)
return name_fn(status)
@dataclass(frozen=True)
class GrpcErrorDetails:
"""gRPC error context and status for response helpers."""
context: GrpcStatusContext
grpc_code: grpc.StatusCode
@dataclass(frozen=True, slots=True)
class DiarizationJobContext:
"""Context for executing a diarization job.
Groups job-related parameters to reduce function signature complexity.
"""
host: ServicerHost
job_id: str
job: DiarizationJob
meeting_id: str
num_speakers: int | None
def create_diarization_error_response(
error_message: str,
status: noteflow_pb2.JobStatus | str = noteflow_pb2.JOB_STATUS_FAILED,
*,
error: GrpcErrorDetails | None = None,
job_id: str = "",
) -> noteflow_pb2.RefineSpeakerDiarizationResponse:
"""Create error response for RefineSpeakerDiarization.
Consolidates the 6+ duplicated response construction patterns.
"""
if error is not None:
error.context.set_code(error.grpc_code)
error.context.set_details(error_message)
return noteflow_pb2.RefineSpeakerDiarizationResponse(
segments_updated=0,
speaker_ids=[],
error_message=error_message,
job_id=job_id,
status=status,
)
def validate_diarization_preconditions(
servicer: ServicerHost,
request: noteflow_pb2.RefineSpeakerDiarizationRequest,
context: GrpcStatusContext,
) -> noteflow_pb2.RefineSpeakerDiarizationResponse | None:
"""Validate preconditions before starting diarization job."""
if not servicer.diarization_refinement_enabled:
return create_diarization_error_response("Diarization refinement disabled on server")
if servicer.diarization_engine is None:
return create_diarization_error_response(
"Diarization not enabled on server",
error=GrpcErrorDetails(
context=context,
grpc_code=grpc.StatusCode.UNAVAILABLE,
),
)
try:
UUID(request.meeting_id)
except ValueError:
return create_diarization_error_response(INVALID_MEETING_ID_MESSAGE)
return None
async def load_meeting_for_diarization(
repo: UnitOfWork,
meeting_id: str,
) -> tuple[Meeting | None, noteflow_pb2.RefineSpeakerDiarizationResponse | None]:
"""Fetch meeting and validate state for diarization refinement."""
meeting = await repo.meetings.get(parse_meeting_id(meeting_id))
if meeting is None:
return None, create_diarization_error_response("Meeting not found")
valid_states = (MeetingState.STOPPED, MeetingState.COMPLETED, MeetingState.ERROR)
if meeting.state not in valid_states:
return None, create_diarization_error_response(
f"Meeting must be stopped before refinement (state: {meeting.state.name.lower()})"
)
return meeting, None
async def check_active_diarization_job(
repo: UnitOfWork,
meeting_id: str,
context: GrpcStatusContext,
) -> noteflow_pb2.RefineSpeakerDiarizationResponse | None:
"""Return error response if a diarization job is already active."""
if not repo.supports_diarization_jobs:
return None
active_job = await repo.diarization_jobs.get_active_for_meeting(meeting_id)
if active_job is None:
return None
return create_diarization_error_response(
f"Diarization already in progress (job: {active_job.job_id})",
status=job_status_name(active_job.status),
error=GrpcErrorDetails(
context=context,
grpc_code=grpc.StatusCode.ALREADY_EXISTS,
),
job_id=active_job.job_id,
)

View File

@@ -3,33 +3,39 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, cast
from uuid import UUID, uuid4
from typing import TYPE_CHECKING
from uuid import uuid4
import grpc
from noteflow.domain.ports.unit_of_work import UnitOfWork
from noteflow.domain.entities.processing import ProcessingStepStatus
from noteflow.domain.utils import utc_now
from noteflow.domain.value_objects import MeetingState
from noteflow.infrastructure.logging import get_logger, log_state_transition
from noteflow.infrastructure.persistence.repositories import DiarizationJob
from ...proto import noteflow_pb2
from .._types import GrpcStatusContext
from ..converters import parse_meeting_id
from ..errors._constants import INVALID_MEETING_ID_MESSAGE
from .._processing_status import ProcessingStatusUpdate, update_processing_status
from ._execution import execute_diarization
from ._job_validation import (
DiarizationJobContext,
GrpcErrorDetails,
check_active_diarization_job,
create_diarization_error_response,
job_status_name,
load_meeting_for_diarization,
validate_diarization_preconditions,
)
from ._status import JobStatusMixin
if TYPE_CHECKING:
from noteflow.domain.entities import Meeting
from ..protocols import ServicerHost
logger = get_logger(__name__)
DIARIZATION_DB_REQUIRED = "Diarization requires database support"
def _diarization_task_done_callback(
task: asyncio.Task[None],
@@ -50,90 +56,6 @@ def _diarization_task_done_callback(
)
def _job_status_name(status: int) -> str:
name_fn = cast(Callable[[int], str], noteflow_pb2.JobStatus.Name)
return name_fn(status)
@dataclass(frozen=True)
class GrpcErrorDetails:
"""gRPC error context and status for response helpers."""
context: GrpcStatusContext
grpc_code: grpc.StatusCode
@dataclass(frozen=True, slots=True)
class DiarizationJobContext:
"""Context for executing a diarization job.
Groups job-related parameters to reduce function signature complexity.
"""
host: ServicerHost
job_id: str
job: DiarizationJob
meeting_id: str
num_speakers: int | None
def create_diarization_error_response(
error_message: str,
status: noteflow_pb2.JobStatus | str = noteflow_pb2.JOB_STATUS_FAILED,
*,
error: GrpcErrorDetails | None = None,
job_id: str = "",
) -> noteflow_pb2.RefineSpeakerDiarizationResponse:
"""Create error response for RefineSpeakerDiarization.
Consolidates the 6+ duplicated response construction patterns.
Args:
error_message: Error message describing the failure.
status: Job status code (default: JOB_STATUS_FAILED).
error: Optional gRPC error details for status code.
job_id: Optional job ID to include in response.
Returns:
Populated RefineSpeakerDiarizationResponse with error state.
"""
if error is not None:
error.context.set_code(error.grpc_code)
error.context.set_details(error_message)
return noteflow_pb2.RefineSpeakerDiarizationResponse(
segments_updated=0,
speaker_ids=[],
error_message=error_message,
job_id=job_id,
status=status,
)
def _validate_diarization_preconditions(
servicer: ServicerHost,
request: noteflow_pb2.RefineSpeakerDiarizationRequest,
context: GrpcStatusContext,
) -> noteflow_pb2.RefineSpeakerDiarizationResponse | None:
"""Validate preconditions before starting diarization job."""
if not servicer.diarization_refinement_enabled:
return create_diarization_error_response("Diarization refinement disabled on server")
if servicer.diarization_engine is None:
return create_diarization_error_response(
"Diarization not enabled on server",
error=GrpcErrorDetails(
context=context,
grpc_code=grpc.StatusCode.UNAVAILABLE,
),
)
try:
UUID(request.meeting_id)
except ValueError:
return create_diarization_error_response(INVALID_MEETING_ID_MESSAGE)
return None
async def _create_and_persist_job(
job_id: str,
meeting_id: str,
@@ -165,47 +87,6 @@ async def _create_and_persist_job(
return True
async def _load_meeting_for_diarization(
repo: UnitOfWork,
meeting_id: str,
) -> tuple[Meeting | None, noteflow_pb2.RefineSpeakerDiarizationResponse | None]:
"""Fetch meeting and validate state for diarization refinement."""
meeting = await repo.meetings.get(parse_meeting_id(meeting_id))
if meeting is None:
return None, create_diarization_error_response("Meeting not found")
valid_states = (MeetingState.STOPPED, MeetingState.COMPLETED, MeetingState.ERROR)
if meeting.state not in valid_states:
return None, create_diarization_error_response(
f"Meeting must be stopped before refinement (state: {meeting.state.name.lower()})"
)
return meeting, None
async def _check_active_diarization_job(
repo: UnitOfWork,
meeting_id: str,
context: GrpcStatusContext,
) -> noteflow_pb2.RefineSpeakerDiarizationResponse | None:
"""Return error response if a diarization job is already active."""
if not repo.supports_diarization_jobs:
return None
active_job = await repo.diarization_jobs.get_active_for_meeting(meeting_id)
if active_job is None:
return None
return create_diarization_error_response(
f"Diarization already in progress (job: {active_job.job_id})",
status=_job_status_name(active_job.status),
error=GrpcErrorDetails(
context=context,
grpc_code=grpc.StatusCode.ALREADY_EXISTS,
),
job_id=active_job.job_id,
)
async def _init_job_for_running(
host: ServicerHost,
job_id: str,
@@ -233,8 +114,8 @@ async def _init_job_for_running(
started_at=utc_now(),
)
await repo.commit()
transition_from = _job_status_name(old_status)
transition_to = _job_status_name(int(noteflow_pb2.JOB_STATUS_RUNNING))
transition_from = job_status_name(old_status)
transition_to = job_status_name(int(noteflow_pb2.JOB_STATUS_RUNNING))
log_state_transition(
"diarization_job",
job_id,
@@ -271,6 +152,40 @@ def _schedule_diarization_task(
num_speakers=num_speakers,
)
async def _prepare_diarization_job(
host: ServicerHost,
request: noteflow_pb2.RefineSpeakerDiarizationRequest,
context: GrpcStatusContext,
) -> tuple[str, float | None] | noteflow_pb2.RefineSpeakerDiarizationResponse:
async with host.create_repository_provider() as repo:
meeting, error = await load_meeting_for_diarization(repo, request.meeting_id)
if error is not None:
return error
active_error = await check_active_diarization_job(repo, request.meeting_id, context)
if active_error is not None:
return active_error
job_id = str(uuid4())
persisted = await _create_and_persist_job(
job_id,
request.meeting_id,
meeting.duration_seconds if meeting else None,
repo,
)
if not persisted:
return create_diarization_error_response(
DIARIZATION_DB_REQUIRED,
error=GrpcErrorDetails(
context=context,
grpc_code=grpc.StatusCode.FAILED_PRECONDITION,
),
)
return job_id, meeting.duration_seconds if meeting else None
class JobsMixin(JobStatusMixin):
"""Mixin providing diarization job management."""
@@ -280,35 +195,42 @@ class JobsMixin(JobStatusMixin):
context: GrpcStatusContext,
) -> noteflow_pb2.RefineSpeakerDiarizationResponse:
"""Start a new diarization refinement job."""
if error := _validate_diarization_preconditions(self, request, context):
if error := validate_diarization_preconditions(self, request, context):
await update_processing_status(
self.create_repository_provider,
request.meeting_id,
ProcessingStatusUpdate(
step="diarization",
status=ProcessingStepStatus.FAILED,
error_message=error.error_message,
),
)
return error
async with self.create_repository_provider() as repo:
meeting, error = await _load_meeting_for_diarization(repo, request.meeting_id)
if error is not None:
return error
active_error = await _check_active_diarization_job(repo, request.meeting_id, context)
if active_error is not None:
return active_error
job_id = str(uuid4())
persisted = await _create_and_persist_job(
job_id,
request.meeting_id,
meeting.duration_seconds if meeting else None,
repo,
)
if not persisted:
return create_diarization_error_response(
"Diarization requires database support",
error=GrpcErrorDetails(
context=context,
grpc_code=grpc.StatusCode.FAILED_PRECONDITION,
job_result = await _prepare_diarization_job(self, request, context)
if isinstance(job_result, noteflow_pb2.RefineSpeakerDiarizationResponse):
if job_result.error_message == DIARIZATION_DB_REQUIRED:
await update_processing_status(
self.create_repository_provider,
request.meeting_id,
ProcessingStatusUpdate(
step="diarization",
status=ProcessingStepStatus.FAILED,
error_message=DIARIZATION_DB_REQUIRED,
),
)
return job_result
job_id, _duration = job_result
num_speakers = request.num_speakers or None
await update_processing_status(
self.create_repository_provider,
request.meeting_id,
ProcessingStatusUpdate(
step="diarization",
status=ProcessingStepStatus.RUNNING,
),
)
_schedule_diarization_task(self, job_id, num_speakers, request.meeting_id)
return noteflow_pb2.RefineSpeakerDiarizationResponse(

View File

@@ -5,11 +5,13 @@ from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING, cast
from noteflow.domain.entities.processing import ProcessingStepStatus
from noteflow.infrastructure.logging import get_logger, log_state_transition
from noteflow.infrastructure.persistence.repositories import DiarizationJob
from ...proto import noteflow_pb2
from ..errors import ERR_CANCELLED_BY_USER
from .._processing_status import ProcessingStatusUpdate, update_processing_status
from ._types import DIARIZATION_TIMEOUT_SECONDS
if TYPE_CHECKING:
@@ -47,6 +49,15 @@ class JobStatusMixin:
speaker_ids=speaker_ids,
)
await repo.commit()
if job is not None:
await update_processing_status(
self.create_repository_provider,
job.meeting_id,
ProcessingStatusUpdate(
step="diarization",
status=ProcessingStepStatus.COMPLETED,
),
)
log_state_transition(
"diarization_job",
job_id,
@@ -81,6 +92,16 @@ class JobStatusMixin:
error_message=error_msg,
)
await repo.commit()
if meeting_id is not None:
await update_processing_status(
self.create_repository_provider,
meeting_id,
ProcessingStatusUpdate(
step="diarization",
status=ProcessingStepStatus.FAILED,
error_message=error_msg,
),
)
log_state_transition(
"diarization_job",
job_id,
@@ -109,6 +130,16 @@ class JobStatusMixin:
error_message=ERR_CANCELLED_BY_USER,
)
await repo.commit()
if meeting_id is not None:
await update_processing_status(
self.create_repository_provider,
meeting_id,
ProcessingStatusUpdate(
step="diarization",
status=ProcessingStepStatus.SKIPPED,
error_message=ERR_CANCELLED_BY_USER,
),
)
log_state_transition(
"diarization_job",
job_id,
@@ -138,6 +169,16 @@ class JobStatusMixin:
error_message=str(exc),
)
await repo.commit()
if meeting_id is not None:
await update_processing_status(
self.create_repository_provider,
meeting_id,
ProcessingStatusUpdate(
step="diarization",
status=ProcessingStepStatus.FAILED,
error_message=str(exc),
),
)
log_state_transition(
"diarization_job",
job_id,

View File

@@ -4,11 +4,13 @@ from __future__ import annotations
from typing import TYPE_CHECKING, cast
from noteflow.domain.entities.processing import ProcessingStepStatus
from noteflow.infrastructure.logging import get_logger
from ..proto import noteflow_pb2
from ._types import GrpcContext
from ._model_status import log_model_status
from ._processing_status import ProcessingStatusUpdate, update_processing_status
from .converters import entity_to_proto, parse_meeting_id_or_abort
from .errors import (
ENTITY_ENTITY,
@@ -25,7 +27,8 @@ from .protocols import EntitiesRepositoryProvider
if TYPE_CHECKING:
from collections.abc import Callable
from noteflow.application.services.ner import NerService
from noteflow.application.services.ner import ExtractionResult, NerService
from noteflow.domain.value_objects import MeetingId
logger = get_logger(__name__)
@@ -41,6 +44,77 @@ class EntitiesMixin:
ner_service: NerService | None
create_repository_provider: Callable[..., object]
async def _mark_entities_step(
self,
meeting_id_str: str,
status: ProcessingStepStatus,
error_message: str | None = None,
) -> None:
await update_processing_status(
self.create_repository_provider,
meeting_id_str,
ProcessingStatusUpdate(
step="entities",
status=status,
error_message=error_message,
),
)
async def _run_entity_extraction(
self,
meeting_id: MeetingId,
meeting_id_str: str,
force_refresh: bool,
context: GrpcContext,
) -> ExtractionResult:
ner_service = await require_ner_service(self.ner_service, context)
try:
return await ner_service.extract_entities(
meeting_id=meeting_id,
force_refresh=force_refresh,
)
except ValueError:
await abort_not_found(context, ENTITY_MEETING, meeting_id_str)
raise
async def _handle_entities_failure(
self,
meeting_id_str: str,
error: Exception,
context: GrpcContext | None = None,
) -> None:
await self._mark_entities_step(
meeting_id_str,
ProcessingStepStatus.FAILED,
str(error),
)
if isinstance(error, RuntimeError) and context is not None:
await abort_failed_precondition(context, str(error))
async def _extract_entities_with_status(
self,
meeting_id: MeetingId,
meeting_id_str: str,
force_refresh: bool,
context: GrpcContext,
) -> ExtractionResult:
await self._mark_entities_step(meeting_id_str, ProcessingStepStatus.RUNNING)
try:
result = await self._run_entity_extraction(
meeting_id,
meeting_id_str,
force_refresh,
context,
)
except ValueError:
raise
except Exception as exc:
await self._handle_entities_failure(meeting_id_str, exc, context)
raise
await self._mark_entities_step(meeting_id_str, ProcessingStepStatus.COMPLETED)
return result
async def ExtractEntities(
self,
request: noteflow_pb2.ExtractEntitiesRequest,
@@ -53,21 +127,12 @@ class EntitiesMixin:
"""
log_model_status(self, "ner_extract_start", meeting_id=request.meeting_id)
meeting_id = await parse_meeting_id_or_abort(request.meeting_id, context)
ner_service = await require_ner_service(self.ner_service, context)
try:
result = await ner_service.extract_entities(
meeting_id=meeting_id,
force_refresh=request.force_refresh,
)
except ValueError:
# Meeting not found
await abort_not_found(context, ENTITY_MEETING, request.meeting_id)
raise # Unreachable: abort raises
except RuntimeError as e:
# Feature disabled
await abort_failed_precondition(context, str(e))
raise # Unreachable: abort raises
result = await self._extract_entities_with_status(
meeting_id,
request.meeting_id,
request.force_refresh,
context,
)
# Convert to proto
proto_entities = [entity_to_proto(entity) for entity in result.entities]

View File

@@ -7,6 +7,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING
from uuid import UUID
from noteflow.domain.entities.processing import ProcessingStatus, ProcessingStepState
from noteflow.domain.ports.unit_of_work import UnitOfWork
from noteflow.domain.value_objects import MeetingId, MeetingState
from noteflow.infrastructure.logging import get_logger, log_state_transition
@@ -133,6 +134,7 @@ async def _complete_without_summary(
"Post-processing: no segments, completing without summary",
meeting_id=meeting_id,
)
_set_summary_processing_status(meeting, ProcessingStepState.skipped())
_complete_meeting(meeting, meeting_id)
await repo.meetings.update(meeting)
await repo.commit()
@@ -140,6 +142,7 @@ async def _complete_without_summary(
async def _save_summary_and_complete(context: _SummaryCompletionContext) -> Summary:
saved_summary = await context.repo.summaries.save(context.summary)
_set_summary_processing_status(context.meeting, ProcessingStepState.completed())
_complete_meeting(context.meeting, context.meeting_id)
await context.repo.meetings.update(context.meeting)
await context.repo.commit()
@@ -159,6 +162,16 @@ def _complete_meeting(meeting: Meeting, meeting_id: str) -> None:
log_state_transition("meeting", meeting_id, previous_state, meeting.state)
def _set_summary_processing_status(meeting: Meeting, step_state: ProcessingStepState) -> None:
current = meeting.processing_status or ProcessingStatus.create_pending()
meeting.processing_status = ProcessingStatus(
summary=step_state,
entities=current.entities,
diarization=current.diarization,
queued_at=current.queued_at,
)
async def _trigger_summary_webhook(
webhook_service: WebhookService | None,
meeting: Meeting,

View File

@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, cast
from uuid import UUID
from noteflow.config.constants import DEFAULT_MEETING_TITLE
from noteflow.domain.entities import Meeting
from noteflow.domain.entities import Meeting, ProcessingStatus
from noteflow.domain.entities.meeting import MeetingCreateParams
from noteflow.domain.identity import OperationContext
from noteflow.domain.value_objects import MeetingId, MeetingState
@@ -91,6 +91,8 @@ async def _stop_meeting_and_persist(context: _StopMeetingContext) -> Meeting:
previous_state,
context.context,
)
if context.meeting.processing_status is None:
context.meeting.processing_status = ProcessingStatus.create_pending()
await context.repo.meetings.update(context.meeting)
if context.repo.supports_diarization_jobs:

View File

@@ -2,11 +2,13 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, cast
from sqlalchemy.exc import IntegrityError
from noteflow.domain.entities import Segment, Summary
from noteflow.domain.entities.processing import ProcessingStepStatus
from noteflow.domain.ports.unit_of_work import UnitOfWork
from noteflow.domain.value_objects import MeetingId
from noteflow.infrastructure.logging import get_logger
@@ -15,6 +17,7 @@ from ...proto import noteflow_pb2
from ...startup.startup import auto_enable_cloud_llm
from .._types import GrpcContext
from .._model_status import log_model_status
from .._processing_status import ProcessingStatusUpdate, update_processing_status
from ..converters import parse_meeting_id_or_abort, summary_to_proto
from ..errors import ENTITY_MEETING, abort_not_found
from ._summary_generation import generate_placeholder_summary, summarize_or_placeholder
@@ -35,6 +38,16 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
@dataclass(frozen=True, slots=True)
class _SummaryGenerationContext:
meeting_id: MeetingId
meeting_id_str: str
meeting: Meeting
segments: list[Segment]
style_prompt: str | None
force_regenerate: bool
class SummarizationGenerationMixin:
"""Generate summaries and handle summary webhooks."""
@@ -59,12 +72,13 @@ class SummarizationGenerationMixin:
await self._prepare_summary_request(request, context)
)
if existing and not request.force_regenerate:
await self._mark_summary_step(request.meeting_id, ProcessingStepStatus.COMPLETED)
return summary_to_proto(existing)
await self._ensure_cloud_provider()
async with cast(UnitOfWork, self.create_repository_provider()) as repo:
style_prompt = await resolve_template_prompt(
style_prompt = await self._resolve_style_prompt(
TemplateResolutionInputs(
request=request,
meeting=meeting,
@@ -76,23 +90,75 @@ class SummarizationGenerationMixin:
summarization_service=self.summarization_service,
)
)
summary = await summarize_or_placeholder(
self.summarization_service,
meeting_id,
segments,
style_prompt,
)
saved, trigger_webhook = await self._save_summary(
meeting=meeting,
summary=summary,
force_regenerate=request.force_regenerate,
saved, trigger_webhook = await self._generate_summary_with_status(
_SummaryGenerationContext(
meeting_id=meeting_id,
meeting_id_str=request.meeting_id,
meeting=meeting,
segments=segments,
style_prompt=style_prompt,
force_regenerate=request.force_regenerate,
)
)
if trigger_webhook:
await self._trigger_summary_webhook(meeting, saved)
return summary_to_proto(saved)
async def _mark_summary_step(
self,
meeting_id_str: str,
status: ProcessingStepStatus,
error_message: str | None = None,
) -> None:
await update_processing_status(
self.create_repository_provider,
meeting_id_str,
ProcessingStatusUpdate(
step="summary",
status=status,
error_message=error_message,
),
)
async def _resolve_style_prompt(self, inputs: TemplateResolutionInputs) -> str | None:
return await resolve_template_prompt(inputs)
async def _generate_summary_with_status(
self,
context: _SummaryGenerationContext,
) -> tuple[Summary, bool]:
meeting_id = context.meeting_id
meeting_id_str = context.meeting_id_str
meeting = context.meeting
segments = context.segments
style_prompt = context.style_prompt
force_regenerate = context.force_regenerate
await self._mark_summary_step(meeting_id_str, ProcessingStepStatus.RUNNING)
try:
summary = await summarize_or_placeholder(
self.summarization_service,
meeting_id,
segments,
style_prompt,
)
saved, trigger_webhook = await self._save_summary(
meeting=meeting,
summary=summary,
force_regenerate=force_regenerate,
)
await self._mark_summary_step(meeting_id_str, ProcessingStepStatus.COMPLETED)
return saved, trigger_webhook
except Exception as exc:
await self._mark_summary_step(
meeting_id_str,
ProcessingStepStatus.FAILED,
str(exc),
)
raise
async def _prepare_summary_request(
self,
request: noteflow_pb2.GenerateSummaryRequest,

View File

@@ -4,7 +4,15 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from noteflow.domain.constants.fields import SEGMENT_IDS
from noteflow.domain.constants.fields import (
CATEGORY,
CONFIDENCE,
IS_PINNED,
MEETING_ID,
NORMALIZED_TEXT,
SEGMENT_IDS,
TEXT,
)
from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity
from noteflow.domain.value_objects import MeetingId
@@ -60,11 +68,11 @@ class NerConverter:
"""
return {
"id": entity.id,
"meeting_id": entity.meeting_id,
"text": entity.text,
"normalized_text": entity.normalized_text,
"category": entity.category.value,
MEETING_ID: entity.meeting_id,
TEXT: entity.text,
NORMALIZED_TEXT: entity.normalized_text,
CATEGORY: entity.category.value,
SEGMENT_IDS: entity.segment_ids,
"confidence": entity.confidence,
"is_pinned": entity.is_pinned,
CONFIDENCE: entity.confidence,
IS_PINNED: entity.is_pinned,
}

View File

@@ -2,7 +2,9 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from collections.abc import Mapping
from datetime import datetime
from typing import TYPE_CHECKING, cast
from noteflow.domain.constants.fields import END_TIME, START_TIME
from noteflow.domain.entities import (
@@ -16,13 +18,17 @@ from noteflow.domain.entities import (
from noteflow.domain.entities import (
WordTiming as DomainWordTiming,
)
from noteflow.domain.entities.processing import (
ProcessingStatus,
ProcessingStepState,
ProcessingStepStatus,
)
from noteflow.domain.value_objects import (
AnnotationId,
AnnotationType,
MeetingId,
MeetingState,
)
if TYPE_CHECKING:
from noteflow.infrastructure.persistence.models import (
ActionItemModel,
@@ -34,10 +40,108 @@ if TYPE_CHECKING:
WordTimingModel,
)
class OrmConverter:
"""Convert between ORM models and domain entities."""
@staticmethod
def _parse_datetime(value: object) -> datetime | None:
"""Parse ISO datetime strings stored in JSON payloads."""
if not isinstance(value, str):
return None
try:
return datetime.fromisoformat(value)
except ValueError:
return None
@staticmethod
def _parse_processing_step_status(value: object) -> ProcessingStepStatus:
"""Parse a processing step status from JSON payload values."""
if isinstance(value, ProcessingStepStatus):
return value
if isinstance(value, str):
try:
return ProcessingStepStatus(value)
except ValueError:
return ProcessingStepStatus.PENDING
return ProcessingStepStatus.PENDING
@staticmethod
def _processing_step_state_from_payload(
payload: Mapping[str, object] | None,
) -> ProcessingStepState:
"""Convert JSON payload to ProcessingStepState."""
if payload is None:
return ProcessingStepState.pending()
status = OrmConverter._parse_processing_step_status(payload.get("status"))
error_message_value = payload.get("error_message")
error_message = error_message_value if isinstance(error_message_value, str) else ""
return ProcessingStepState(
status=status,
error_message=error_message,
started_at=OrmConverter._parse_datetime(payload.get("started_at")),
completed_at=OrmConverter._parse_datetime(payload.get("completed_at")),
)
@staticmethod
def _coerce_step_payload(value: object) -> Mapping[str, object] | None:
if isinstance(value, Mapping):
return cast(Mapping[str, object], value)
return None
@staticmethod
def processing_status_from_payload(
payload: Mapping[str, object] | None,
) -> ProcessingStatus | None:
"""Convert JSON payload to ProcessingStatus."""
if payload is None:
return None
summary = OrmConverter._processing_step_state_from_payload(
OrmConverter._coerce_step_payload(payload.get("summary")),
)
entities = OrmConverter._processing_step_state_from_payload(
OrmConverter._coerce_step_payload(payload.get("entities")),
)
diarization = OrmConverter._processing_step_state_from_payload(
OrmConverter._coerce_step_payload(payload.get("diarization")),
)
return ProcessingStatus(
summary=summary,
entities=entities,
diarization=diarization,
queued_at=OrmConverter._parse_datetime(payload.get("queued_at")),
)
@staticmethod
def _processing_step_state_to_payload(
state: ProcessingStepState,
) -> dict[str, object]:
"""Convert ProcessingStepState to JSON payload."""
payload: dict[str, object] = {
"status": state.status.value,
"error_message": state.error_message,
}
if state.started_at is not None:
payload["started_at"] = state.started_at.isoformat()
if state.completed_at is not None:
payload["completed_at"] = state.completed_at.isoformat()
return payload
@staticmethod
def processing_status_to_payload(
status: ProcessingStatus | None,
) -> dict[str, object] | None:
"""Convert ProcessingStatus to JSON payload."""
if status is None:
return None
payload: dict[str, object] = {
"summary": OrmConverter._processing_step_state_to_payload(status.summary),
"entities": OrmConverter._processing_step_state_to_payload(status.entities),
"diarization": OrmConverter._processing_step_state_to_payload(status.diarization),
}
if status.queued_at is not None:
payload["queued_at"] = status.queued_at.isoformat()
return payload
# --- WordTiming ---
@staticmethod
@@ -113,6 +217,7 @@ class OrmConverter:
wrapped_dek=model.wrapped_dek,
asset_path=model.asset_path,
version=model.version,
processing_status=OrmConverter.processing_status_from_payload(model.processing_status),
)
# --- Segment ---

View File

@@ -0,0 +1,37 @@
"""add_meeting_processing_status
Revision ID: t4u5v6w7x8y9
Revises: s3t4u5v6w7x8
Create Date: 2026-01-15 00:00:00.000000
"""
from collections.abc import Sequence
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "t4u5v6w7x8y9"
down_revision: str | Sequence[str] | None = "s3t4u5v6w7x8"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Add processing_status column to meetings."""
op.execute(
"""
ALTER TABLE noteflow.meetings
ADD COLUMN processing_status JSONB;
"""
)
def downgrade() -> None:
"""Remove processing_status column from meetings."""
op.execute(
"""
ALTER TABLE noteflow.meetings
DROP COLUMN processing_status;
"""
)

View File

@@ -17,7 +17,7 @@ from sqlalchemy import (
Text,
UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from .._base import DEFAULT_USER_ID, DEFAULT_WORKSPACE_ID, EMBEDDING_DIM, Base
@@ -115,6 +115,10 @@ class MeetingModel(UuidPrimaryKeyMixin, CreatedAtMixin, MetadataMixin, Base):
Text,
nullable=True,
)
processing_status: Mapped[dict[str, object] | None] = mapped_column(
JSONB,
nullable=True,
)
version: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
# Soft delete support
deleted_at: Mapped[datetime | None] = mapped_column(

View File

@@ -9,6 +9,8 @@ from sqlalchemy import Boolean, Float, ForeignKey, Integer, String, Text, Unique
from sqlalchemy.dialects.postgresql import ARRAY, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from noteflow.domain.constants.fields import NORMALIZED_TEXT
from .._base import Base
from .._mixins import CreatedAtMixin, UpdatedAtMixin, UuidPrimaryKeyMixin
from .._strings import (
@@ -31,7 +33,7 @@ class NamedEntityModel(UuidPrimaryKeyMixin, CreatedAtMixin, UpdatedAtMixin, Base
__tablename__ = "named_entities"
__table_args__: tuple[UniqueConstraint, dict[str, str]] = (
UniqueConstraint(
"meeting_id", "normalized_text", name="uq_named_entities_meeting_text"
"meeting_id", NORMALIZED_TEXT, name="uq_named_entities_meeting_text"
),
{"schema": "noteflow"},
)

View File

@@ -6,8 +6,19 @@ from collections.abc import Sequence
from typing import TYPE_CHECKING
from uuid import UUID
from sqlalchemy import delete, select
from sqlalchemy import delete, func, select
from sqlalchemy.dialects.postgresql import insert
from noteflow.domain.constants.fields import (
CATEGORY,
CONFIDENCE,
IS_PINNED,
MEETING_ID,
NORMALIZED_TEXT,
SEGMENT_IDS,
TEXT,
UPDATED_AT,
)
from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity
from noteflow.infrastructure.converters.ner_converters import NerConverter
from noteflow.infrastructure.logging import get_logger
@@ -43,7 +54,7 @@ class SqlAlchemyEntityRepository(
async def save(self, entity: NamedEntity) -> NamedEntity:
"""Save or update a named entity.
Uses merge to handle both insert and update cases.
Uses an upsert on (meeting_id, normalized_text) to avoid duplicate key errors.
Args:
entity: The entity to save.
@@ -52,10 +63,24 @@ class SqlAlchemyEntityRepository(
Saved entity with db_id populated.
"""
kwargs = NerConverter.to_orm_kwargs(entity)
model = NamedEntityModel(**kwargs)
merged = await self._session.merge(model)
await self._session.flush()
entity.db_id = merged.id
stmt = insert(NamedEntityModel).values(**kwargs)
excluded = stmt.excluded
stmt = stmt.on_conflict_do_update(
index_elements=[MEETING_ID, NORMALIZED_TEXT],
set_={
TEXT: excluded.text,
NORMALIZED_TEXT: excluded.normalized_text,
CATEGORY: excluded.category,
SEGMENT_IDS: excluded.segment_ids,
CONFIDENCE: excluded.confidence,
IS_PINNED: NamedEntityModel.is_pinned | excluded.is_pinned,
UPDATED_AT: func.now(),
},
).returning(NamedEntityModel.id)
result = await self._session.execute(stmt)
row = result.first()
if row is not None:
entity.db_id = row[0]
logger.info(
"entity_saved",
entity_id=str(entity.id),
@@ -67,8 +92,7 @@ class SqlAlchemyEntityRepository(
async def save_batch(self, entities: Sequence[NamedEntity]) -> Sequence[NamedEntity]:
"""Save multiple entities efficiently.
Uses individual merges to handle upsert semantics with the unique
constraint on (meeting_id, normalized_text).
Uses an upsert on (meeting_id, normalized_text) to avoid duplicate key errors.
Args:
entities: List of entities to save.
@@ -76,13 +100,38 @@ class SqlAlchemyEntityRepository(
Returns:
Saved entities with db_ids populated.
"""
for entity in entities:
kwargs = NerConverter.to_orm_kwargs(entity)
model = NamedEntityModel(**kwargs)
merged = await self._session.merge(model)
entity.db_id = merged.id
if not entities:
return entities
await self._session.flush()
payload = [NerConverter.to_orm_kwargs(entity) for entity in entities]
stmt = insert(NamedEntityModel).values(payload)
excluded = stmt.excluded
stmt = stmt.on_conflict_do_update(
index_elements=[MEETING_ID, NORMALIZED_TEXT],
set_={
TEXT: excluded.text,
NORMALIZED_TEXT: excluded.normalized_text,
CATEGORY: excluded.category,
SEGMENT_IDS: excluded.segment_ids,
CONFIDENCE: excluded.confidence,
IS_PINNED: NamedEntityModel.is_pinned | excluded.is_pinned,
UPDATED_AT: func.now(),
},
).returning(
NamedEntityModel.id,
NamedEntityModel.meeting_id,
NamedEntityModel.normalized_text,
)
result = await self._session.execute(stmt)
rows = result.all()
if rows:
ids_by_key = {
(str(row.meeting_id), row.normalized_text): row.id
for row in rows
}
for entity in entities:
key = (str(entity.meeting_id), entity.normalized_text)
entity.db_id = ids_by_key.get(key, entity.db_id)
if entities:
logger.info(
"entities_batch_saved",

View File

@@ -45,6 +45,7 @@ class SqlAlchemyMeetingRepository(BaseRepository):
metadata_=meeting.metadata,
wrapped_dek=meeting.wrapped_dek,
asset_path=meeting.asset_path,
processing_status=OrmConverter.processing_status_to_payload(meeting.processing_status),
version=meeting.version,
)
self._session.add(model)
@@ -96,6 +97,9 @@ class SqlAlchemyMeetingRepository(BaseRepository):
model.metadata_ = dict(meeting.metadata)
model.wrapped_dek = meeting.wrapped_dek
model.asset_path = meeting.asset_path
model.processing_status = OrmConverter.processing_status_to_payload(
meeting.processing_status
)
model.version += 1
meeting.version = model.version