From 1d2bc250248d13c241a720739b74c662ed1f0898 Mon Sep 17 00:00:00 2001 From: Travis Vasceannie Date: Fri, 16 Jan 2026 01:34:33 +0000 Subject: [PATCH] refactor: add API initialization memoization and replace initializeAPI with getAPI across hooks - Added initializePromise singleton to prevent concurrent API initialization attempts and ensure single initialization - Wrapped initializeAPI logic in IIFE that caches promise and clears on error for retry capability - Replaced all initializeAPI() calls with synchronous getAPI() in hooks (useAudioDevices, useDiarization, usePostProcessing, useRecordingAppPolicy) - Updated corresponding test mocks from --- client/src/api/index.ts | 130 +++++----- client/src/hooks/post-processing/state.ts | 6 +- client/src/hooks/use-audio-devices.test.ts | 20 +- client/src/hooks/use-audio-devices.ts | 14 +- client/src/hooks/use-diarization.test.ts | 6 +- client/src/hooks/use-diarization.ts | 10 +- client/src/hooks/use-post-processing.test.ts | 8 +- client/src/hooks/use-post-processing.ts | 15 +- client/src/hooks/use-recording-app-policy.ts | 4 +- client/src/lib/preferences/api.ts | 21 +- client/src/lib/preferences/core.ts | 7 + client/src/lib/preferences/tauri.ts | 8 +- .../meeting-detail/use-meeting-detail.ts | 71 ++++-- .../application/services/ner/_dedupe.py | 41 +++ .../application/services/ner/service.py | 5 +- src/noteflow/domain/constants/fields.py | 7 + src/noteflow/domain/entities/named_entity.py | 8 +- .../grpc/mixins/_processing_status.py | 166 ++++++++++++ .../grpc/mixins/diarization/_execution.py | 2 +- .../mixins/diarization/_job_validation.py | 143 +++++++++++ src/noteflow/grpc/mixins/diarization/_jobs.py | 240 ++++++------------ .../grpc/mixins/diarization/_status.py | 41 +++ src/noteflow/grpc/mixins/entities.py | 97 +++++-- .../grpc/mixins/meeting/_post_processing.py | 13 + .../grpc/mixins/meeting/meeting_mixin.py | 4 +- .../mixins/summarization/_generation_mixin.py | 90 ++++++- .../converters/ner_converters.py | 22 +- .../converters/orm_converters.py | 111 +++++++- ...5v6w7x8y9_add_meeting_processing_status.py | 37 +++ .../persistence/models/core/meeting.py | 6 +- .../models/entities/named_entity.py | 4 +- .../persistence/repositories/entity_repo.py | 77 +++++- .../persistence/repositories/meeting_repo.py | 4 + 33 files changed, 1092 insertions(+), 346 deletions(-) create mode 100644 src/noteflow/application/services/ner/_dedupe.py create mode 100644 src/noteflow/grpc/mixins/_processing_status.py create mode 100644 src/noteflow/grpc/mixins/diarization/_job_validation.py create mode 100644 src/noteflow/infrastructure/persistence/migrations/versions/t4u5v6w7x8y9_add_meeting_processing_status.py diff --git a/client/src/api/index.ts b/client/src/api/index.ts index 704e54a..31aa4e7 100644 --- a/client/src/api/index.ts +++ b/client/src/api/index.ts @@ -31,6 +31,7 @@ import { startReconnection } from './reconnection'; import { initializeTauriAPI } from './tauri-adapter'; const log = debug('NoteFlowAPI'); +let initializePromise: Promise | null = null; // ============================================================================ // API Initialization @@ -46,65 +47,78 @@ const log = debug('NoteFlowAPI'); * Sprint GAP-009: Event bridge starts before connection to capture early events. */ export async function initializeAPI(): Promise { - // Always try Tauri first - initializeTauriAPI tests the API and throws if unavailable + if (initializePromise) { + return initializePromise; + } + + initializePromise = (async () => { + // Always try Tauri first - initializeTauriAPI tests the API and throws if unavailable + try { + const tauriAPI = await initializeTauriAPI(); + setAPIInstance(tauriAPI); + + try { + const { invoke } = await import('@tauri-apps/api/core'); + window.__NOTEFLOW_TEST_INVOKE__ = invoke; + } catch (error) { + log('Test invoke binding unavailable (expected in non-Tauri contexts)', { + error: error instanceof Error ? error.message : String(error), + }); + } + + // Sprint GAP-009: Start event bridge before connection to capture early events + // (e.g., connection errors, early warnings). Non-critical if it fails. + await startTauriEventBridge().catch((error) => { + addClientLog({ + level: 'warning', + source: 'api', + message: 'Event bridge initialization failed - continuing without early events', + details: error instanceof Error ? error.message : String(error), + metadata: { context: 'api_event_bridge_init' }, + }); + }); + + // Attempt to connect to the gRPC server + try { + const preferredUrl = preferences.getServerUrl(); + await tauriAPI.connect(preferredUrl || undefined); + setConnectionMode('connected'); + await preferences.initialize(); + startReconnection(); + // Sprint GAP-007: Log successful connection + log('Adapter: Tauri | Mode: connected', { server: preferredUrl || 'default' }); + return tauriAPI; + } catch (connectError) { + // Connection failed - fall back to cached mode but keep Tauri adapter + const message = extractErrorMessage(connectError, 'Connection failed'); + setConnectionMode('cached', message); + await preferences.initialize(); + startReconnection(); + // Sprint GAP-007: Log cached mode fallback + addClientLog({ + level: 'warning', + source: 'api', + message: 'Adapter fallback to cached mode', + details: message, + metadata: { context: 'api_cached_mode_fallback' }, + }); + return tauriAPI; // Keep Tauri adapter for reconnection attempts + } + } catch (_tauriError) { + // Tauri unavailable - use mock API (we're in a browser) + setConnectionMode('mock'); + setAPIInstance(mockAPI); + // Sprint GAP-007: Log mock mode + log('Adapter: Mock | Mode: mock | Environment: Browser'); + return mockAPI; + } + })(); + try { - const tauriAPI = await initializeTauriAPI(); - setAPIInstance(tauriAPI); - - try { - const { invoke } = await import('@tauri-apps/api/core'); - window.__NOTEFLOW_TEST_INVOKE__ = invoke; - } catch (error) { - log('Test invoke binding unavailable (expected in non-Tauri contexts)', { - error: error instanceof Error ? error.message : String(error), - }); - } - - // Sprint GAP-009: Start event bridge before connection to capture early events - // (e.g., connection errors, early warnings). Non-critical if it fails. - await startTauriEventBridge().catch((error) => { - addClientLog({ - level: 'warning', - source: 'api', - message: 'Event bridge initialization failed - continuing without early events', - details: error instanceof Error ? error.message : String(error), - metadata: { context: 'api_event_bridge_init' }, - }); - }); - - // Attempt to connect to the gRPC server - try { - const preferredUrl = preferences.getServerUrl(); - await tauriAPI.connect(preferredUrl || undefined); - setConnectionMode('connected'); - await preferences.initialize(); - startReconnection(); - // Sprint GAP-007: Log successful connection - log('Adapter: Tauri | Mode: connected', { server: preferredUrl || 'default' }); - return tauriAPI; - } catch (connectError) { - // Connection failed - fall back to cached mode but keep Tauri adapter - const message = extractErrorMessage(connectError, 'Connection failed'); - setConnectionMode('cached', message); - await preferences.initialize(); - startReconnection(); - // Sprint GAP-007: Log cached mode fallback - addClientLog({ - level: 'warning', - source: 'api', - message: 'Adapter fallback to cached mode', - details: message, - metadata: { context: 'api_cached_mode_fallback' }, - }); - return tauriAPI; // Keep Tauri adapter for reconnection attempts - } - } catch (_tauriError) { - // Tauri unavailable - use mock API (we're in a browser) - setConnectionMode('mock'); - setAPIInstance(mockAPI); - // Sprint GAP-007: Log mock mode - log('Adapter: Mock | Mode: mock | Environment: Browser'); - return mockAPI; + return await initializePromise; + } catch (error) { + initializePromise = null; + throw error; } } diff --git a/client/src/hooks/post-processing/state.ts b/client/src/hooks/post-processing/state.ts index bbe14fc..de87eee 100644 --- a/client/src/hooks/post-processing/state.ts +++ b/client/src/hooks/post-processing/state.ts @@ -136,6 +136,10 @@ export function shouldAutoStartProcessing( // If processing is already complete or in progress, don't restart const { summary, entities, diarization } = processingStatus; + const anyFailed = + summary.status === 'failed' || + entities.status === 'failed' || + diarization.status === 'failed'; const allTerminal = ['completed', 'failed', 'skipped'].includes(summary.status) && ['completed', 'failed', 'skipped'].includes(entities.status) && @@ -147,7 +151,7 @@ export function shouldAutoStartProcessing( diarization.status === 'running'; // If all done or any running, don't auto-start - if (allTerminal || anyRunning) { + if (allTerminal || anyRunning || anyFailed) { return false; } diff --git a/client/src/hooks/use-audio-devices.test.ts b/client/src/hooks/use-audio-devices.test.ts index 1e12da6..aa5cdc2 100644 --- a/client/src/hooks/use-audio-devices.test.ts +++ b/client/src/hooks/use-audio-devices.test.ts @@ -1,6 +1,6 @@ import { act, renderHook } from '@testing-library/react'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; -import { initializeAPI } from '@/api'; +import { getAPI } from '@/api'; import { TauriCommands } from '@/api/constants'; import { isTauriEnvironment } from '@/api/tauri-adapter'; import { toast } from '@/hooks/use-toast'; @@ -9,7 +9,7 @@ import { useTauriEvent } from '@/lib/tauri-events'; import { useAudioDevices } from './use-audio-devices'; vi.mock('@/api', () => ({ - initializeAPI: vi.fn(), + getAPI: vi.fn(), })); vi.mock('@/api/tauri-adapter', async () => { @@ -172,7 +172,7 @@ describe('useAudioDevices', () => { it('loads devices in tauri mode and syncs selection', async () => { vi.mocked(isTauriEnvironment).mockReturnValue(true); - vi.mocked(initializeAPI).mockResolvedValue({ + vi.mocked(getAPI).mockReturnValue({ getPreferences: vi.fn().mockResolvedValue({ audio_devices: { input_device_id: '', output_device_id: '' }, }), @@ -188,7 +188,7 @@ describe('useAudioDevices', () => { system_gain: 1, }), selectAudioDevice: vi.fn(), - } as unknown as Awaited>); + } as unknown as ReturnType); const { result } = renderHook(() => useAudioDevices({ showToasts: false })); @@ -218,7 +218,7 @@ describe('useAudioDevices', () => { audio_devices: { input_device_id: 'input:0:Mic', output_device_id: 'output:1:Speakers' }, }); const selectAudioDevice = vi.fn(); - vi.mocked(initializeAPI).mockResolvedValue({ + vi.mocked(getAPI).mockReturnValue({ getPreferences: vi.fn().mockResolvedValue({ audio_devices: { input_device_id: 'input:0:Mic', output_device_id: 'output:1:Speakers' }, }), @@ -234,7 +234,7 @@ describe('useAudioDevices', () => { system_gain: 1, }), selectAudioDevice, - } as unknown as Awaited>); + } as unknown as ReturnType); const { result } = renderHook(() => useAudioDevices({ showToasts: false })); @@ -387,7 +387,7 @@ describe('useAudioDevices', () => { }); const selectAudioDevice = vi.fn(); - vi.mocked(initializeAPI).mockResolvedValue({ + vi.mocked(getAPI).mockReturnValue({ getPreferences: vi.fn().mockImplementation(() => Promise.resolve({ audio_devices: prefsState.audio_devices }) ), @@ -403,7 +403,7 @@ describe('useAudioDevices', () => { system_gain: 1, }), selectAudioDevice, - } as unknown as Awaited>); + } as unknown as ReturnType); const first = renderHook(() => useAudioDevices({ showToasts: false })); @@ -644,10 +644,10 @@ describe('useAudioDevices', () => { it('sets selected devices and syncs tauri selection', async () => { vi.mocked(isTauriEnvironment).mockReturnValue(true); const selectAudioDevice = vi.fn(); - vi.mocked(initializeAPI).mockResolvedValue({ + vi.mocked(getAPI).mockReturnValue({ listAudioDevices: vi.fn().mockResolvedValue([]), selectAudioDevice, - } as unknown as Awaited>); + } as unknown as ReturnType); const { result } = renderHook(() => useAudioDevices({ showToasts: false })); diff --git a/client/src/hooks/use-audio-devices.ts b/client/src/hooks/use-audio-devices.ts index 4aeb2af..c5758d2 100644 --- a/client/src/hooks/use-audio-devices.ts +++ b/client/src/hooks/use-audio-devices.ts @@ -1,5 +1,5 @@ import { useCallback, useEffect, useRef, useState } from 'react'; -import { initializeAPI } from '@/api'; +import { getAPI } from '@/api'; import { clientLog } from '@/lib/client-log-events'; import { addClientLog } from '@/lib/client-logs'; import { isTauriEnvironment } from '@/api/tauri-adapter'; @@ -133,7 +133,7 @@ export function useAudioDevices(options: UseAudioDevicesOptions = {}): UseAudioD try { if (isTauriEnvironment()) { - const api = await initializeAPI(); + const api = getAPI(); let storedPrefs = preferences.get(); try { const tauriPrefs = await api.getPreferences(); @@ -308,7 +308,7 @@ export function useAudioDevices(options: UseAudioDevicesOptions = {}): UseAudioD }); if (isTauriEnvironment()) { try { - const api = await initializeAPI(); + const api = getAPI(); await api.selectAudioDevice(deviceId, true); } catch (err) { const errorMessage = err instanceof Error ? err.message : String(err); @@ -340,7 +340,7 @@ export function useAudioDevices(options: UseAudioDevicesOptions = {}): UseAudioD }); if (isTauriEnvironment()) { try { - const api = await initializeAPI(); + const api = getAPI(); await api.selectAudioDevice(deviceId, false); } catch (err) { const errorMessage = err instanceof Error ? err.message : String(err); @@ -364,7 +364,7 @@ export function useAudioDevices(options: UseAudioDevicesOptions = {}): UseAudioD preferences.setSystemAudioDevice(deviceId); if (isTauriEnvironment()) { try { - const api = await initializeAPI(); + const api = getAPI(); await api.setSystemAudioDevice(deviceId); } catch (err) { const errorMessage = err instanceof Error ? err.message : String(err); @@ -388,7 +388,7 @@ export function useAudioDevices(options: UseAudioDevicesOptions = {}): UseAudioD preferences.setDualCaptureEnabled(enabled); if (isTauriEnvironment()) { try { - const api = await initializeAPI(); + const api = getAPI(); await api.setDualCaptureEnabled(enabled); if (enabled && !selectedSystemDevice && loopbackDevices.length > 0) { const firstDevice = loopbackDevices[0].deviceId; @@ -418,7 +418,7 @@ export function useAudioDevices(options: UseAudioDevicesOptions = {}): UseAudioD preferences.setAudioMixLevels(newMicGain, newSystemGain); if (isTauriEnvironment()) { try { - const api = await initializeAPI(); + const api = getAPI(); await api.setAudioMixLevels(newMicGain, newSystemGain); } catch (err) { if (showToasts) { diff --git a/client/src/hooks/use-diarization.test.ts b/client/src/hooks/use-diarization.test.ts index 8cb3c2f..f2ed5e9 100644 --- a/client/src/hooks/use-diarization.test.ts +++ b/client/src/hooks/use-diarization.test.ts @@ -19,7 +19,7 @@ import { useDiarization } from './use-diarization'; // Mock the API module vi.mock('@/api', () => ({ - initializeAPI: vi.fn(), + getAPI: vi.fn(), })); // Mock the toast hook @@ -47,9 +47,7 @@ describe('useDiarization', () => { beforeEach(() => { vi.useFakeTimers(); - vi.mocked(api.initializeAPI).mockResolvedValue( - mockAPI as unknown as Awaited> - ); + vi.mocked(api.getAPI).mockReturnValue(mockAPI as unknown as ReturnType); vi.clearAllMocks(); }); diff --git a/client/src/hooks/use-diarization.ts b/client/src/hooks/use-diarization.ts index e19e586..5da9a24 100644 --- a/client/src/hooks/use-diarization.ts +++ b/client/src/hooks/use-diarization.ts @@ -6,7 +6,7 @@ */ import { useCallback, useEffect, useRef, useState } from 'react'; -import { initializeAPI } from '@/api'; +import { getAPI } from '@/api'; import type { DiarizationJobStatus, JobStatus } from '@/api/types'; import { toast } from '@/hooks/use-toast'; import { PollingConfig } from '@/lib/config'; @@ -151,7 +151,7 @@ export function useDiarization(options: UseDiarizationOptions = {}): UseDiarizat } try { - const api = await initializeAPI(); + const api = getAPI(); const status = await api.getDiarizationJobStatus(jobId); if (!isMountedRef.current) { @@ -289,7 +289,7 @@ export function useDiarization(options: UseDiarizationOptions = {}): UseDiarizat retryCountRef.current = 0; try { - const api = await initializeAPI(); + const api = getAPI(); const status = await api.refineSpeakers(meetingId, numSpeakers); if (!isMountedRef.current) { @@ -345,7 +345,7 @@ export function useDiarization(options: UseDiarizationOptions = {}): UseDiarizat stopPolling(); try { - const api = await initializeAPI(); + const api = getAPI(); const result = await api.cancelDiarization(jobId); if (!isMountedRef.current) { @@ -396,7 +396,7 @@ export function useDiarization(options: UseDiarizationOptions = {}): UseDiarizat } try { - const api = await initializeAPI(); + const api = getAPI(); const activeJobs = await api.getActiveDiarizationJobs(); if (!isMountedRef.current) { diff --git a/client/src/hooks/use-post-processing.test.ts b/client/src/hooks/use-post-processing.test.ts index 286c000..5a9b4b4 100644 --- a/client/src/hooks/use-post-processing.test.ts +++ b/client/src/hooks/use-post-processing.test.ts @@ -23,20 +23,20 @@ import { toast } from '@/hooks/use-toast'; import { usePostProcessing } from './use-post-processing'; // Use vi.hoisted to define mocks that are available in vi.mock -const { mockAPI, mockInitializeAPI } = vi.hoisted(() => { +const { mockAPI, mockGetAPI } = vi.hoisted(() => { const mockAPI = { generateSummary: vi.fn(), extractEntities: vi.fn(), refineSpeakers: vi.fn(), getDiarizationJobStatus: vi.fn(), }; - const mockInitializeAPI = vi.fn(() => Promise.resolve(mockAPI)); - return { mockAPI, mockInitializeAPI }; + const mockGetAPI = vi.fn(() => mockAPI); + return { mockAPI, mockGetAPI }; }); // Mock the API module vi.mock('@/api', () => ({ - initializeAPI: mockInitializeAPI, + getAPI: mockGetAPI, })); // Mock the toast hook diff --git a/client/src/hooks/use-post-processing.ts b/client/src/hooks/use-post-processing.ts index 2134c8d..cb7dd7d 100644 --- a/client/src/hooks/use-post-processing.ts +++ b/client/src/hooks/use-post-processing.ts @@ -11,9 +11,10 @@ */ import { useCallback, useEffect, useRef, useState } from 'react'; -import { initializeAPI } from '@/api'; +import { getAPI } from '@/api'; import type { DiarizationJobStatus, MeetingState, ProcessingStatus } from '@/api/types'; import { toast } from '@/hooks/use-toast'; +import { setEntitiesFromExtraction } from '@/lib/entity-store'; import { errorMessageFrom, toastError } from '@/lib/error-reporting'; import { usePostProcessingEvents } from './post-processing/events'; import type { @@ -40,7 +41,6 @@ export type { UsePostProcessingReturn, } from './post-processing/state'; - /** * Hook for orchestrating post-processing after recording stops */ @@ -116,7 +116,7 @@ export function usePostProcessing(options: UsePostProcessingOptions = {}): UsePo } try { - const api = await initializeAPI(); + const api = getAPI(); const status: DiarizationJobStatus = await api.getDiarizationJobStatus(jobId); if (!isMountedRef.current) { @@ -201,7 +201,7 @@ export function usePostProcessing(options: UsePostProcessingOptions = {}): UsePo }); try { - const api = await initializeAPI(); + const api = getAPI(); await api.generateSummary(meetingId, false); if (!isMountedRef.current) { @@ -253,8 +253,9 @@ export function usePostProcessing(options: UsePostProcessingOptions = {}): UsePo }); try { - const api = await initializeAPI(); - await api.extractEntities(meetingId, false); + const api = getAPI(); + const response = await api.extractEntities(meetingId, false); + setEntitiesFromExtraction(response.entities); if (!isMountedRef.current) { return; @@ -305,7 +306,7 @@ export function usePostProcessing(options: UsePostProcessingOptions = {}): UsePo }); try { - const api = await initializeAPI(); + const api = getAPI(); const response = await api.refineSpeakers(meetingId, numSpeakers); if (!isMountedRef.current) { diff --git a/client/src/hooks/use-recording-app-policy.ts b/client/src/hooks/use-recording-app-policy.ts index ad97393..3c2799c 100644 --- a/client/src/hooks/use-recording-app-policy.ts +++ b/client/src/hooks/use-recording-app-policy.ts @@ -1,5 +1,5 @@ import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; -import { initializeAPI } from '@/api'; +import { getAPI } from '@/api'; import { isTauriEnvironment } from '@/api/tauri-adapter'; import type { AppMatcher, @@ -216,7 +216,7 @@ export function useRecordingAppPolicy(): RecordingAppPolicyState { setIsLoadingApps(true); try { - const api = await initializeAPI(); + const api = getAPI(); const page = reset ? 0 : currentPageRef.current; const response: ListInstalledAppsResponse = await api.listInstalledApps({ commonOnly, diff --git a/client/src/lib/preferences/api.ts b/client/src/lib/preferences/api.ts index edff461..01b4e8a 100644 --- a/client/src/lib/preferences/api.ts +++ b/client/src/lib/preferences/api.ts @@ -19,7 +19,7 @@ import { debug } from '@/lib/debug'; const log = debug('Preferences:Replace'); -import { clonePreferences } from './core'; +import { arePreferencesEqual, clonePreferences } from './core'; import { defaultPreferences } from './constants'; import { mergeIntegrationsWithDefaults, resetIntegrationsForServerSwitch } from './integrations'; import { @@ -78,10 +78,6 @@ export const preferences = { }, replace(prefs: UserPreferences): void { - log('Replacing preferences', { - input_device_id: prefs.audio_devices?.input_device_id, - output_device_id: prefs.audio_devices?.output_device_id, - }); const mergedAudioDevices = isRecord(prefs.audio_devices) ? { ...defaultPreferences.audio_devices, ...prefs.audio_devices } : { ...defaultPreferences.audio_devices }; @@ -92,7 +88,16 @@ export const preferences = { audio_devices: mergedAudioDevices, integrations: mergedIntegrations, }; - savePreferences(applyLocalOverrides(merged), persistPreferencesToTauri); + const next = applyLocalOverrides(merged); + const current = loadPreferences(); + if (arePreferencesEqual(current, next)) { + return; + } + log('Replacing preferences', { + input_device_id: next.audio_devices?.input_device_id, + output_device_id: next.audio_devices?.output_device_id, + }); + savePreferences(next, persistPreferencesToTauri); }, resetToDefaults(): void { @@ -462,9 +467,7 @@ export const preferences = { setServerAddressOverride(hasAddress, trimmedHost, trimmedPort); }); }, - // SYNC SCHEDULER - isSyncSchedulerPaused(): boolean { return loadPreferences().sync_scheduler_paused; }, @@ -474,9 +477,7 @@ export const preferences = { prefs.sync_scheduler_paused = paused; }); }, - // SYNC HISTORY - getSyncHistory(): SyncHistoryEvent[] { return loadPreferences().sync_history || []; }, diff --git a/client/src/lib/preferences/core.ts b/client/src/lib/preferences/core.ts index 53a8140..a4eea0c 100644 --- a/client/src/lib/preferences/core.ts +++ b/client/src/lib/preferences/core.ts @@ -145,3 +145,10 @@ export function normalizePreferences(prefs: UserPreferences): UserPreferences { ai_config: normalizeAiConfig(prefs.ai_config), }; } + +/** + * Compare preferences payloads after normalization. + */ +export function arePreferencesEqual(a: UserPreferences, b: UserPreferences): boolean { + return JSON.stringify(normalizePreferences(a)) === JSON.stringify(normalizePreferences(b)); +} diff --git a/client/src/lib/preferences/tauri.ts b/client/src/lib/preferences/tauri.ts index eae42ba..2e48d69 100644 --- a/client/src/lib/preferences/tauri.ts +++ b/client/src/lib/preferences/tauri.ts @@ -7,7 +7,7 @@ import type { Integration, UserPreferences } from '@/api/types'; import { addClientLog } from '@/lib/client-logs'; import { debug } from '@/lib/debug'; -import { isTauriRuntime } from './core'; +import { arePreferencesEqual, isTauriRuntime } from './core'; import { loadPreferences } from './storage'; import { emitValidationEvent } from './validation-events'; @@ -241,8 +241,10 @@ export async function hydratePreferencesFromTauri( }); // Only call replacePrefs if we actually need to update something - // This avoids unnecessary writes - replacePrefs(mergedPrefs); + // This avoids unnecessary writes and persistence churn. + if (!arePreferencesEqual(localPrefs, mergedPrefs)) { + replacePrefs(mergedPrefs); + } hasHydratedFromTauri = true; notifyHydrationComplete(); diff --git a/client/src/pages/meeting-detail/use-meeting-detail.ts b/client/src/pages/meeting-detail/use-meeting-detail.ts index 9fb45de..88905d0 100644 --- a/client/src/pages/meeting-detail/use-meeting-detail.ts +++ b/client/src/pages/meeting-detail/use-meeting-detail.ts @@ -2,7 +2,7 @@ * Hook for meeting detail data and state. */ -import { useEffect, useMemo, useState } from 'react'; +import { useEffect, useMemo, useRef, useState } from 'react'; import { getAPI } from '@/api'; import type { Meeting } from '@/api/types'; @@ -25,19 +25,14 @@ export function useMeetingDetail({ meetingId }: UseMeetingDetailProps) { const [selectedSegment, setSelectedSegment] = useState(null); const [speakerNameMap, setSpeakerNameMap] = useState>(new Map()); const { mode } = useConnectionState(); + const autoStartedRef = useRef>(new Set()); + const autoStartLogRef = useRef>(new Set()); - // Entity extraction hook - const { - state: entityState, - extract: extractEntities, - isExtracting, - } = useEntityExtraction({ - meetingId, - meetingTitle: meeting?.title, - meetingState: meeting?.state, - autoExtract: true, - }); - const { entities } = entityState; + const failedStep = meeting?.processing_status + ? (['summary', 'entities', 'diarization'] as const).find( + (step) => meeting.processing_status?.[step]?.status === 'failed' + ) ?? null + : null; // Post-processing orchestration hook const { @@ -54,6 +49,26 @@ export function useMeetingDetail({ meetingId }: UseMeetingDetailProps) { } }, }); + const hasLocalPostProcessing = processingState.meetingId === meetingId; + const shouldAutoExtract = + !!meeting && + !!meetingId && + meeting.state === 'completed' && + !hasLocalPostProcessing && + !shouldAutoStart(meeting.state, meeting.processing_status); + + // Entity extraction hook + const { + state: entityState, + extract: extractEntities, + isExtracting, + } = useEntityExtraction({ + meetingId, + meetingTitle: meeting?.title, + meetingState: meeting?.state, + autoExtract: shouldAutoExtract, + }); + const { entities } = entityState; // Summary metadata const summaryMeta = useMemo(() => { @@ -128,9 +143,35 @@ export function useMeetingDetail({ meetingId }: UseMeetingDetailProps) { // Auto-start post-processing useEffect(() => { - if (meeting && meetingId && shouldAutoStart(meeting.state, meeting.processing_status)) { - void startProcessing(meetingId); + if (!meeting || !meetingId) { + return; } + if (failedStep && !autoStartLogRef.current.has(meetingId)) { + autoStartLogRef.current.add(meetingId); + addClientLog({ + level: 'info', + source: 'app', + message: 'Post-processing auto-start skipped due to failed step', + metadata: { + meeting_id: meetingId, + failed_step: failedStep, + }, + }); + } + if (!shouldAutoStart(meeting.state, meeting.processing_status)) { + return; + } + if (autoStartedRef.current.has(meetingId)) { + return; + } + autoStartedRef.current.add(meetingId); + addClientLog({ + level: 'info', + source: 'app', + message: 'Post-processing auto-start triggered', + metadata: { meeting_id: meetingId }, + }); + void startProcessing(meetingId); }, [meeting, meetingId, shouldAutoStart, startProcessing]); const handleGenerateSummary = async () => { diff --git a/src/noteflow/application/services/ner/_dedupe.py b/src/noteflow/application/services/ner/_dedupe.py new file mode 100644 index 0000000..6d1593d --- /dev/null +++ b/src/noteflow/application/services/ner/_dedupe.py @@ -0,0 +1,41 @@ +"""Helpers for normalizing extracted entities.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from noteflow.domain.entities.named_entity import NamedEntity +from noteflow.infrastructure.logging import get_logger + +logger = get_logger(__name__) + + +def dedupe_entities(entities: Sequence[NamedEntity]) -> list[NamedEntity]: + """Collapse duplicate entities by normalized text. + + Merges segment IDs and keeps the highest-confidence category. + """ + deduped: dict[str, NamedEntity] = {} + for entity in entities: + key = entity.normalized_text or entity.text.lower().strip() + if not key: + continue + existing = deduped.get(key) + if existing is None: + entity.normalized_text = key + deduped[key] = entity + continue + + existing.merge_segments(entity.segment_ids) + if entity.confidence > existing.confidence: + existing.confidence = entity.confidence + existing.category = entity.category + if len(entity.text) > len(existing.text): + existing.text = entity.text + if len(deduped) != len(entities): + logger.info( + "ner_entities_deduped", + total=len(entities), + deduped=len(deduped), + ) + return list(deduped.values()) diff --git a/src/noteflow/application/services/ner/service.py b/src/noteflow/application/services/ner/service.py index f30e036..edd3b8f 100644 --- a/src/noteflow/application/services/ner/service.py +++ b/src/noteflow/application/services/ner/service.py @@ -7,9 +7,11 @@ Orchestrates NER extraction, caching, and persistence following hexagonal archit from __future__ import annotations import asyncio +from collections.abc import Sequence from dataclasses import dataclass from typing import TYPE_CHECKING +from noteflow.application.services.ner._dedupe import dedupe_entities from noteflow.config.constants import ERROR_MSG_MEETING_PREFIX from noteflow.config.settings import get_feature_flags from noteflow.domain.entities.named_entity import NamedEntity @@ -17,7 +19,7 @@ from noteflow.infrastructure.logging import get_logger, log_timing from noteflow.infrastructure.metrics.memory_logger import log_memory_snapshot if TYPE_CHECKING: - from collections.abc import Callable, Sequence + from collections.abc import Callable from uuid import UUID from noteflow.domain.ports.ner import NerPort @@ -211,6 +213,7 @@ class NerService: entities = await self._extraction_helper.extract(segments) for entity in entities: entity.meeting_id = meeting_id + entities = dedupe_entities(entities) await self._persist_entities(meeting_id, entities, force_refresh) log_memory_snapshot( diff --git a/src/noteflow/domain/constants/fields.py b/src/noteflow/domain/constants/fields.py index d55a94a..c849e86 100644 --- a/src/noteflow/domain/constants/fields.py +++ b/src/noteflow/domain/constants/fields.py @@ -14,6 +14,13 @@ PROVIDER: Final[str] = "provider" UNKNOWN: Final[str] = "unknown" DEVICE: Final[Literal["device"]] = "device" SEGMENT_IDS: Final[Literal["segment_ids"]] = "segment_ids" +MEETING_ID: Final[Literal["meeting_id"]] = "meeting_id" +TEXT: Final[Literal["text"]] = "text" +NORMALIZED_TEXT: Final[Literal["normalized_text"]] = "normalized_text" +CATEGORY: Final[Literal["category"]] = "category" +CONFIDENCE: Final[Literal["confidence"]] = "confidence" +IS_PINNED: Final[Literal["is_pinned"]] = "is_pinned" +UPDATED_AT: Final[Literal["updated_at"]] = "updated_at" PROJECT_ID: Final[str] = "project_id" PROJECT_IDS: Final[str] = "project_ids" CALENDAR: Final[str] = "calendar" diff --git a/src/noteflow/domain/entities/named_entity.py b/src/noteflow/domain/entities/named_entity.py index 6ef2379..56b7578 100644 --- a/src/noteflow/domain/entities/named_entity.py +++ b/src/noteflow/domain/entities/named_entity.py @@ -9,7 +9,7 @@ from uuid import UUID, uuid4 from noteflow.domain.constants.fields import DATE as ENTITY_DATE from noteflow.domain.constants.fields import LOCATION as ENTITY_LOCATION -from noteflow.domain.constants.fields import SEGMENT_IDS +from noteflow.domain.constants.fields import CATEGORY, CONFIDENCE, SEGMENT_IDS, TEXT if TYPE_CHECKING: from noteflow.domain.value_objects import MeetingId @@ -116,10 +116,10 @@ class NamedEntity: ValueError: If text is empty or confidence is out of range. """ # Validate required text - text = kwargs["text"] - category = kwargs["category"] + text = kwargs[TEXT] + category = kwargs[CATEGORY] segment_ids: list[int] = kwargs[SEGMENT_IDS] - confidence = kwargs["confidence"] + confidence = kwargs[CONFIDENCE] meeting_id = kwargs.get("meeting_id") stripped_text = text.strip() diff --git a/src/noteflow/grpc/mixins/_processing_status.py b/src/noteflow/grpc/mixins/_processing_status.py new file mode 100644 index 0000000..73946d5 --- /dev/null +++ b/src/noteflow/grpc/mixins/_processing_status.py @@ -0,0 +1,166 @@ +"""Helpers for updating meeting post-processing status.""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Literal, cast +from uuid import UUID + +from noteflow.domain.entities.processing import ( + ProcessingStatus, + ProcessingStepState, + ProcessingStepStatus, +) +from noteflow.domain.ports.unit_of_work import UnitOfWork +from noteflow.domain.value_objects import MeetingId +from noteflow.infrastructure.logging import get_logger + +if TYPE_CHECKING: + from noteflow.domain.entities import Meeting + +logger = get_logger(__name__) + +ProcessingStepName = Literal["summary", "entities", "diarization"] + + +def _parse_meeting_id(meeting_id: str) -> MeetingId | None: + try: + return MeetingId(UUID(meeting_id)) + except ValueError: + return None + + +@dataclass(frozen=True, slots=True) +class ProcessingStatusUpdate: + step: ProcessingStepName + status: ProcessingStepStatus + error_message: str | None = None + + +def _build_step_state( + current: ProcessingStepState, + update: ProcessingStatusUpdate, +) -> ProcessingStepState: + if update.status == ProcessingStepStatus.RUNNING: + return ProcessingStepState.running() + if update.status == ProcessingStepStatus.COMPLETED: + return ProcessingStepState.completed(started_at=current.started_at) + if update.status == ProcessingStepStatus.FAILED: + return ProcessingStepState.failed( + update.error_message or "processing failed", + started_at=current.started_at, + ) + if update.status == ProcessingStepStatus.SKIPPED: + return ProcessingStepState.skipped() + return ProcessingStepState.pending() + + +def _apply_step_update( + current: ProcessingStatus, + step: ProcessingStepName, + new_state: ProcessingStepState, +) -> ProcessingStatus: + if step == "summary": + return ProcessingStatus( + summary=new_state, + entities=current.entities, + diarization=current.diarization, + queued_at=current.queued_at, + ) + if step == "entities": + return ProcessingStatus( + summary=current.summary, + entities=new_state, + diarization=current.diarization, + queued_at=current.queued_at, + ) + return ProcessingStatus( + summary=current.summary, + entities=current.entities, + diarization=new_state, + queued_at=current.queued_at, + ) + + +async def _commit_processing_update( + repo: UnitOfWork, + meeting: Meeting, + meeting_id: str, + update: ProcessingStatusUpdate, +) -> bool: + try: + await repo.meetings.update(meeting) + await repo.commit() + return True + except ValueError as exc: + await repo.rollback() + if "modified concurrently" in str(exc): + return False + logger.warning( + "processing_status_update_failed", + meeting_id=meeting_id, + step=update.step, + error=str(exc), + ) + return True + + +async def _apply_processing_update( + repo_provider: Callable[[], object], + parsed_id: MeetingId, + meeting_id: str, + update: ProcessingStatusUpdate, +) -> bool: + async with cast(UnitOfWork, repo_provider()) as repo: + meeting = await repo.meetings.get(parsed_id) + if meeting is None: + logger.warning( + "processing_status_meeting_missing", + meeting_id=meeting_id, + step=update.step, + ) + return True + + current_status = meeting.processing_status or ProcessingStatus.create_pending() + current_step = getattr(current_status, update.step) + new_state = _build_step_state(current_step, update) + meeting.processing_status = _apply_step_update(current_status, update.step, new_state) + + return await _commit_processing_update(repo, meeting, meeting_id, update) + + +async def update_processing_status( + repo_provider: Callable[[], object], + meeting_id: str, + update: ProcessingStatusUpdate, + *, + max_attempts: int = 3, +) -> None: + """Update a meeting's processing status with retry on version conflicts.""" + parsed_id = _parse_meeting_id(meeting_id) + if parsed_id is None: + logger.debug( + "processing_status_meeting_id_invalid", + meeting_id=meeting_id, + step=update.step, + ) + return + + for attempt in range(max_attempts): + completed = await _apply_processing_update( + repo_provider, + parsed_id, + meeting_id, + update, + ) + if completed: + return + if attempt >= max_attempts - 1: + logger.warning( + "processing_status_update_failed", + meeting_id=meeting_id, + step=update.step, + error="retry_limit_exceeded", + ) + return diff --git a/src/noteflow/grpc/mixins/diarization/_execution.py b/src/noteflow/grpc/mixins/diarization/_execution.py index 9c89d0a..c065db7 100644 --- a/src/noteflow/grpc/mixins/diarization/_execution.py +++ b/src/noteflow/grpc/mixins/diarization/_execution.py @@ -14,7 +14,7 @@ from opentelemetry.trace import Span from ._types import DIARIZATION_TIMEOUT_SECONDS if TYPE_CHECKING: - from ._jobs import DiarizationJobContext + from ._job_validation import DiarizationJobContext def _set_diarization_span_attributes(span: Span, ctx: DiarizationJobContext) -> None: diff --git a/src/noteflow/grpc/mixins/diarization/_job_validation.py b/src/noteflow/grpc/mixins/diarization/_job_validation.py new file mode 100644 index 0000000..dce4309 --- /dev/null +++ b/src/noteflow/grpc/mixins/diarization/_job_validation.py @@ -0,0 +1,143 @@ +"""Diarization job validation and error response helpers.""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING, cast +from uuid import UUID + +import grpc + +from noteflow.domain.ports.unit_of_work import UnitOfWork +from noteflow.domain.value_objects import MeetingState +from noteflow.infrastructure.logging import get_logger +from noteflow.infrastructure.persistence.repositories import DiarizationJob + +from ...proto import noteflow_pb2 +from .._types import GrpcStatusContext +from ..converters import parse_meeting_id +from ..errors._constants import INVALID_MEETING_ID_MESSAGE + +if TYPE_CHECKING: + from noteflow.domain.entities import Meeting + + from ..protocols import ServicerHost + +logger = get_logger(__name__) + + +def job_status_name(status: int) -> str: + name_fn = cast(Callable[[int], str], noteflow_pb2.JobStatus.Name) + return name_fn(status) + + +@dataclass(frozen=True) +class GrpcErrorDetails: + """gRPC error context and status for response helpers.""" + + context: GrpcStatusContext + grpc_code: grpc.StatusCode + + +@dataclass(frozen=True, slots=True) +class DiarizationJobContext: + """Context for executing a diarization job. + + Groups job-related parameters to reduce function signature complexity. + """ + + host: ServicerHost + job_id: str + job: DiarizationJob + meeting_id: str + num_speakers: int | None + + +def create_diarization_error_response( + error_message: str, + status: noteflow_pb2.JobStatus | str = noteflow_pb2.JOB_STATUS_FAILED, + *, + error: GrpcErrorDetails | None = None, + job_id: str = "", +) -> noteflow_pb2.RefineSpeakerDiarizationResponse: + """Create error response for RefineSpeakerDiarization. + + Consolidates the 6+ duplicated response construction patterns. + """ + if error is not None: + error.context.set_code(error.grpc_code) + error.context.set_details(error_message) + return noteflow_pb2.RefineSpeakerDiarizationResponse( + segments_updated=0, + speaker_ids=[], + error_message=error_message, + job_id=job_id, + status=status, + ) + + +def validate_diarization_preconditions( + servicer: ServicerHost, + request: noteflow_pb2.RefineSpeakerDiarizationRequest, + context: GrpcStatusContext, +) -> noteflow_pb2.RefineSpeakerDiarizationResponse | None: + """Validate preconditions before starting diarization job.""" + if not servicer.diarization_refinement_enabled: + return create_diarization_error_response("Diarization refinement disabled on server") + + if servicer.diarization_engine is None: + return create_diarization_error_response( + "Diarization not enabled on server", + error=GrpcErrorDetails( + context=context, + grpc_code=grpc.StatusCode.UNAVAILABLE, + ), + ) + + try: + UUID(request.meeting_id) + except ValueError: + return create_diarization_error_response(INVALID_MEETING_ID_MESSAGE) + return None + + +async def load_meeting_for_diarization( + repo: UnitOfWork, + meeting_id: str, +) -> tuple[Meeting | None, noteflow_pb2.RefineSpeakerDiarizationResponse | None]: + """Fetch meeting and validate state for diarization refinement.""" + meeting = await repo.meetings.get(parse_meeting_id(meeting_id)) + if meeting is None: + return None, create_diarization_error_response("Meeting not found") + + valid_states = (MeetingState.STOPPED, MeetingState.COMPLETED, MeetingState.ERROR) + if meeting.state not in valid_states: + return None, create_diarization_error_response( + f"Meeting must be stopped before refinement (state: {meeting.state.name.lower()})" + ) + return meeting, None + + +async def check_active_diarization_job( + repo: UnitOfWork, + meeting_id: str, + context: GrpcStatusContext, +) -> noteflow_pb2.RefineSpeakerDiarizationResponse | None: + """Return error response if a diarization job is already active.""" + if not repo.supports_diarization_jobs: + return None + + active_job = await repo.diarization_jobs.get_active_for_meeting(meeting_id) + if active_job is None: + return None + + return create_diarization_error_response( + f"Diarization already in progress (job: {active_job.job_id})", + status=job_status_name(active_job.status), + error=GrpcErrorDetails( + context=context, + grpc_code=grpc.StatusCode.ALREADY_EXISTS, + ), + job_id=active_job.job_id, + ) diff --git a/src/noteflow/grpc/mixins/diarization/_jobs.py b/src/noteflow/grpc/mixins/diarization/_jobs.py index 9638f88..5b10ebe 100644 --- a/src/noteflow/grpc/mixins/diarization/_jobs.py +++ b/src/noteflow/grpc/mixins/diarization/_jobs.py @@ -3,33 +3,39 @@ from __future__ import annotations import asyncio -from collections.abc import Callable -from dataclasses import dataclass -from typing import TYPE_CHECKING, cast -from uuid import UUID, uuid4 +from typing import TYPE_CHECKING +from uuid import uuid4 import grpc from noteflow.domain.ports.unit_of_work import UnitOfWork +from noteflow.domain.entities.processing import ProcessingStepStatus from noteflow.domain.utils import utc_now -from noteflow.domain.value_objects import MeetingState from noteflow.infrastructure.logging import get_logger, log_state_transition from noteflow.infrastructure.persistence.repositories import DiarizationJob from ...proto import noteflow_pb2 from .._types import GrpcStatusContext -from ..converters import parse_meeting_id -from ..errors._constants import INVALID_MEETING_ID_MESSAGE +from .._processing_status import ProcessingStatusUpdate, update_processing_status from ._execution import execute_diarization +from ._job_validation import ( + DiarizationJobContext, + GrpcErrorDetails, + check_active_diarization_job, + create_diarization_error_response, + job_status_name, + load_meeting_for_diarization, + validate_diarization_preconditions, +) from ._status import JobStatusMixin if TYPE_CHECKING: - from noteflow.domain.entities import Meeting - from ..protocols import ServicerHost logger = get_logger(__name__) +DIARIZATION_DB_REQUIRED = "Diarization requires database support" + def _diarization_task_done_callback( task: asyncio.Task[None], @@ -50,90 +56,6 @@ def _diarization_task_done_callback( ) -def _job_status_name(status: int) -> str: - name_fn = cast(Callable[[int], str], noteflow_pb2.JobStatus.Name) - return name_fn(status) - - -@dataclass(frozen=True) -class GrpcErrorDetails: - """gRPC error context and status for response helpers.""" - - context: GrpcStatusContext - grpc_code: grpc.StatusCode - - -@dataclass(frozen=True, slots=True) -class DiarizationJobContext: - """Context for executing a diarization job. - - Groups job-related parameters to reduce function signature complexity. - """ - - host: ServicerHost - job_id: str - job: DiarizationJob - meeting_id: str - num_speakers: int | None - - -def create_diarization_error_response( - error_message: str, - status: noteflow_pb2.JobStatus | str = noteflow_pb2.JOB_STATUS_FAILED, - *, - error: GrpcErrorDetails | None = None, - job_id: str = "", -) -> noteflow_pb2.RefineSpeakerDiarizationResponse: - """Create error response for RefineSpeakerDiarization. - - Consolidates the 6+ duplicated response construction patterns. - - Args: - error_message: Error message describing the failure. - status: Job status code (default: JOB_STATUS_FAILED). - error: Optional gRPC error details for status code. - job_id: Optional job ID to include in response. - - Returns: - Populated RefineSpeakerDiarizationResponse with error state. - """ - if error is not None: - error.context.set_code(error.grpc_code) - error.context.set_details(error_message) - return noteflow_pb2.RefineSpeakerDiarizationResponse( - segments_updated=0, - speaker_ids=[], - error_message=error_message, - job_id=job_id, - status=status, - ) - - -def _validate_diarization_preconditions( - servicer: ServicerHost, - request: noteflow_pb2.RefineSpeakerDiarizationRequest, - context: GrpcStatusContext, -) -> noteflow_pb2.RefineSpeakerDiarizationResponse | None: - """Validate preconditions before starting diarization job.""" - if not servicer.diarization_refinement_enabled: - return create_diarization_error_response("Diarization refinement disabled on server") - - if servicer.diarization_engine is None: - return create_diarization_error_response( - "Diarization not enabled on server", - error=GrpcErrorDetails( - context=context, - grpc_code=grpc.StatusCode.UNAVAILABLE, - ), - ) - - try: - UUID(request.meeting_id) - except ValueError: - return create_diarization_error_response(INVALID_MEETING_ID_MESSAGE) - return None - - async def _create_and_persist_job( job_id: str, meeting_id: str, @@ -165,47 +87,6 @@ async def _create_and_persist_job( return True -async def _load_meeting_for_diarization( - repo: UnitOfWork, - meeting_id: str, -) -> tuple[Meeting | None, noteflow_pb2.RefineSpeakerDiarizationResponse | None]: - """Fetch meeting and validate state for diarization refinement.""" - meeting = await repo.meetings.get(parse_meeting_id(meeting_id)) - if meeting is None: - return None, create_diarization_error_response("Meeting not found") - - valid_states = (MeetingState.STOPPED, MeetingState.COMPLETED, MeetingState.ERROR) - if meeting.state not in valid_states: - return None, create_diarization_error_response( - f"Meeting must be stopped before refinement (state: {meeting.state.name.lower()})" - ) - return meeting, None - - -async def _check_active_diarization_job( - repo: UnitOfWork, - meeting_id: str, - context: GrpcStatusContext, -) -> noteflow_pb2.RefineSpeakerDiarizationResponse | None: - """Return error response if a diarization job is already active.""" - if not repo.supports_diarization_jobs: - return None - - active_job = await repo.diarization_jobs.get_active_for_meeting(meeting_id) - if active_job is None: - return None - - return create_diarization_error_response( - f"Diarization already in progress (job: {active_job.job_id})", - status=_job_status_name(active_job.status), - error=GrpcErrorDetails( - context=context, - grpc_code=grpc.StatusCode.ALREADY_EXISTS, - ), - job_id=active_job.job_id, - ) - - async def _init_job_for_running( host: ServicerHost, job_id: str, @@ -233,8 +114,8 @@ async def _init_job_for_running( started_at=utc_now(), ) await repo.commit() - transition_from = _job_status_name(old_status) - transition_to = _job_status_name(int(noteflow_pb2.JOB_STATUS_RUNNING)) + transition_from = job_status_name(old_status) + transition_to = job_status_name(int(noteflow_pb2.JOB_STATUS_RUNNING)) log_state_transition( "diarization_job", job_id, @@ -271,6 +152,40 @@ def _schedule_diarization_task( num_speakers=num_speakers, ) + +async def _prepare_diarization_job( + host: ServicerHost, + request: noteflow_pb2.RefineSpeakerDiarizationRequest, + context: GrpcStatusContext, +) -> tuple[str, float | None] | noteflow_pb2.RefineSpeakerDiarizationResponse: + async with host.create_repository_provider() as repo: + meeting, error = await load_meeting_for_diarization(repo, request.meeting_id) + if error is not None: + return error + + active_error = await check_active_diarization_job(repo, request.meeting_id, context) + if active_error is not None: + return active_error + + job_id = str(uuid4()) + persisted = await _create_and_persist_job( + job_id, + request.meeting_id, + meeting.duration_seconds if meeting else None, + repo, + ) + if not persisted: + return create_diarization_error_response( + DIARIZATION_DB_REQUIRED, + error=GrpcErrorDetails( + context=context, + grpc_code=grpc.StatusCode.FAILED_PRECONDITION, + ), + ) + + return job_id, meeting.duration_seconds if meeting else None + + class JobsMixin(JobStatusMixin): """Mixin providing diarization job management.""" @@ -280,35 +195,42 @@ class JobsMixin(JobStatusMixin): context: GrpcStatusContext, ) -> noteflow_pb2.RefineSpeakerDiarizationResponse: """Start a new diarization refinement job.""" - if error := _validate_diarization_preconditions(self, request, context): + if error := validate_diarization_preconditions(self, request, context): + await update_processing_status( + self.create_repository_provider, + request.meeting_id, + ProcessingStatusUpdate( + step="diarization", + status=ProcessingStepStatus.FAILED, + error_message=error.error_message, + ), + ) return error - async with self.create_repository_provider() as repo: - meeting, error = await _load_meeting_for_diarization(repo, request.meeting_id) - if error is not None: - return error - - active_error = await _check_active_diarization_job(repo, request.meeting_id, context) - if active_error is not None: - return active_error - - job_id = str(uuid4()) - persisted = await _create_and_persist_job( - job_id, - request.meeting_id, - meeting.duration_seconds if meeting else None, - repo, - ) - if not persisted: - return create_diarization_error_response( - "Diarization requires database support", - error=GrpcErrorDetails( - context=context, - grpc_code=grpc.StatusCode.FAILED_PRECONDITION, + job_result = await _prepare_diarization_job(self, request, context) + if isinstance(job_result, noteflow_pb2.RefineSpeakerDiarizationResponse): + if job_result.error_message == DIARIZATION_DB_REQUIRED: + await update_processing_status( + self.create_repository_provider, + request.meeting_id, + ProcessingStatusUpdate( + step="diarization", + status=ProcessingStepStatus.FAILED, + error_message=DIARIZATION_DB_REQUIRED, ), ) + return job_result + job_id, _duration = job_result num_speakers = request.num_speakers or None + await update_processing_status( + self.create_repository_provider, + request.meeting_id, + ProcessingStatusUpdate( + step="diarization", + status=ProcessingStepStatus.RUNNING, + ), + ) _schedule_diarization_task(self, job_id, num_speakers, request.meeting_id) return noteflow_pb2.RefineSpeakerDiarizationResponse( diff --git a/src/noteflow/grpc/mixins/diarization/_status.py b/src/noteflow/grpc/mixins/diarization/_status.py index d9e3497..008cc15 100644 --- a/src/noteflow/grpc/mixins/diarization/_status.py +++ b/src/noteflow/grpc/mixins/diarization/_status.py @@ -5,11 +5,13 @@ from __future__ import annotations from collections.abc import Callable from typing import TYPE_CHECKING, cast +from noteflow.domain.entities.processing import ProcessingStepStatus from noteflow.infrastructure.logging import get_logger, log_state_transition from noteflow.infrastructure.persistence.repositories import DiarizationJob from ...proto import noteflow_pb2 from ..errors import ERR_CANCELLED_BY_USER +from .._processing_status import ProcessingStatusUpdate, update_processing_status from ._types import DIARIZATION_TIMEOUT_SECONDS if TYPE_CHECKING: @@ -47,6 +49,15 @@ class JobStatusMixin: speaker_ids=speaker_ids, ) await repo.commit() + if job is not None: + await update_processing_status( + self.create_repository_provider, + job.meeting_id, + ProcessingStatusUpdate( + step="diarization", + status=ProcessingStepStatus.COMPLETED, + ), + ) log_state_transition( "diarization_job", job_id, @@ -81,6 +92,16 @@ class JobStatusMixin: error_message=error_msg, ) await repo.commit() + if meeting_id is not None: + await update_processing_status( + self.create_repository_provider, + meeting_id, + ProcessingStatusUpdate( + step="diarization", + status=ProcessingStepStatus.FAILED, + error_message=error_msg, + ), + ) log_state_transition( "diarization_job", job_id, @@ -109,6 +130,16 @@ class JobStatusMixin: error_message=ERR_CANCELLED_BY_USER, ) await repo.commit() + if meeting_id is not None: + await update_processing_status( + self.create_repository_provider, + meeting_id, + ProcessingStatusUpdate( + step="diarization", + status=ProcessingStepStatus.SKIPPED, + error_message=ERR_CANCELLED_BY_USER, + ), + ) log_state_transition( "diarization_job", job_id, @@ -138,6 +169,16 @@ class JobStatusMixin: error_message=str(exc), ) await repo.commit() + if meeting_id is not None: + await update_processing_status( + self.create_repository_provider, + meeting_id, + ProcessingStatusUpdate( + step="diarization", + status=ProcessingStepStatus.FAILED, + error_message=str(exc), + ), + ) log_state_transition( "diarization_job", job_id, diff --git a/src/noteflow/grpc/mixins/entities.py b/src/noteflow/grpc/mixins/entities.py index efdac06..129d448 100644 --- a/src/noteflow/grpc/mixins/entities.py +++ b/src/noteflow/grpc/mixins/entities.py @@ -4,11 +4,13 @@ from __future__ import annotations from typing import TYPE_CHECKING, cast +from noteflow.domain.entities.processing import ProcessingStepStatus from noteflow.infrastructure.logging import get_logger from ..proto import noteflow_pb2 from ._types import GrpcContext from ._model_status import log_model_status +from ._processing_status import ProcessingStatusUpdate, update_processing_status from .converters import entity_to_proto, parse_meeting_id_or_abort from .errors import ( ENTITY_ENTITY, @@ -25,7 +27,8 @@ from .protocols import EntitiesRepositoryProvider if TYPE_CHECKING: from collections.abc import Callable - from noteflow.application.services.ner import NerService + from noteflow.application.services.ner import ExtractionResult, NerService + from noteflow.domain.value_objects import MeetingId logger = get_logger(__name__) @@ -41,6 +44,77 @@ class EntitiesMixin: ner_service: NerService | None create_repository_provider: Callable[..., object] + async def _mark_entities_step( + self, + meeting_id_str: str, + status: ProcessingStepStatus, + error_message: str | None = None, + ) -> None: + await update_processing_status( + self.create_repository_provider, + meeting_id_str, + ProcessingStatusUpdate( + step="entities", + status=status, + error_message=error_message, + ), + ) + + async def _run_entity_extraction( + self, + meeting_id: MeetingId, + meeting_id_str: str, + force_refresh: bool, + context: GrpcContext, + ) -> ExtractionResult: + ner_service = await require_ner_service(self.ner_service, context) + try: + return await ner_service.extract_entities( + meeting_id=meeting_id, + force_refresh=force_refresh, + ) + except ValueError: + await abort_not_found(context, ENTITY_MEETING, meeting_id_str) + raise + + async def _handle_entities_failure( + self, + meeting_id_str: str, + error: Exception, + context: GrpcContext | None = None, + ) -> None: + await self._mark_entities_step( + meeting_id_str, + ProcessingStepStatus.FAILED, + str(error), + ) + if isinstance(error, RuntimeError) and context is not None: + await abort_failed_precondition(context, str(error)) + + async def _extract_entities_with_status( + self, + meeting_id: MeetingId, + meeting_id_str: str, + force_refresh: bool, + context: GrpcContext, + ) -> ExtractionResult: + await self._mark_entities_step(meeting_id_str, ProcessingStepStatus.RUNNING) + try: + result = await self._run_entity_extraction( + meeting_id, + meeting_id_str, + force_refresh, + context, + ) + except ValueError: + raise + except Exception as exc: + await self._handle_entities_failure(meeting_id_str, exc, context) + raise + + await self._mark_entities_step(meeting_id_str, ProcessingStepStatus.COMPLETED) + return result + async def ExtractEntities( self, request: noteflow_pb2.ExtractEntitiesRequest, @@ -53,21 +127,12 @@ class EntitiesMixin: """ log_model_status(self, "ner_extract_start", meeting_id=request.meeting_id) meeting_id = await parse_meeting_id_or_abort(request.meeting_id, context) - ner_service = await require_ner_service(self.ner_service, context) - - try: - result = await ner_service.extract_entities( - meeting_id=meeting_id, - force_refresh=request.force_refresh, - ) - except ValueError: - # Meeting not found - await abort_not_found(context, ENTITY_MEETING, request.meeting_id) - raise # Unreachable: abort raises - except RuntimeError as e: - # Feature disabled - await abort_failed_precondition(context, str(e)) - raise # Unreachable: abort raises + result = await self._extract_entities_with_status( + meeting_id, + request.meeting_id, + request.force_refresh, + context, + ) # Convert to proto proto_entities = [entity_to_proto(entity) for entity in result.entities] diff --git a/src/noteflow/grpc/mixins/meeting/_post_processing.py b/src/noteflow/grpc/mixins/meeting/_post_processing.py index 4a5c9e2..6e4f4a4 100644 --- a/src/noteflow/grpc/mixins/meeting/_post_processing.py +++ b/src/noteflow/grpc/mixins/meeting/_post_processing.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING from uuid import UUID +from noteflow.domain.entities.processing import ProcessingStatus, ProcessingStepState from noteflow.domain.ports.unit_of_work import UnitOfWork from noteflow.domain.value_objects import MeetingId, MeetingState from noteflow.infrastructure.logging import get_logger, log_state_transition @@ -133,6 +134,7 @@ async def _complete_without_summary( "Post-processing: no segments, completing without summary", meeting_id=meeting_id, ) + _set_summary_processing_status(meeting, ProcessingStepState.skipped()) _complete_meeting(meeting, meeting_id) await repo.meetings.update(meeting) await repo.commit() @@ -140,6 +142,7 @@ async def _complete_without_summary( async def _save_summary_and_complete(context: _SummaryCompletionContext) -> Summary: saved_summary = await context.repo.summaries.save(context.summary) + _set_summary_processing_status(context.meeting, ProcessingStepState.completed()) _complete_meeting(context.meeting, context.meeting_id) await context.repo.meetings.update(context.meeting) await context.repo.commit() @@ -159,6 +162,16 @@ def _complete_meeting(meeting: Meeting, meeting_id: str) -> None: log_state_transition("meeting", meeting_id, previous_state, meeting.state) +def _set_summary_processing_status(meeting: Meeting, step_state: ProcessingStepState) -> None: + current = meeting.processing_status or ProcessingStatus.create_pending() + meeting.processing_status = ProcessingStatus( + summary=step_state, + entities=current.entities, + diarization=current.diarization, + queued_at=current.queued_at, + ) + + async def _trigger_summary_webhook( webhook_service: WebhookService | None, meeting: Meeting, diff --git a/src/noteflow/grpc/mixins/meeting/meeting_mixin.py b/src/noteflow/grpc/mixins/meeting/meeting_mixin.py index 470caaa..d8d5a5e 100644 --- a/src/noteflow/grpc/mixins/meeting/meeting_mixin.py +++ b/src/noteflow/grpc/mixins/meeting/meeting_mixin.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, cast from uuid import UUID from noteflow.config.constants import DEFAULT_MEETING_TITLE -from noteflow.domain.entities import Meeting +from noteflow.domain.entities import Meeting, ProcessingStatus from noteflow.domain.entities.meeting import MeetingCreateParams from noteflow.domain.identity import OperationContext from noteflow.domain.value_objects import MeetingId, MeetingState @@ -91,6 +91,8 @@ async def _stop_meeting_and_persist(context: _StopMeetingContext) -> Meeting: previous_state, context.context, ) + if context.meeting.processing_status is None: + context.meeting.processing_status = ProcessingStatus.create_pending() await context.repo.meetings.update(context.meeting) if context.repo.supports_diarization_jobs: diff --git a/src/noteflow/grpc/mixins/summarization/_generation_mixin.py b/src/noteflow/grpc/mixins/summarization/_generation_mixin.py index bb44d0a..dc08362 100644 --- a/src/noteflow/grpc/mixins/summarization/_generation_mixin.py +++ b/src/noteflow/grpc/mixins/summarization/_generation_mixin.py @@ -2,11 +2,13 @@ from __future__ import annotations +from dataclasses import dataclass from typing import TYPE_CHECKING, cast from sqlalchemy.exc import IntegrityError from noteflow.domain.entities import Segment, Summary +from noteflow.domain.entities.processing import ProcessingStepStatus from noteflow.domain.ports.unit_of_work import UnitOfWork from noteflow.domain.value_objects import MeetingId from noteflow.infrastructure.logging import get_logger @@ -15,6 +17,7 @@ from ...proto import noteflow_pb2 from ...startup.startup import auto_enable_cloud_llm from .._types import GrpcContext from .._model_status import log_model_status +from .._processing_status import ProcessingStatusUpdate, update_processing_status from ..converters import parse_meeting_id_or_abort, summary_to_proto from ..errors import ENTITY_MEETING, abort_not_found from ._summary_generation import generate_placeholder_summary, summarize_or_placeholder @@ -35,6 +38,16 @@ if TYPE_CHECKING: logger = get_logger(__name__) +@dataclass(frozen=True, slots=True) +class _SummaryGenerationContext: + meeting_id: MeetingId + meeting_id_str: str + meeting: Meeting + segments: list[Segment] + style_prompt: str | None + force_regenerate: bool + + class SummarizationGenerationMixin: """Generate summaries and handle summary webhooks.""" @@ -59,12 +72,13 @@ class SummarizationGenerationMixin: await self._prepare_summary_request(request, context) ) if existing and not request.force_regenerate: + await self._mark_summary_step(request.meeting_id, ProcessingStepStatus.COMPLETED) return summary_to_proto(existing) await self._ensure_cloud_provider() async with cast(UnitOfWork, self.create_repository_provider()) as repo: - style_prompt = await resolve_template_prompt( + style_prompt = await self._resolve_style_prompt( TemplateResolutionInputs( request=request, meeting=meeting, @@ -76,23 +90,75 @@ class SummarizationGenerationMixin: summarization_service=self.summarization_service, ) ) - summary = await summarize_or_placeholder( - self.summarization_service, - meeting_id, - segments, - style_prompt, - ) - - saved, trigger_webhook = await self._save_summary( - meeting=meeting, - summary=summary, - force_regenerate=request.force_regenerate, + saved, trigger_webhook = await self._generate_summary_with_status( + _SummaryGenerationContext( + meeting_id=meeting_id, + meeting_id_str=request.meeting_id, + meeting=meeting, + segments=segments, + style_prompt=style_prompt, + force_regenerate=request.force_regenerate, + ) ) if trigger_webhook: await self._trigger_summary_webhook(meeting, saved) return summary_to_proto(saved) + async def _mark_summary_step( + self, + meeting_id_str: str, + status: ProcessingStepStatus, + error_message: str | None = None, + ) -> None: + await update_processing_status( + self.create_repository_provider, + meeting_id_str, + ProcessingStatusUpdate( + step="summary", + status=status, + error_message=error_message, + ), + ) + + async def _resolve_style_prompt(self, inputs: TemplateResolutionInputs) -> str | None: + return await resolve_template_prompt(inputs) + + async def _generate_summary_with_status( + self, + context: _SummaryGenerationContext, + ) -> tuple[Summary, bool]: + meeting_id = context.meeting_id + meeting_id_str = context.meeting_id_str + meeting = context.meeting + segments = context.segments + style_prompt = context.style_prompt + force_regenerate = context.force_regenerate + + await self._mark_summary_step(meeting_id_str, ProcessingStepStatus.RUNNING) + try: + summary = await summarize_or_placeholder( + self.summarization_service, + meeting_id, + segments, + style_prompt, + ) + + saved, trigger_webhook = await self._save_summary( + meeting=meeting, + summary=summary, + force_regenerate=force_regenerate, + ) + await self._mark_summary_step(meeting_id_str, ProcessingStepStatus.COMPLETED) + return saved, trigger_webhook + except Exception as exc: + await self._mark_summary_step( + meeting_id_str, + ProcessingStepStatus.FAILED, + str(exc), + ) + raise + async def _prepare_summary_request( self, request: noteflow_pb2.GenerateSummaryRequest, diff --git a/src/noteflow/infrastructure/converters/ner_converters.py b/src/noteflow/infrastructure/converters/ner_converters.py index 2221d08..dad79ee 100644 --- a/src/noteflow/infrastructure/converters/ner_converters.py +++ b/src/noteflow/infrastructure/converters/ner_converters.py @@ -4,7 +4,15 @@ from __future__ import annotations from typing import TYPE_CHECKING -from noteflow.domain.constants.fields import SEGMENT_IDS +from noteflow.domain.constants.fields import ( + CATEGORY, + CONFIDENCE, + IS_PINNED, + MEETING_ID, + NORMALIZED_TEXT, + SEGMENT_IDS, + TEXT, +) from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity from noteflow.domain.value_objects import MeetingId @@ -60,11 +68,11 @@ class NerConverter: """ return { "id": entity.id, - "meeting_id": entity.meeting_id, - "text": entity.text, - "normalized_text": entity.normalized_text, - "category": entity.category.value, + MEETING_ID: entity.meeting_id, + TEXT: entity.text, + NORMALIZED_TEXT: entity.normalized_text, + CATEGORY: entity.category.value, SEGMENT_IDS: entity.segment_ids, - "confidence": entity.confidence, - "is_pinned": entity.is_pinned, + CONFIDENCE: entity.confidence, + IS_PINNED: entity.is_pinned, } diff --git a/src/noteflow/infrastructure/converters/orm_converters.py b/src/noteflow/infrastructure/converters/orm_converters.py index 6a41d3a..16290a6 100644 --- a/src/noteflow/infrastructure/converters/orm_converters.py +++ b/src/noteflow/infrastructure/converters/orm_converters.py @@ -2,7 +2,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from collections.abc import Mapping +from datetime import datetime +from typing import TYPE_CHECKING, cast from noteflow.domain.constants.fields import END_TIME, START_TIME from noteflow.domain.entities import ( @@ -16,13 +18,17 @@ from noteflow.domain.entities import ( from noteflow.domain.entities import ( WordTiming as DomainWordTiming, ) +from noteflow.domain.entities.processing import ( + ProcessingStatus, + ProcessingStepState, + ProcessingStepStatus, +) from noteflow.domain.value_objects import ( AnnotationId, AnnotationType, MeetingId, MeetingState, ) - if TYPE_CHECKING: from noteflow.infrastructure.persistence.models import ( ActionItemModel, @@ -34,10 +40,108 @@ if TYPE_CHECKING: WordTimingModel, ) - class OrmConverter: """Convert between ORM models and domain entities.""" + @staticmethod + def _parse_datetime(value: object) -> datetime | None: + """Parse ISO datetime strings stored in JSON payloads.""" + if not isinstance(value, str): + return None + try: + return datetime.fromisoformat(value) + except ValueError: + return None + + @staticmethod + def _parse_processing_step_status(value: object) -> ProcessingStepStatus: + """Parse a processing step status from JSON payload values.""" + if isinstance(value, ProcessingStepStatus): + return value + if isinstance(value, str): + try: + return ProcessingStepStatus(value) + except ValueError: + return ProcessingStepStatus.PENDING + return ProcessingStepStatus.PENDING + + @staticmethod + def _processing_step_state_from_payload( + payload: Mapping[str, object] | None, + ) -> ProcessingStepState: + """Convert JSON payload to ProcessingStepState.""" + if payload is None: + return ProcessingStepState.pending() + status = OrmConverter._parse_processing_step_status(payload.get("status")) + error_message_value = payload.get("error_message") + error_message = error_message_value if isinstance(error_message_value, str) else "" + return ProcessingStepState( + status=status, + error_message=error_message, + started_at=OrmConverter._parse_datetime(payload.get("started_at")), + completed_at=OrmConverter._parse_datetime(payload.get("completed_at")), + ) + + @staticmethod + def _coerce_step_payload(value: object) -> Mapping[str, object] | None: + if isinstance(value, Mapping): + return cast(Mapping[str, object], value) + return None + + @staticmethod + def processing_status_from_payload( + payload: Mapping[str, object] | None, + ) -> ProcessingStatus | None: + """Convert JSON payload to ProcessingStatus.""" + if payload is None: + return None + summary = OrmConverter._processing_step_state_from_payload( + OrmConverter._coerce_step_payload(payload.get("summary")), + ) + entities = OrmConverter._processing_step_state_from_payload( + OrmConverter._coerce_step_payload(payload.get("entities")), + ) + diarization = OrmConverter._processing_step_state_from_payload( + OrmConverter._coerce_step_payload(payload.get("diarization")), + ) + return ProcessingStatus( + summary=summary, + entities=entities, + diarization=diarization, + queued_at=OrmConverter._parse_datetime(payload.get("queued_at")), + ) + + @staticmethod + def _processing_step_state_to_payload( + state: ProcessingStepState, + ) -> dict[str, object]: + """Convert ProcessingStepState to JSON payload.""" + payload: dict[str, object] = { + "status": state.status.value, + "error_message": state.error_message, + } + if state.started_at is not None: + payload["started_at"] = state.started_at.isoformat() + if state.completed_at is not None: + payload["completed_at"] = state.completed_at.isoformat() + return payload + + @staticmethod + def processing_status_to_payload( + status: ProcessingStatus | None, + ) -> dict[str, object] | None: + """Convert ProcessingStatus to JSON payload.""" + if status is None: + return None + payload: dict[str, object] = { + "summary": OrmConverter._processing_step_state_to_payload(status.summary), + "entities": OrmConverter._processing_step_state_to_payload(status.entities), + "diarization": OrmConverter._processing_step_state_to_payload(status.diarization), + } + if status.queued_at is not None: + payload["queued_at"] = status.queued_at.isoformat() + return payload + # --- WordTiming --- @staticmethod @@ -113,6 +217,7 @@ class OrmConverter: wrapped_dek=model.wrapped_dek, asset_path=model.asset_path, version=model.version, + processing_status=OrmConverter.processing_status_from_payload(model.processing_status), ) # --- Segment --- diff --git a/src/noteflow/infrastructure/persistence/migrations/versions/t4u5v6w7x8y9_add_meeting_processing_status.py b/src/noteflow/infrastructure/persistence/migrations/versions/t4u5v6w7x8y9_add_meeting_processing_status.py new file mode 100644 index 0000000..64c735b --- /dev/null +++ b/src/noteflow/infrastructure/persistence/migrations/versions/t4u5v6w7x8y9_add_meeting_processing_status.py @@ -0,0 +1,37 @@ +"""add_meeting_processing_status + +Revision ID: t4u5v6w7x8y9 +Revises: s3t4u5v6w7x8 +Create Date: 2026-01-15 00:00:00.000000 + +""" + +from collections.abc import Sequence + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "t4u5v6w7x8y9" +down_revision: str | Sequence[str] | None = "s3t4u5v6w7x8" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Add processing_status column to meetings.""" + op.execute( + """ + ALTER TABLE noteflow.meetings + ADD COLUMN processing_status JSONB; + """ + ) + + +def downgrade() -> None: + """Remove processing_status column from meetings.""" + op.execute( + """ + ALTER TABLE noteflow.meetings + DROP COLUMN processing_status; + """ + ) diff --git a/src/noteflow/infrastructure/persistence/models/core/meeting.py b/src/noteflow/infrastructure/persistence/models/core/meeting.py index 1a71e5c..7d0aba6 100644 --- a/src/noteflow/infrastructure/persistence/models/core/meeting.py +++ b/src/noteflow/infrastructure/persistence/models/core/meeting.py @@ -17,7 +17,7 @@ from sqlalchemy import ( Text, UniqueConstraint, ) -from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy.orm import Mapped, mapped_column, relationship from .._base import DEFAULT_USER_ID, DEFAULT_WORKSPACE_ID, EMBEDDING_DIM, Base @@ -115,6 +115,10 @@ class MeetingModel(UuidPrimaryKeyMixin, CreatedAtMixin, MetadataMixin, Base): Text, nullable=True, ) + processing_status: Mapped[dict[str, object] | None] = mapped_column( + JSONB, + nullable=True, + ) version: Mapped[int] = mapped_column(Integer, nullable=False, default=1) # Soft delete support deleted_at: Mapped[datetime | None] = mapped_column( diff --git a/src/noteflow/infrastructure/persistence/models/entities/named_entity.py b/src/noteflow/infrastructure/persistence/models/entities/named_entity.py index 10f7f49..8fe0444 100644 --- a/src/noteflow/infrastructure/persistence/models/entities/named_entity.py +++ b/src/noteflow/infrastructure/persistence/models/entities/named_entity.py @@ -9,6 +9,8 @@ from sqlalchemy import Boolean, Float, ForeignKey, Integer, String, Text, Unique from sqlalchemy.dialects.postgresql import ARRAY, UUID from sqlalchemy.orm import Mapped, mapped_column, relationship +from noteflow.domain.constants.fields import NORMALIZED_TEXT + from .._base import Base from .._mixins import CreatedAtMixin, UpdatedAtMixin, UuidPrimaryKeyMixin from .._strings import ( @@ -31,7 +33,7 @@ class NamedEntityModel(UuidPrimaryKeyMixin, CreatedAtMixin, UpdatedAtMixin, Base __tablename__ = "named_entities" __table_args__: tuple[UniqueConstraint, dict[str, str]] = ( UniqueConstraint( - "meeting_id", "normalized_text", name="uq_named_entities_meeting_text" + "meeting_id", NORMALIZED_TEXT, name="uq_named_entities_meeting_text" ), {"schema": "noteflow"}, ) diff --git a/src/noteflow/infrastructure/persistence/repositories/entity_repo.py b/src/noteflow/infrastructure/persistence/repositories/entity_repo.py index 8153337..0237fef 100644 --- a/src/noteflow/infrastructure/persistence/repositories/entity_repo.py +++ b/src/noteflow/infrastructure/persistence/repositories/entity_repo.py @@ -6,8 +6,19 @@ from collections.abc import Sequence from typing import TYPE_CHECKING from uuid import UUID -from sqlalchemy import delete, select +from sqlalchemy import delete, func, select +from sqlalchemy.dialects.postgresql import insert +from noteflow.domain.constants.fields import ( + CATEGORY, + CONFIDENCE, + IS_PINNED, + MEETING_ID, + NORMALIZED_TEXT, + SEGMENT_IDS, + TEXT, + UPDATED_AT, +) from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity from noteflow.infrastructure.converters.ner_converters import NerConverter from noteflow.infrastructure.logging import get_logger @@ -43,7 +54,7 @@ class SqlAlchemyEntityRepository( async def save(self, entity: NamedEntity) -> NamedEntity: """Save or update a named entity. - Uses merge to handle both insert and update cases. + Uses an upsert on (meeting_id, normalized_text) to avoid duplicate key errors. Args: entity: The entity to save. @@ -52,10 +63,24 @@ class SqlAlchemyEntityRepository( Saved entity with db_id populated. """ kwargs = NerConverter.to_orm_kwargs(entity) - model = NamedEntityModel(**kwargs) - merged = await self._session.merge(model) - await self._session.flush() - entity.db_id = merged.id + stmt = insert(NamedEntityModel).values(**kwargs) + excluded = stmt.excluded + stmt = stmt.on_conflict_do_update( + index_elements=[MEETING_ID, NORMALIZED_TEXT], + set_={ + TEXT: excluded.text, + NORMALIZED_TEXT: excluded.normalized_text, + CATEGORY: excluded.category, + SEGMENT_IDS: excluded.segment_ids, + CONFIDENCE: excluded.confidence, + IS_PINNED: NamedEntityModel.is_pinned | excluded.is_pinned, + UPDATED_AT: func.now(), + }, + ).returning(NamedEntityModel.id) + result = await self._session.execute(stmt) + row = result.first() + if row is not None: + entity.db_id = row[0] logger.info( "entity_saved", entity_id=str(entity.id), @@ -67,8 +92,7 @@ class SqlAlchemyEntityRepository( async def save_batch(self, entities: Sequence[NamedEntity]) -> Sequence[NamedEntity]: """Save multiple entities efficiently. - Uses individual merges to handle upsert semantics with the unique - constraint on (meeting_id, normalized_text). + Uses an upsert on (meeting_id, normalized_text) to avoid duplicate key errors. Args: entities: List of entities to save. @@ -76,13 +100,38 @@ class SqlAlchemyEntityRepository( Returns: Saved entities with db_ids populated. """ - for entity in entities: - kwargs = NerConverter.to_orm_kwargs(entity) - model = NamedEntityModel(**kwargs) - merged = await self._session.merge(model) - entity.db_id = merged.id + if not entities: + return entities - await self._session.flush() + payload = [NerConverter.to_orm_kwargs(entity) for entity in entities] + stmt = insert(NamedEntityModel).values(payload) + excluded = stmt.excluded + stmt = stmt.on_conflict_do_update( + index_elements=[MEETING_ID, NORMALIZED_TEXT], + set_={ + TEXT: excluded.text, + NORMALIZED_TEXT: excluded.normalized_text, + CATEGORY: excluded.category, + SEGMENT_IDS: excluded.segment_ids, + CONFIDENCE: excluded.confidence, + IS_PINNED: NamedEntityModel.is_pinned | excluded.is_pinned, + UPDATED_AT: func.now(), + }, + ).returning( + NamedEntityModel.id, + NamedEntityModel.meeting_id, + NamedEntityModel.normalized_text, + ) + result = await self._session.execute(stmt) + rows = result.all() + if rows: + ids_by_key = { + (str(row.meeting_id), row.normalized_text): row.id + for row in rows + } + for entity in entities: + key = (str(entity.meeting_id), entity.normalized_text) + entity.db_id = ids_by_key.get(key, entity.db_id) if entities: logger.info( "entities_batch_saved", diff --git a/src/noteflow/infrastructure/persistence/repositories/meeting_repo.py b/src/noteflow/infrastructure/persistence/repositories/meeting_repo.py index fc3875a..97b7bdc 100644 --- a/src/noteflow/infrastructure/persistence/repositories/meeting_repo.py +++ b/src/noteflow/infrastructure/persistence/repositories/meeting_repo.py @@ -45,6 +45,7 @@ class SqlAlchemyMeetingRepository(BaseRepository): metadata_=meeting.metadata, wrapped_dek=meeting.wrapped_dek, asset_path=meeting.asset_path, + processing_status=OrmConverter.processing_status_to_payload(meeting.processing_status), version=meeting.version, ) self._session.add(model) @@ -96,6 +97,9 @@ class SqlAlchemyMeetingRepository(BaseRepository): model.metadata_ = dict(meeting.metadata) model.wrapped_dek = meeting.wrapped_dek model.asset_path = meeting.asset_path + model.processing_status = OrmConverter.processing_status_to_payload( + meeting.processing_status + ) model.version += 1 meeting.version = model.version