refactor: add API initialization memoization and replace initializeAPI with getAPI across hooks
- Added initializePromise singleton to prevent concurrent API initialization attempts and ensure single initialization - Wrapped initializeAPI logic in IIFE that caches promise and clears on error for retry capability - Replaced all initializeAPI() calls with synchronous getAPI() in hooks (useAudioDevices, useDiarization, usePostProcessing, useRecordingAppPolicy) - Updated corresponding test mocks from
This commit is contained in:
@@ -31,6 +31,7 @@ import { startReconnection } from './reconnection';
|
||||
import { initializeTauriAPI } from './tauri-adapter';
|
||||
|
||||
const log = debug('NoteFlowAPI');
|
||||
let initializePromise: Promise<NoteFlowAPI> | null = null;
|
||||
|
||||
// ============================================================================
|
||||
// API Initialization
|
||||
@@ -46,65 +47,78 @@ const log = debug('NoteFlowAPI');
|
||||
* Sprint GAP-009: Event bridge starts before connection to capture early events.
|
||||
*/
|
||||
export async function initializeAPI(): Promise<NoteFlowAPI> {
|
||||
// Always try Tauri first - initializeTauriAPI tests the API and throws if unavailable
|
||||
if (initializePromise) {
|
||||
return initializePromise;
|
||||
}
|
||||
|
||||
initializePromise = (async () => {
|
||||
// Always try Tauri first - initializeTauriAPI tests the API and throws if unavailable
|
||||
try {
|
||||
const tauriAPI = await initializeTauriAPI();
|
||||
setAPIInstance(tauriAPI);
|
||||
|
||||
try {
|
||||
const { invoke } = await import('@tauri-apps/api/core');
|
||||
window.__NOTEFLOW_TEST_INVOKE__ = invoke;
|
||||
} catch (error) {
|
||||
log('Test invoke binding unavailable (expected in non-Tauri contexts)', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
});
|
||||
}
|
||||
|
||||
// Sprint GAP-009: Start event bridge before connection to capture early events
|
||||
// (e.g., connection errors, early warnings). Non-critical if it fails.
|
||||
await startTauriEventBridge().catch((error) => {
|
||||
addClientLog({
|
||||
level: 'warning',
|
||||
source: 'api',
|
||||
message: 'Event bridge initialization failed - continuing without early events',
|
||||
details: error instanceof Error ? error.message : String(error),
|
||||
metadata: { context: 'api_event_bridge_init' },
|
||||
});
|
||||
});
|
||||
|
||||
// Attempt to connect to the gRPC server
|
||||
try {
|
||||
const preferredUrl = preferences.getServerUrl();
|
||||
await tauriAPI.connect(preferredUrl || undefined);
|
||||
setConnectionMode('connected');
|
||||
await preferences.initialize();
|
||||
startReconnection();
|
||||
// Sprint GAP-007: Log successful connection
|
||||
log('Adapter: Tauri | Mode: connected', { server: preferredUrl || 'default' });
|
||||
return tauriAPI;
|
||||
} catch (connectError) {
|
||||
// Connection failed - fall back to cached mode but keep Tauri adapter
|
||||
const message = extractErrorMessage(connectError, 'Connection failed');
|
||||
setConnectionMode('cached', message);
|
||||
await preferences.initialize();
|
||||
startReconnection();
|
||||
// Sprint GAP-007: Log cached mode fallback
|
||||
addClientLog({
|
||||
level: 'warning',
|
||||
source: 'api',
|
||||
message: 'Adapter fallback to cached mode',
|
||||
details: message,
|
||||
metadata: { context: 'api_cached_mode_fallback' },
|
||||
});
|
||||
return tauriAPI; // Keep Tauri adapter for reconnection attempts
|
||||
}
|
||||
} catch (_tauriError) {
|
||||
// Tauri unavailable - use mock API (we're in a browser)
|
||||
setConnectionMode('mock');
|
||||
setAPIInstance(mockAPI);
|
||||
// Sprint GAP-007: Log mock mode
|
||||
log('Adapter: Mock | Mode: mock | Environment: Browser');
|
||||
return mockAPI;
|
||||
}
|
||||
})();
|
||||
|
||||
try {
|
||||
const tauriAPI = await initializeTauriAPI();
|
||||
setAPIInstance(tauriAPI);
|
||||
|
||||
try {
|
||||
const { invoke } = await import('@tauri-apps/api/core');
|
||||
window.__NOTEFLOW_TEST_INVOKE__ = invoke;
|
||||
} catch (error) {
|
||||
log('Test invoke binding unavailable (expected in non-Tauri contexts)', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
});
|
||||
}
|
||||
|
||||
// Sprint GAP-009: Start event bridge before connection to capture early events
|
||||
// (e.g., connection errors, early warnings). Non-critical if it fails.
|
||||
await startTauriEventBridge().catch((error) => {
|
||||
addClientLog({
|
||||
level: 'warning',
|
||||
source: 'api',
|
||||
message: 'Event bridge initialization failed - continuing without early events',
|
||||
details: error instanceof Error ? error.message : String(error),
|
||||
metadata: { context: 'api_event_bridge_init' },
|
||||
});
|
||||
});
|
||||
|
||||
// Attempt to connect to the gRPC server
|
||||
try {
|
||||
const preferredUrl = preferences.getServerUrl();
|
||||
await tauriAPI.connect(preferredUrl || undefined);
|
||||
setConnectionMode('connected');
|
||||
await preferences.initialize();
|
||||
startReconnection();
|
||||
// Sprint GAP-007: Log successful connection
|
||||
log('Adapter: Tauri | Mode: connected', { server: preferredUrl || 'default' });
|
||||
return tauriAPI;
|
||||
} catch (connectError) {
|
||||
// Connection failed - fall back to cached mode but keep Tauri adapter
|
||||
const message = extractErrorMessage(connectError, 'Connection failed');
|
||||
setConnectionMode('cached', message);
|
||||
await preferences.initialize();
|
||||
startReconnection();
|
||||
// Sprint GAP-007: Log cached mode fallback
|
||||
addClientLog({
|
||||
level: 'warning',
|
||||
source: 'api',
|
||||
message: 'Adapter fallback to cached mode',
|
||||
details: message,
|
||||
metadata: { context: 'api_cached_mode_fallback' },
|
||||
});
|
||||
return tauriAPI; // Keep Tauri adapter for reconnection attempts
|
||||
}
|
||||
} catch (_tauriError) {
|
||||
// Tauri unavailable - use mock API (we're in a browser)
|
||||
setConnectionMode('mock');
|
||||
setAPIInstance(mockAPI);
|
||||
// Sprint GAP-007: Log mock mode
|
||||
log('Adapter: Mock | Mode: mock | Environment: Browser');
|
||||
return mockAPI;
|
||||
return await initializePromise;
|
||||
} catch (error) {
|
||||
initializePromise = null;
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -136,6 +136,10 @@ export function shouldAutoStartProcessing(
|
||||
|
||||
// If processing is already complete or in progress, don't restart
|
||||
const { summary, entities, diarization } = processingStatus;
|
||||
const anyFailed =
|
||||
summary.status === 'failed' ||
|
||||
entities.status === 'failed' ||
|
||||
diarization.status === 'failed';
|
||||
const allTerminal =
|
||||
['completed', 'failed', 'skipped'].includes(summary.status) &&
|
||||
['completed', 'failed', 'skipped'].includes(entities.status) &&
|
||||
@@ -147,7 +151,7 @@ export function shouldAutoStartProcessing(
|
||||
diarization.status === 'running';
|
||||
|
||||
// If all done or any running, don't auto-start
|
||||
if (allTerminal || anyRunning) {
|
||||
if (allTerminal || anyRunning || anyFailed) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { act, renderHook } from '@testing-library/react';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { initializeAPI } from '@/api';
|
||||
import { getAPI } from '@/api';
|
||||
import { TauriCommands } from '@/api/constants';
|
||||
import { isTauriEnvironment } from '@/api/tauri-adapter';
|
||||
import { toast } from '@/hooks/use-toast';
|
||||
@@ -9,7 +9,7 @@ import { useTauriEvent } from '@/lib/tauri-events';
|
||||
import { useAudioDevices } from './use-audio-devices';
|
||||
|
||||
vi.mock('@/api', () => ({
|
||||
initializeAPI: vi.fn(),
|
||||
getAPI: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('@/api/tauri-adapter', async () => {
|
||||
@@ -172,7 +172,7 @@ describe('useAudioDevices', () => {
|
||||
|
||||
it('loads devices in tauri mode and syncs selection', async () => {
|
||||
vi.mocked(isTauriEnvironment).mockReturnValue(true);
|
||||
vi.mocked(initializeAPI).mockResolvedValue({
|
||||
vi.mocked(getAPI).mockReturnValue({
|
||||
getPreferences: vi.fn().mockResolvedValue({
|
||||
audio_devices: { input_device_id: '', output_device_id: '' },
|
||||
}),
|
||||
@@ -188,7 +188,7 @@ describe('useAudioDevices', () => {
|
||||
system_gain: 1,
|
||||
}),
|
||||
selectAudioDevice: vi.fn(),
|
||||
} as unknown as Awaited<ReturnType<typeof initializeAPI>>);
|
||||
} as unknown as ReturnType<typeof getAPI>);
|
||||
|
||||
const { result } = renderHook(() => useAudioDevices({ showToasts: false }));
|
||||
|
||||
@@ -218,7 +218,7 @@ describe('useAudioDevices', () => {
|
||||
audio_devices: { input_device_id: 'input:0:Mic', output_device_id: 'output:1:Speakers' },
|
||||
});
|
||||
const selectAudioDevice = vi.fn();
|
||||
vi.mocked(initializeAPI).mockResolvedValue({
|
||||
vi.mocked(getAPI).mockReturnValue({
|
||||
getPreferences: vi.fn().mockResolvedValue({
|
||||
audio_devices: { input_device_id: 'input:0:Mic', output_device_id: 'output:1:Speakers' },
|
||||
}),
|
||||
@@ -234,7 +234,7 @@ describe('useAudioDevices', () => {
|
||||
system_gain: 1,
|
||||
}),
|
||||
selectAudioDevice,
|
||||
} as unknown as Awaited<ReturnType<typeof initializeAPI>>);
|
||||
} as unknown as ReturnType<typeof getAPI>);
|
||||
|
||||
const { result } = renderHook(() => useAudioDevices({ showToasts: false }));
|
||||
|
||||
@@ -387,7 +387,7 @@ describe('useAudioDevices', () => {
|
||||
});
|
||||
|
||||
const selectAudioDevice = vi.fn();
|
||||
vi.mocked(initializeAPI).mockResolvedValue({
|
||||
vi.mocked(getAPI).mockReturnValue({
|
||||
getPreferences: vi.fn().mockImplementation(() =>
|
||||
Promise.resolve({ audio_devices: prefsState.audio_devices })
|
||||
),
|
||||
@@ -403,7 +403,7 @@ describe('useAudioDevices', () => {
|
||||
system_gain: 1,
|
||||
}),
|
||||
selectAudioDevice,
|
||||
} as unknown as Awaited<ReturnType<typeof initializeAPI>>);
|
||||
} as unknown as ReturnType<typeof getAPI>);
|
||||
|
||||
const first = renderHook(() => useAudioDevices({ showToasts: false }));
|
||||
|
||||
@@ -644,10 +644,10 @@ describe('useAudioDevices', () => {
|
||||
it('sets selected devices and syncs tauri selection', async () => {
|
||||
vi.mocked(isTauriEnvironment).mockReturnValue(true);
|
||||
const selectAudioDevice = vi.fn();
|
||||
vi.mocked(initializeAPI).mockResolvedValue({
|
||||
vi.mocked(getAPI).mockReturnValue({
|
||||
listAudioDevices: vi.fn().mockResolvedValue([]),
|
||||
selectAudioDevice,
|
||||
} as unknown as Awaited<ReturnType<typeof initializeAPI>>);
|
||||
} as unknown as ReturnType<typeof getAPI>);
|
||||
|
||||
const { result } = renderHook(() => useAudioDevices({ showToasts: false }));
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { initializeAPI } from '@/api';
|
||||
import { getAPI } from '@/api';
|
||||
import { clientLog } from '@/lib/client-log-events';
|
||||
import { addClientLog } from '@/lib/client-logs';
|
||||
import { isTauriEnvironment } from '@/api/tauri-adapter';
|
||||
@@ -133,7 +133,7 @@ export function useAudioDevices(options: UseAudioDevicesOptions = {}): UseAudioD
|
||||
|
||||
try {
|
||||
if (isTauriEnvironment()) {
|
||||
const api = await initializeAPI();
|
||||
const api = getAPI();
|
||||
let storedPrefs = preferences.get();
|
||||
try {
|
||||
const tauriPrefs = await api.getPreferences();
|
||||
@@ -308,7 +308,7 @@ export function useAudioDevices(options: UseAudioDevicesOptions = {}): UseAudioD
|
||||
});
|
||||
if (isTauriEnvironment()) {
|
||||
try {
|
||||
const api = await initializeAPI();
|
||||
const api = getAPI();
|
||||
await api.selectAudioDevice(deviceId, true);
|
||||
} catch (err) {
|
||||
const errorMessage = err instanceof Error ? err.message : String(err);
|
||||
@@ -340,7 +340,7 @@ export function useAudioDevices(options: UseAudioDevicesOptions = {}): UseAudioD
|
||||
});
|
||||
if (isTauriEnvironment()) {
|
||||
try {
|
||||
const api = await initializeAPI();
|
||||
const api = getAPI();
|
||||
await api.selectAudioDevice(deviceId, false);
|
||||
} catch (err) {
|
||||
const errorMessage = err instanceof Error ? err.message : String(err);
|
||||
@@ -364,7 +364,7 @@ export function useAudioDevices(options: UseAudioDevicesOptions = {}): UseAudioD
|
||||
preferences.setSystemAudioDevice(deviceId);
|
||||
if (isTauriEnvironment()) {
|
||||
try {
|
||||
const api = await initializeAPI();
|
||||
const api = getAPI();
|
||||
await api.setSystemAudioDevice(deviceId);
|
||||
} catch (err) {
|
||||
const errorMessage = err instanceof Error ? err.message : String(err);
|
||||
@@ -388,7 +388,7 @@ export function useAudioDevices(options: UseAudioDevicesOptions = {}): UseAudioD
|
||||
preferences.setDualCaptureEnabled(enabled);
|
||||
if (isTauriEnvironment()) {
|
||||
try {
|
||||
const api = await initializeAPI();
|
||||
const api = getAPI();
|
||||
await api.setDualCaptureEnabled(enabled);
|
||||
if (enabled && !selectedSystemDevice && loopbackDevices.length > 0) {
|
||||
const firstDevice = loopbackDevices[0].deviceId;
|
||||
@@ -418,7 +418,7 @@ export function useAudioDevices(options: UseAudioDevicesOptions = {}): UseAudioD
|
||||
preferences.setAudioMixLevels(newMicGain, newSystemGain);
|
||||
if (isTauriEnvironment()) {
|
||||
try {
|
||||
const api = await initializeAPI();
|
||||
const api = getAPI();
|
||||
await api.setAudioMixLevels(newMicGain, newSystemGain);
|
||||
} catch (err) {
|
||||
if (showToasts) {
|
||||
|
||||
@@ -19,7 +19,7 @@ import { useDiarization } from './use-diarization';
|
||||
|
||||
// Mock the API module
|
||||
vi.mock('@/api', () => ({
|
||||
initializeAPI: vi.fn(),
|
||||
getAPI: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock the toast hook
|
||||
@@ -47,9 +47,7 @@ describe('useDiarization', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers();
|
||||
vi.mocked(api.initializeAPI).mockResolvedValue(
|
||||
mockAPI as unknown as Awaited<ReturnType<typeof api.initializeAPI>>
|
||||
);
|
||||
vi.mocked(api.getAPI).mockReturnValue(mockAPI as unknown as ReturnType<typeof api.getAPI>);
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
*/
|
||||
|
||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { initializeAPI } from '@/api';
|
||||
import { getAPI } from '@/api';
|
||||
import type { DiarizationJobStatus, JobStatus } from '@/api/types';
|
||||
import { toast } from '@/hooks/use-toast';
|
||||
import { PollingConfig } from '@/lib/config';
|
||||
@@ -151,7 +151,7 @@ export function useDiarization(options: UseDiarizationOptions = {}): UseDiarizat
|
||||
}
|
||||
|
||||
try {
|
||||
const api = await initializeAPI();
|
||||
const api = getAPI();
|
||||
const status = await api.getDiarizationJobStatus(jobId);
|
||||
|
||||
if (!isMountedRef.current) {
|
||||
@@ -289,7 +289,7 @@ export function useDiarization(options: UseDiarizationOptions = {}): UseDiarizat
|
||||
retryCountRef.current = 0;
|
||||
|
||||
try {
|
||||
const api = await initializeAPI();
|
||||
const api = getAPI();
|
||||
const status = await api.refineSpeakers(meetingId, numSpeakers);
|
||||
|
||||
if (!isMountedRef.current) {
|
||||
@@ -345,7 +345,7 @@ export function useDiarization(options: UseDiarizationOptions = {}): UseDiarizat
|
||||
stopPolling();
|
||||
|
||||
try {
|
||||
const api = await initializeAPI();
|
||||
const api = getAPI();
|
||||
const result = await api.cancelDiarization(jobId);
|
||||
|
||||
if (!isMountedRef.current) {
|
||||
@@ -396,7 +396,7 @@ export function useDiarization(options: UseDiarizationOptions = {}): UseDiarizat
|
||||
}
|
||||
|
||||
try {
|
||||
const api = await initializeAPI();
|
||||
const api = getAPI();
|
||||
const activeJobs = await api.getActiveDiarizationJobs();
|
||||
|
||||
if (!isMountedRef.current) {
|
||||
|
||||
@@ -23,20 +23,20 @@ import { toast } from '@/hooks/use-toast';
|
||||
import { usePostProcessing } from './use-post-processing';
|
||||
|
||||
// Use vi.hoisted to define mocks that are available in vi.mock
|
||||
const { mockAPI, mockInitializeAPI } = vi.hoisted(() => {
|
||||
const { mockAPI, mockGetAPI } = vi.hoisted(() => {
|
||||
const mockAPI = {
|
||||
generateSummary: vi.fn(),
|
||||
extractEntities: vi.fn(),
|
||||
refineSpeakers: vi.fn(),
|
||||
getDiarizationJobStatus: vi.fn(),
|
||||
};
|
||||
const mockInitializeAPI = vi.fn(() => Promise.resolve(mockAPI));
|
||||
return { mockAPI, mockInitializeAPI };
|
||||
const mockGetAPI = vi.fn(() => mockAPI);
|
||||
return { mockAPI, mockGetAPI };
|
||||
});
|
||||
|
||||
// Mock the API module
|
||||
vi.mock('@/api', () => ({
|
||||
initializeAPI: mockInitializeAPI,
|
||||
getAPI: mockGetAPI,
|
||||
}));
|
||||
|
||||
// Mock the toast hook
|
||||
|
||||
@@ -11,9 +11,10 @@
|
||||
*/
|
||||
|
||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { initializeAPI } from '@/api';
|
||||
import { getAPI } from '@/api';
|
||||
import type { DiarizationJobStatus, MeetingState, ProcessingStatus } from '@/api/types';
|
||||
import { toast } from '@/hooks/use-toast';
|
||||
import { setEntitiesFromExtraction } from '@/lib/entity-store';
|
||||
import { errorMessageFrom, toastError } from '@/lib/error-reporting';
|
||||
import { usePostProcessingEvents } from './post-processing/events';
|
||||
import type {
|
||||
@@ -40,7 +41,6 @@ export type {
|
||||
UsePostProcessingReturn,
|
||||
} from './post-processing/state';
|
||||
|
||||
|
||||
/**
|
||||
* Hook for orchestrating post-processing after recording stops
|
||||
*/
|
||||
@@ -116,7 +116,7 @@ export function usePostProcessing(options: UsePostProcessingOptions = {}): UsePo
|
||||
}
|
||||
|
||||
try {
|
||||
const api = await initializeAPI();
|
||||
const api = getAPI();
|
||||
const status: DiarizationJobStatus = await api.getDiarizationJobStatus(jobId);
|
||||
|
||||
if (!isMountedRef.current) {
|
||||
@@ -201,7 +201,7 @@ export function usePostProcessing(options: UsePostProcessingOptions = {}): UsePo
|
||||
});
|
||||
|
||||
try {
|
||||
const api = await initializeAPI();
|
||||
const api = getAPI();
|
||||
await api.generateSummary(meetingId, false);
|
||||
|
||||
if (!isMountedRef.current) {
|
||||
@@ -253,8 +253,9 @@ export function usePostProcessing(options: UsePostProcessingOptions = {}): UsePo
|
||||
});
|
||||
|
||||
try {
|
||||
const api = await initializeAPI();
|
||||
await api.extractEntities(meetingId, false);
|
||||
const api = getAPI();
|
||||
const response = await api.extractEntities(meetingId, false);
|
||||
setEntitiesFromExtraction(response.entities);
|
||||
|
||||
if (!isMountedRef.current) {
|
||||
return;
|
||||
@@ -305,7 +306,7 @@ export function usePostProcessing(options: UsePostProcessingOptions = {}): UsePo
|
||||
});
|
||||
|
||||
try {
|
||||
const api = await initializeAPI();
|
||||
const api = getAPI();
|
||||
const response = await api.refineSpeakers(meetingId, numSpeakers);
|
||||
|
||||
if (!isMountedRef.current) {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { initializeAPI } from '@/api';
|
||||
import { getAPI } from '@/api';
|
||||
import { isTauriEnvironment } from '@/api/tauri-adapter';
|
||||
import type {
|
||||
AppMatcher,
|
||||
@@ -216,7 +216,7 @@ export function useRecordingAppPolicy(): RecordingAppPolicyState {
|
||||
setIsLoadingApps(true);
|
||||
|
||||
try {
|
||||
const api = await initializeAPI();
|
||||
const api = getAPI();
|
||||
const page = reset ? 0 : currentPageRef.current;
|
||||
const response: ListInstalledAppsResponse = await api.listInstalledApps({
|
||||
commonOnly,
|
||||
|
||||
@@ -19,7 +19,7 @@ import { debug } from '@/lib/debug';
|
||||
|
||||
const log = debug('Preferences:Replace');
|
||||
|
||||
import { clonePreferences } from './core';
|
||||
import { arePreferencesEqual, clonePreferences } from './core';
|
||||
import { defaultPreferences } from './constants';
|
||||
import { mergeIntegrationsWithDefaults, resetIntegrationsForServerSwitch } from './integrations';
|
||||
import {
|
||||
@@ -78,10 +78,6 @@ export const preferences = {
|
||||
},
|
||||
|
||||
replace(prefs: UserPreferences): void {
|
||||
log('Replacing preferences', {
|
||||
input_device_id: prefs.audio_devices?.input_device_id,
|
||||
output_device_id: prefs.audio_devices?.output_device_id,
|
||||
});
|
||||
const mergedAudioDevices = isRecord(prefs.audio_devices)
|
||||
? { ...defaultPreferences.audio_devices, ...prefs.audio_devices }
|
||||
: { ...defaultPreferences.audio_devices };
|
||||
@@ -92,7 +88,16 @@ export const preferences = {
|
||||
audio_devices: mergedAudioDevices,
|
||||
integrations: mergedIntegrations,
|
||||
};
|
||||
savePreferences(applyLocalOverrides(merged), persistPreferencesToTauri);
|
||||
const next = applyLocalOverrides(merged);
|
||||
const current = loadPreferences();
|
||||
if (arePreferencesEqual(current, next)) {
|
||||
return;
|
||||
}
|
||||
log('Replacing preferences', {
|
||||
input_device_id: next.audio_devices?.input_device_id,
|
||||
output_device_id: next.audio_devices?.output_device_id,
|
||||
});
|
||||
savePreferences(next, persistPreferencesToTauri);
|
||||
},
|
||||
|
||||
resetToDefaults(): void {
|
||||
@@ -462,9 +467,7 @@ export const preferences = {
|
||||
setServerAddressOverride(hasAddress, trimmedHost, trimmedPort);
|
||||
});
|
||||
},
|
||||
|
||||
// SYNC SCHEDULER
|
||||
|
||||
isSyncSchedulerPaused(): boolean {
|
||||
return loadPreferences().sync_scheduler_paused;
|
||||
},
|
||||
@@ -474,9 +477,7 @@ export const preferences = {
|
||||
prefs.sync_scheduler_paused = paused;
|
||||
});
|
||||
},
|
||||
|
||||
// SYNC HISTORY
|
||||
|
||||
getSyncHistory(): SyncHistoryEvent[] {
|
||||
return loadPreferences().sync_history || [];
|
||||
},
|
||||
|
||||
@@ -145,3 +145,10 @@ export function normalizePreferences(prefs: UserPreferences): UserPreferences {
|
||||
ai_config: normalizeAiConfig(prefs.ai_config),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Compare preferences payloads after normalization.
|
||||
*/
|
||||
export function arePreferencesEqual(a: UserPreferences, b: UserPreferences): boolean {
|
||||
return JSON.stringify(normalizePreferences(a)) === JSON.stringify(normalizePreferences(b));
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import type { Integration, UserPreferences } from '@/api/types';
|
||||
import { addClientLog } from '@/lib/client-logs';
|
||||
import { debug } from '@/lib/debug';
|
||||
|
||||
import { isTauriRuntime } from './core';
|
||||
import { arePreferencesEqual, isTauriRuntime } from './core';
|
||||
import { loadPreferences } from './storage';
|
||||
import { emitValidationEvent } from './validation-events';
|
||||
|
||||
@@ -241,8 +241,10 @@ export async function hydratePreferencesFromTauri(
|
||||
});
|
||||
|
||||
// Only call replacePrefs if we actually need to update something
|
||||
// This avoids unnecessary writes
|
||||
replacePrefs(mergedPrefs);
|
||||
// This avoids unnecessary writes and persistence churn.
|
||||
if (!arePreferencesEqual(localPrefs, mergedPrefs)) {
|
||||
replacePrefs(mergedPrefs);
|
||||
}
|
||||
|
||||
hasHydratedFromTauri = true;
|
||||
notifyHydrationComplete();
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* Hook for meeting detail data and state.
|
||||
*/
|
||||
|
||||
import { useEffect, useMemo, useState } from 'react';
|
||||
import { useEffect, useMemo, useRef, useState } from 'react';
|
||||
|
||||
import { getAPI } from '@/api';
|
||||
import type { Meeting } from '@/api/types';
|
||||
@@ -25,19 +25,14 @@ export function useMeetingDetail({ meetingId }: UseMeetingDetailProps) {
|
||||
const [selectedSegment, setSelectedSegment] = useState<number | null>(null);
|
||||
const [speakerNameMap, setSpeakerNameMap] = useState<Map<string, string>>(new Map());
|
||||
const { mode } = useConnectionState();
|
||||
const autoStartedRef = useRef<Set<string>>(new Set());
|
||||
const autoStartLogRef = useRef<Set<string>>(new Set());
|
||||
|
||||
// Entity extraction hook
|
||||
const {
|
||||
state: entityState,
|
||||
extract: extractEntities,
|
||||
isExtracting,
|
||||
} = useEntityExtraction({
|
||||
meetingId,
|
||||
meetingTitle: meeting?.title,
|
||||
meetingState: meeting?.state,
|
||||
autoExtract: true,
|
||||
});
|
||||
const { entities } = entityState;
|
||||
const failedStep = meeting?.processing_status
|
||||
? (['summary', 'entities', 'diarization'] as const).find(
|
||||
(step) => meeting.processing_status?.[step]?.status === 'failed'
|
||||
) ?? null
|
||||
: null;
|
||||
|
||||
// Post-processing orchestration hook
|
||||
const {
|
||||
@@ -54,6 +49,26 @@ export function useMeetingDetail({ meetingId }: UseMeetingDetailProps) {
|
||||
}
|
||||
},
|
||||
});
|
||||
const hasLocalPostProcessing = processingState.meetingId === meetingId;
|
||||
const shouldAutoExtract =
|
||||
!!meeting &&
|
||||
!!meetingId &&
|
||||
meeting.state === 'completed' &&
|
||||
!hasLocalPostProcessing &&
|
||||
!shouldAutoStart(meeting.state, meeting.processing_status);
|
||||
|
||||
// Entity extraction hook
|
||||
const {
|
||||
state: entityState,
|
||||
extract: extractEntities,
|
||||
isExtracting,
|
||||
} = useEntityExtraction({
|
||||
meetingId,
|
||||
meetingTitle: meeting?.title,
|
||||
meetingState: meeting?.state,
|
||||
autoExtract: shouldAutoExtract,
|
||||
});
|
||||
const { entities } = entityState;
|
||||
|
||||
// Summary metadata
|
||||
const summaryMeta = useMemo(() => {
|
||||
@@ -128,9 +143,35 @@ export function useMeetingDetail({ meetingId }: UseMeetingDetailProps) {
|
||||
|
||||
// Auto-start post-processing
|
||||
useEffect(() => {
|
||||
if (meeting && meetingId && shouldAutoStart(meeting.state, meeting.processing_status)) {
|
||||
void startProcessing(meetingId);
|
||||
if (!meeting || !meetingId) {
|
||||
return;
|
||||
}
|
||||
if (failedStep && !autoStartLogRef.current.has(meetingId)) {
|
||||
autoStartLogRef.current.add(meetingId);
|
||||
addClientLog({
|
||||
level: 'info',
|
||||
source: 'app',
|
||||
message: 'Post-processing auto-start skipped due to failed step',
|
||||
metadata: {
|
||||
meeting_id: meetingId,
|
||||
failed_step: failedStep,
|
||||
},
|
||||
});
|
||||
}
|
||||
if (!shouldAutoStart(meeting.state, meeting.processing_status)) {
|
||||
return;
|
||||
}
|
||||
if (autoStartedRef.current.has(meetingId)) {
|
||||
return;
|
||||
}
|
||||
autoStartedRef.current.add(meetingId);
|
||||
addClientLog({
|
||||
level: 'info',
|
||||
source: 'app',
|
||||
message: 'Post-processing auto-start triggered',
|
||||
metadata: { meeting_id: meetingId },
|
||||
});
|
||||
void startProcessing(meetingId);
|
||||
}, [meeting, meetingId, shouldAutoStart, startProcessing]);
|
||||
|
||||
const handleGenerateSummary = async () => {
|
||||
|
||||
41
src/noteflow/application/services/ner/_dedupe.py
Normal file
41
src/noteflow/application/services/ner/_dedupe.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Helpers for normalizing extracted entities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from noteflow.domain.entities.named_entity import NamedEntity
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def dedupe_entities(entities: Sequence[NamedEntity]) -> list[NamedEntity]:
|
||||
"""Collapse duplicate entities by normalized text.
|
||||
|
||||
Merges segment IDs and keeps the highest-confidence category.
|
||||
"""
|
||||
deduped: dict[str, NamedEntity] = {}
|
||||
for entity in entities:
|
||||
key = entity.normalized_text or entity.text.lower().strip()
|
||||
if not key:
|
||||
continue
|
||||
existing = deduped.get(key)
|
||||
if existing is None:
|
||||
entity.normalized_text = key
|
||||
deduped[key] = entity
|
||||
continue
|
||||
|
||||
existing.merge_segments(entity.segment_ids)
|
||||
if entity.confidence > existing.confidence:
|
||||
existing.confidence = entity.confidence
|
||||
existing.category = entity.category
|
||||
if len(entity.text) > len(existing.text):
|
||||
existing.text = entity.text
|
||||
if len(deduped) != len(entities):
|
||||
logger.info(
|
||||
"ner_entities_deduped",
|
||||
total=len(entities),
|
||||
deduped=len(deduped),
|
||||
)
|
||||
return list(deduped.values())
|
||||
@@ -7,9 +7,11 @@ Orchestrates NER extraction, caching, and persistence following hexagonal archit
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.application.services.ner._dedupe import dedupe_entities
|
||||
from noteflow.config.constants import ERROR_MSG_MEETING_PREFIX
|
||||
from noteflow.config.settings import get_feature_flags
|
||||
from noteflow.domain.entities.named_entity import NamedEntity
|
||||
@@ -17,7 +19,7 @@ from noteflow.infrastructure.logging import get_logger, log_timing
|
||||
from noteflow.infrastructure.metrics.memory_logger import log_memory_snapshot
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Callable
|
||||
from uuid import UUID
|
||||
|
||||
from noteflow.domain.ports.ner import NerPort
|
||||
@@ -211,6 +213,7 @@ class NerService:
|
||||
entities = await self._extraction_helper.extract(segments)
|
||||
for entity in entities:
|
||||
entity.meeting_id = meeting_id
|
||||
entities = dedupe_entities(entities)
|
||||
|
||||
await self._persist_entities(meeting_id, entities, force_refresh)
|
||||
log_memory_snapshot(
|
||||
|
||||
@@ -14,6 +14,13 @@ PROVIDER: Final[str] = "provider"
|
||||
UNKNOWN: Final[str] = "unknown"
|
||||
DEVICE: Final[Literal["device"]] = "device"
|
||||
SEGMENT_IDS: Final[Literal["segment_ids"]] = "segment_ids"
|
||||
MEETING_ID: Final[Literal["meeting_id"]] = "meeting_id"
|
||||
TEXT: Final[Literal["text"]] = "text"
|
||||
NORMALIZED_TEXT: Final[Literal["normalized_text"]] = "normalized_text"
|
||||
CATEGORY: Final[Literal["category"]] = "category"
|
||||
CONFIDENCE: Final[Literal["confidence"]] = "confidence"
|
||||
IS_PINNED: Final[Literal["is_pinned"]] = "is_pinned"
|
||||
UPDATED_AT: Final[Literal["updated_at"]] = "updated_at"
|
||||
PROJECT_ID: Final[str] = "project_id"
|
||||
PROJECT_IDS: Final[str] = "project_ids"
|
||||
CALENDAR: Final[str] = "calendar"
|
||||
|
||||
@@ -9,7 +9,7 @@ from uuid import UUID, uuid4
|
||||
|
||||
from noteflow.domain.constants.fields import DATE as ENTITY_DATE
|
||||
from noteflow.domain.constants.fields import LOCATION as ENTITY_LOCATION
|
||||
from noteflow.domain.constants.fields import SEGMENT_IDS
|
||||
from noteflow.domain.constants.fields import CATEGORY, CONFIDENCE, SEGMENT_IDS, TEXT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
@@ -116,10 +116,10 @@ class NamedEntity:
|
||||
ValueError: If text is empty or confidence is out of range.
|
||||
"""
|
||||
# Validate required text
|
||||
text = kwargs["text"]
|
||||
category = kwargs["category"]
|
||||
text = kwargs[TEXT]
|
||||
category = kwargs[CATEGORY]
|
||||
segment_ids: list[int] = kwargs[SEGMENT_IDS]
|
||||
confidence = kwargs["confidence"]
|
||||
confidence = kwargs[CONFIDENCE]
|
||||
meeting_id = kwargs.get("meeting_id")
|
||||
|
||||
stripped_text = text.strip()
|
||||
|
||||
166
src/noteflow/grpc/mixins/_processing_status.py
Normal file
166
src/noteflow/grpc/mixins/_processing_status.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Helpers for updating meeting post-processing status."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Literal, cast
|
||||
from uuid import UUID
|
||||
|
||||
from noteflow.domain.entities.processing import (
|
||||
ProcessingStatus,
|
||||
ProcessingStepState,
|
||||
ProcessingStepStatus,
|
||||
)
|
||||
from noteflow.domain.ports.unit_of_work import UnitOfWork
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.domain.entities import Meeting
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
ProcessingStepName = Literal["summary", "entities", "diarization"]
|
||||
|
||||
|
||||
def _parse_meeting_id(meeting_id: str) -> MeetingId | None:
|
||||
try:
|
||||
return MeetingId(UUID(meeting_id))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ProcessingStatusUpdate:
|
||||
step: ProcessingStepName
|
||||
status: ProcessingStepStatus
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
def _build_step_state(
|
||||
current: ProcessingStepState,
|
||||
update: ProcessingStatusUpdate,
|
||||
) -> ProcessingStepState:
|
||||
if update.status == ProcessingStepStatus.RUNNING:
|
||||
return ProcessingStepState.running()
|
||||
if update.status == ProcessingStepStatus.COMPLETED:
|
||||
return ProcessingStepState.completed(started_at=current.started_at)
|
||||
if update.status == ProcessingStepStatus.FAILED:
|
||||
return ProcessingStepState.failed(
|
||||
update.error_message or "processing failed",
|
||||
started_at=current.started_at,
|
||||
)
|
||||
if update.status == ProcessingStepStatus.SKIPPED:
|
||||
return ProcessingStepState.skipped()
|
||||
return ProcessingStepState.pending()
|
||||
|
||||
|
||||
def _apply_step_update(
|
||||
current: ProcessingStatus,
|
||||
step: ProcessingStepName,
|
||||
new_state: ProcessingStepState,
|
||||
) -> ProcessingStatus:
|
||||
if step == "summary":
|
||||
return ProcessingStatus(
|
||||
summary=new_state,
|
||||
entities=current.entities,
|
||||
diarization=current.diarization,
|
||||
queued_at=current.queued_at,
|
||||
)
|
||||
if step == "entities":
|
||||
return ProcessingStatus(
|
||||
summary=current.summary,
|
||||
entities=new_state,
|
||||
diarization=current.diarization,
|
||||
queued_at=current.queued_at,
|
||||
)
|
||||
return ProcessingStatus(
|
||||
summary=current.summary,
|
||||
entities=current.entities,
|
||||
diarization=new_state,
|
||||
queued_at=current.queued_at,
|
||||
)
|
||||
|
||||
|
||||
async def _commit_processing_update(
|
||||
repo: UnitOfWork,
|
||||
meeting: Meeting,
|
||||
meeting_id: str,
|
||||
update: ProcessingStatusUpdate,
|
||||
) -> bool:
|
||||
try:
|
||||
await repo.meetings.update(meeting)
|
||||
await repo.commit()
|
||||
return True
|
||||
except ValueError as exc:
|
||||
await repo.rollback()
|
||||
if "modified concurrently" in str(exc):
|
||||
return False
|
||||
logger.warning(
|
||||
"processing_status_update_failed",
|
||||
meeting_id=meeting_id,
|
||||
step=update.step,
|
||||
error=str(exc),
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
async def _apply_processing_update(
|
||||
repo_provider: Callable[[], object],
|
||||
parsed_id: MeetingId,
|
||||
meeting_id: str,
|
||||
update: ProcessingStatusUpdate,
|
||||
) -> bool:
|
||||
async with cast(UnitOfWork, repo_provider()) as repo:
|
||||
meeting = await repo.meetings.get(parsed_id)
|
||||
if meeting is None:
|
||||
logger.warning(
|
||||
"processing_status_meeting_missing",
|
||||
meeting_id=meeting_id,
|
||||
step=update.step,
|
||||
)
|
||||
return True
|
||||
|
||||
current_status = meeting.processing_status or ProcessingStatus.create_pending()
|
||||
current_step = getattr(current_status, update.step)
|
||||
new_state = _build_step_state(current_step, update)
|
||||
meeting.processing_status = _apply_step_update(current_status, update.step, new_state)
|
||||
|
||||
return await _commit_processing_update(repo, meeting, meeting_id, update)
|
||||
|
||||
|
||||
async def update_processing_status(
|
||||
repo_provider: Callable[[], object],
|
||||
meeting_id: str,
|
||||
update: ProcessingStatusUpdate,
|
||||
*,
|
||||
max_attempts: int = 3,
|
||||
) -> None:
|
||||
"""Update a meeting's processing status with retry on version conflicts."""
|
||||
parsed_id = _parse_meeting_id(meeting_id)
|
||||
if parsed_id is None:
|
||||
logger.debug(
|
||||
"processing_status_meeting_id_invalid",
|
||||
meeting_id=meeting_id,
|
||||
step=update.step,
|
||||
)
|
||||
return
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
completed = await _apply_processing_update(
|
||||
repo_provider,
|
||||
parsed_id,
|
||||
meeting_id,
|
||||
update,
|
||||
)
|
||||
if completed:
|
||||
return
|
||||
if attempt >= max_attempts - 1:
|
||||
logger.warning(
|
||||
"processing_status_update_failed",
|
||||
meeting_id=meeting_id,
|
||||
step=update.step,
|
||||
error="retry_limit_exceeded",
|
||||
)
|
||||
return
|
||||
@@ -14,7 +14,7 @@ from opentelemetry.trace import Span
|
||||
from ._types import DIARIZATION_TIMEOUT_SECONDS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._jobs import DiarizationJobContext
|
||||
from ._job_validation import DiarizationJobContext
|
||||
|
||||
|
||||
def _set_diarization_span_attributes(span: Span, ctx: DiarizationJobContext) -> None:
|
||||
|
||||
143
src/noteflow/grpc/mixins/diarization/_job_validation.py
Normal file
143
src/noteflow/grpc/mixins/diarization/_job_validation.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Diarization job validation and error response helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from uuid import UUID
|
||||
|
||||
import grpc
|
||||
|
||||
from noteflow.domain.ports.unit_of_work import UnitOfWork
|
||||
from noteflow.domain.value_objects import MeetingState
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
from noteflow.infrastructure.persistence.repositories import DiarizationJob
|
||||
|
||||
from ...proto import noteflow_pb2
|
||||
from .._types import GrpcStatusContext
|
||||
from ..converters import parse_meeting_id
|
||||
from ..errors._constants import INVALID_MEETING_ID_MESSAGE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.domain.entities import Meeting
|
||||
|
||||
from ..protocols import ServicerHost
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def job_status_name(status: int) -> str:
|
||||
name_fn = cast(Callable[[int], str], noteflow_pb2.JobStatus.Name)
|
||||
return name_fn(status)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GrpcErrorDetails:
|
||||
"""gRPC error context and status for response helpers."""
|
||||
|
||||
context: GrpcStatusContext
|
||||
grpc_code: grpc.StatusCode
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DiarizationJobContext:
|
||||
"""Context for executing a diarization job.
|
||||
|
||||
Groups job-related parameters to reduce function signature complexity.
|
||||
"""
|
||||
|
||||
host: ServicerHost
|
||||
job_id: str
|
||||
job: DiarizationJob
|
||||
meeting_id: str
|
||||
num_speakers: int | None
|
||||
|
||||
|
||||
def create_diarization_error_response(
|
||||
error_message: str,
|
||||
status: noteflow_pb2.JobStatus | str = noteflow_pb2.JOB_STATUS_FAILED,
|
||||
*,
|
||||
error: GrpcErrorDetails | None = None,
|
||||
job_id: str = "",
|
||||
) -> noteflow_pb2.RefineSpeakerDiarizationResponse:
|
||||
"""Create error response for RefineSpeakerDiarization.
|
||||
|
||||
Consolidates the 6+ duplicated response construction patterns.
|
||||
"""
|
||||
if error is not None:
|
||||
error.context.set_code(error.grpc_code)
|
||||
error.context.set_details(error_message)
|
||||
return noteflow_pb2.RefineSpeakerDiarizationResponse(
|
||||
segments_updated=0,
|
||||
speaker_ids=[],
|
||||
error_message=error_message,
|
||||
job_id=job_id,
|
||||
status=status,
|
||||
)
|
||||
|
||||
|
||||
def validate_diarization_preconditions(
|
||||
servicer: ServicerHost,
|
||||
request: noteflow_pb2.RefineSpeakerDiarizationRequest,
|
||||
context: GrpcStatusContext,
|
||||
) -> noteflow_pb2.RefineSpeakerDiarizationResponse | None:
|
||||
"""Validate preconditions before starting diarization job."""
|
||||
if not servicer.diarization_refinement_enabled:
|
||||
return create_diarization_error_response("Diarization refinement disabled on server")
|
||||
|
||||
if servicer.diarization_engine is None:
|
||||
return create_diarization_error_response(
|
||||
"Diarization not enabled on server",
|
||||
error=GrpcErrorDetails(
|
||||
context=context,
|
||||
grpc_code=grpc.StatusCode.UNAVAILABLE,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
UUID(request.meeting_id)
|
||||
except ValueError:
|
||||
return create_diarization_error_response(INVALID_MEETING_ID_MESSAGE)
|
||||
return None
|
||||
|
||||
|
||||
async def load_meeting_for_diarization(
|
||||
repo: UnitOfWork,
|
||||
meeting_id: str,
|
||||
) -> tuple[Meeting | None, noteflow_pb2.RefineSpeakerDiarizationResponse | None]:
|
||||
"""Fetch meeting and validate state for diarization refinement."""
|
||||
meeting = await repo.meetings.get(parse_meeting_id(meeting_id))
|
||||
if meeting is None:
|
||||
return None, create_diarization_error_response("Meeting not found")
|
||||
|
||||
valid_states = (MeetingState.STOPPED, MeetingState.COMPLETED, MeetingState.ERROR)
|
||||
if meeting.state not in valid_states:
|
||||
return None, create_diarization_error_response(
|
||||
f"Meeting must be stopped before refinement (state: {meeting.state.name.lower()})"
|
||||
)
|
||||
return meeting, None
|
||||
|
||||
|
||||
async def check_active_diarization_job(
|
||||
repo: UnitOfWork,
|
||||
meeting_id: str,
|
||||
context: GrpcStatusContext,
|
||||
) -> noteflow_pb2.RefineSpeakerDiarizationResponse | None:
|
||||
"""Return error response if a diarization job is already active."""
|
||||
if not repo.supports_diarization_jobs:
|
||||
return None
|
||||
|
||||
active_job = await repo.diarization_jobs.get_active_for_meeting(meeting_id)
|
||||
if active_job is None:
|
||||
return None
|
||||
|
||||
return create_diarization_error_response(
|
||||
f"Diarization already in progress (job: {active_job.job_id})",
|
||||
status=job_status_name(active_job.status),
|
||||
error=GrpcErrorDetails(
|
||||
context=context,
|
||||
grpc_code=grpc.StatusCode.ALREADY_EXISTS,
|
||||
),
|
||||
job_id=active_job.job_id,
|
||||
)
|
||||
@@ -3,33 +3,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from uuid import UUID, uuid4
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
|
||||
import grpc
|
||||
|
||||
from noteflow.domain.ports.unit_of_work import UnitOfWork
|
||||
from noteflow.domain.entities.processing import ProcessingStepStatus
|
||||
from noteflow.domain.utils import utc_now
|
||||
from noteflow.domain.value_objects import MeetingState
|
||||
from noteflow.infrastructure.logging import get_logger, log_state_transition
|
||||
from noteflow.infrastructure.persistence.repositories import DiarizationJob
|
||||
|
||||
from ...proto import noteflow_pb2
|
||||
from .._types import GrpcStatusContext
|
||||
from ..converters import parse_meeting_id
|
||||
from ..errors._constants import INVALID_MEETING_ID_MESSAGE
|
||||
from .._processing_status import ProcessingStatusUpdate, update_processing_status
|
||||
from ._execution import execute_diarization
|
||||
from ._job_validation import (
|
||||
DiarizationJobContext,
|
||||
GrpcErrorDetails,
|
||||
check_active_diarization_job,
|
||||
create_diarization_error_response,
|
||||
job_status_name,
|
||||
load_meeting_for_diarization,
|
||||
validate_diarization_preconditions,
|
||||
)
|
||||
from ._status import JobStatusMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.domain.entities import Meeting
|
||||
|
||||
from ..protocols import ServicerHost
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
DIARIZATION_DB_REQUIRED = "Diarization requires database support"
|
||||
|
||||
|
||||
def _diarization_task_done_callback(
|
||||
task: asyncio.Task[None],
|
||||
@@ -50,90 +56,6 @@ def _diarization_task_done_callback(
|
||||
)
|
||||
|
||||
|
||||
def _job_status_name(status: int) -> str:
|
||||
name_fn = cast(Callable[[int], str], noteflow_pb2.JobStatus.Name)
|
||||
return name_fn(status)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GrpcErrorDetails:
|
||||
"""gRPC error context and status for response helpers."""
|
||||
|
||||
context: GrpcStatusContext
|
||||
grpc_code: grpc.StatusCode
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DiarizationJobContext:
|
||||
"""Context for executing a diarization job.
|
||||
|
||||
Groups job-related parameters to reduce function signature complexity.
|
||||
"""
|
||||
|
||||
host: ServicerHost
|
||||
job_id: str
|
||||
job: DiarizationJob
|
||||
meeting_id: str
|
||||
num_speakers: int | None
|
||||
|
||||
|
||||
def create_diarization_error_response(
|
||||
error_message: str,
|
||||
status: noteflow_pb2.JobStatus | str = noteflow_pb2.JOB_STATUS_FAILED,
|
||||
*,
|
||||
error: GrpcErrorDetails | None = None,
|
||||
job_id: str = "",
|
||||
) -> noteflow_pb2.RefineSpeakerDiarizationResponse:
|
||||
"""Create error response for RefineSpeakerDiarization.
|
||||
|
||||
Consolidates the 6+ duplicated response construction patterns.
|
||||
|
||||
Args:
|
||||
error_message: Error message describing the failure.
|
||||
status: Job status code (default: JOB_STATUS_FAILED).
|
||||
error: Optional gRPC error details for status code.
|
||||
job_id: Optional job ID to include in response.
|
||||
|
||||
Returns:
|
||||
Populated RefineSpeakerDiarizationResponse with error state.
|
||||
"""
|
||||
if error is not None:
|
||||
error.context.set_code(error.grpc_code)
|
||||
error.context.set_details(error_message)
|
||||
return noteflow_pb2.RefineSpeakerDiarizationResponse(
|
||||
segments_updated=0,
|
||||
speaker_ids=[],
|
||||
error_message=error_message,
|
||||
job_id=job_id,
|
||||
status=status,
|
||||
)
|
||||
|
||||
|
||||
def _validate_diarization_preconditions(
|
||||
servicer: ServicerHost,
|
||||
request: noteflow_pb2.RefineSpeakerDiarizationRequest,
|
||||
context: GrpcStatusContext,
|
||||
) -> noteflow_pb2.RefineSpeakerDiarizationResponse | None:
|
||||
"""Validate preconditions before starting diarization job."""
|
||||
if not servicer.diarization_refinement_enabled:
|
||||
return create_diarization_error_response("Diarization refinement disabled on server")
|
||||
|
||||
if servicer.diarization_engine is None:
|
||||
return create_diarization_error_response(
|
||||
"Diarization not enabled on server",
|
||||
error=GrpcErrorDetails(
|
||||
context=context,
|
||||
grpc_code=grpc.StatusCode.UNAVAILABLE,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
UUID(request.meeting_id)
|
||||
except ValueError:
|
||||
return create_diarization_error_response(INVALID_MEETING_ID_MESSAGE)
|
||||
return None
|
||||
|
||||
|
||||
async def _create_and_persist_job(
|
||||
job_id: str,
|
||||
meeting_id: str,
|
||||
@@ -165,47 +87,6 @@ async def _create_and_persist_job(
|
||||
return True
|
||||
|
||||
|
||||
async def _load_meeting_for_diarization(
|
||||
repo: UnitOfWork,
|
||||
meeting_id: str,
|
||||
) -> tuple[Meeting | None, noteflow_pb2.RefineSpeakerDiarizationResponse | None]:
|
||||
"""Fetch meeting and validate state for diarization refinement."""
|
||||
meeting = await repo.meetings.get(parse_meeting_id(meeting_id))
|
||||
if meeting is None:
|
||||
return None, create_diarization_error_response("Meeting not found")
|
||||
|
||||
valid_states = (MeetingState.STOPPED, MeetingState.COMPLETED, MeetingState.ERROR)
|
||||
if meeting.state not in valid_states:
|
||||
return None, create_diarization_error_response(
|
||||
f"Meeting must be stopped before refinement (state: {meeting.state.name.lower()})"
|
||||
)
|
||||
return meeting, None
|
||||
|
||||
|
||||
async def _check_active_diarization_job(
|
||||
repo: UnitOfWork,
|
||||
meeting_id: str,
|
||||
context: GrpcStatusContext,
|
||||
) -> noteflow_pb2.RefineSpeakerDiarizationResponse | None:
|
||||
"""Return error response if a diarization job is already active."""
|
||||
if not repo.supports_diarization_jobs:
|
||||
return None
|
||||
|
||||
active_job = await repo.diarization_jobs.get_active_for_meeting(meeting_id)
|
||||
if active_job is None:
|
||||
return None
|
||||
|
||||
return create_diarization_error_response(
|
||||
f"Diarization already in progress (job: {active_job.job_id})",
|
||||
status=_job_status_name(active_job.status),
|
||||
error=GrpcErrorDetails(
|
||||
context=context,
|
||||
grpc_code=grpc.StatusCode.ALREADY_EXISTS,
|
||||
),
|
||||
job_id=active_job.job_id,
|
||||
)
|
||||
|
||||
|
||||
async def _init_job_for_running(
|
||||
host: ServicerHost,
|
||||
job_id: str,
|
||||
@@ -233,8 +114,8 @@ async def _init_job_for_running(
|
||||
started_at=utc_now(),
|
||||
)
|
||||
await repo.commit()
|
||||
transition_from = _job_status_name(old_status)
|
||||
transition_to = _job_status_name(int(noteflow_pb2.JOB_STATUS_RUNNING))
|
||||
transition_from = job_status_name(old_status)
|
||||
transition_to = job_status_name(int(noteflow_pb2.JOB_STATUS_RUNNING))
|
||||
log_state_transition(
|
||||
"diarization_job",
|
||||
job_id,
|
||||
@@ -271,6 +152,40 @@ def _schedule_diarization_task(
|
||||
num_speakers=num_speakers,
|
||||
)
|
||||
|
||||
|
||||
async def _prepare_diarization_job(
|
||||
host: ServicerHost,
|
||||
request: noteflow_pb2.RefineSpeakerDiarizationRequest,
|
||||
context: GrpcStatusContext,
|
||||
) -> tuple[str, float | None] | noteflow_pb2.RefineSpeakerDiarizationResponse:
|
||||
async with host.create_repository_provider() as repo:
|
||||
meeting, error = await load_meeting_for_diarization(repo, request.meeting_id)
|
||||
if error is not None:
|
||||
return error
|
||||
|
||||
active_error = await check_active_diarization_job(repo, request.meeting_id, context)
|
||||
if active_error is not None:
|
||||
return active_error
|
||||
|
||||
job_id = str(uuid4())
|
||||
persisted = await _create_and_persist_job(
|
||||
job_id,
|
||||
request.meeting_id,
|
||||
meeting.duration_seconds if meeting else None,
|
||||
repo,
|
||||
)
|
||||
if not persisted:
|
||||
return create_diarization_error_response(
|
||||
DIARIZATION_DB_REQUIRED,
|
||||
error=GrpcErrorDetails(
|
||||
context=context,
|
||||
grpc_code=grpc.StatusCode.FAILED_PRECONDITION,
|
||||
),
|
||||
)
|
||||
|
||||
return job_id, meeting.duration_seconds if meeting else None
|
||||
|
||||
|
||||
class JobsMixin(JobStatusMixin):
|
||||
"""Mixin providing diarization job management."""
|
||||
|
||||
@@ -280,35 +195,42 @@ class JobsMixin(JobStatusMixin):
|
||||
context: GrpcStatusContext,
|
||||
) -> noteflow_pb2.RefineSpeakerDiarizationResponse:
|
||||
"""Start a new diarization refinement job."""
|
||||
if error := _validate_diarization_preconditions(self, request, context):
|
||||
if error := validate_diarization_preconditions(self, request, context):
|
||||
await update_processing_status(
|
||||
self.create_repository_provider,
|
||||
request.meeting_id,
|
||||
ProcessingStatusUpdate(
|
||||
step="diarization",
|
||||
status=ProcessingStepStatus.FAILED,
|
||||
error_message=error.error_message,
|
||||
),
|
||||
)
|
||||
return error
|
||||
|
||||
async with self.create_repository_provider() as repo:
|
||||
meeting, error = await _load_meeting_for_diarization(repo, request.meeting_id)
|
||||
if error is not None:
|
||||
return error
|
||||
|
||||
active_error = await _check_active_diarization_job(repo, request.meeting_id, context)
|
||||
if active_error is not None:
|
||||
return active_error
|
||||
|
||||
job_id = str(uuid4())
|
||||
persisted = await _create_and_persist_job(
|
||||
job_id,
|
||||
request.meeting_id,
|
||||
meeting.duration_seconds if meeting else None,
|
||||
repo,
|
||||
)
|
||||
if not persisted:
|
||||
return create_diarization_error_response(
|
||||
"Diarization requires database support",
|
||||
error=GrpcErrorDetails(
|
||||
context=context,
|
||||
grpc_code=grpc.StatusCode.FAILED_PRECONDITION,
|
||||
job_result = await _prepare_diarization_job(self, request, context)
|
||||
if isinstance(job_result, noteflow_pb2.RefineSpeakerDiarizationResponse):
|
||||
if job_result.error_message == DIARIZATION_DB_REQUIRED:
|
||||
await update_processing_status(
|
||||
self.create_repository_provider,
|
||||
request.meeting_id,
|
||||
ProcessingStatusUpdate(
|
||||
step="diarization",
|
||||
status=ProcessingStepStatus.FAILED,
|
||||
error_message=DIARIZATION_DB_REQUIRED,
|
||||
),
|
||||
)
|
||||
return job_result
|
||||
|
||||
job_id, _duration = job_result
|
||||
num_speakers = request.num_speakers or None
|
||||
await update_processing_status(
|
||||
self.create_repository_provider,
|
||||
request.meeting_id,
|
||||
ProcessingStatusUpdate(
|
||||
step="diarization",
|
||||
status=ProcessingStepStatus.RUNNING,
|
||||
),
|
||||
)
|
||||
_schedule_diarization_task(self, job_id, num_speakers, request.meeting_id)
|
||||
|
||||
return noteflow_pb2.RefineSpeakerDiarizationResponse(
|
||||
|
||||
@@ -5,11 +5,13 @@ from __future__ import annotations
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from noteflow.domain.entities.processing import ProcessingStepStatus
|
||||
from noteflow.infrastructure.logging import get_logger, log_state_transition
|
||||
from noteflow.infrastructure.persistence.repositories import DiarizationJob
|
||||
|
||||
from ...proto import noteflow_pb2
|
||||
from ..errors import ERR_CANCELLED_BY_USER
|
||||
from .._processing_status import ProcessingStatusUpdate, update_processing_status
|
||||
from ._types import DIARIZATION_TIMEOUT_SECONDS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -47,6 +49,15 @@ class JobStatusMixin:
|
||||
speaker_ids=speaker_ids,
|
||||
)
|
||||
await repo.commit()
|
||||
if job is not None:
|
||||
await update_processing_status(
|
||||
self.create_repository_provider,
|
||||
job.meeting_id,
|
||||
ProcessingStatusUpdate(
|
||||
step="diarization",
|
||||
status=ProcessingStepStatus.COMPLETED,
|
||||
),
|
||||
)
|
||||
log_state_transition(
|
||||
"diarization_job",
|
||||
job_id,
|
||||
@@ -81,6 +92,16 @@ class JobStatusMixin:
|
||||
error_message=error_msg,
|
||||
)
|
||||
await repo.commit()
|
||||
if meeting_id is not None:
|
||||
await update_processing_status(
|
||||
self.create_repository_provider,
|
||||
meeting_id,
|
||||
ProcessingStatusUpdate(
|
||||
step="diarization",
|
||||
status=ProcessingStepStatus.FAILED,
|
||||
error_message=error_msg,
|
||||
),
|
||||
)
|
||||
log_state_transition(
|
||||
"diarization_job",
|
||||
job_id,
|
||||
@@ -109,6 +130,16 @@ class JobStatusMixin:
|
||||
error_message=ERR_CANCELLED_BY_USER,
|
||||
)
|
||||
await repo.commit()
|
||||
if meeting_id is not None:
|
||||
await update_processing_status(
|
||||
self.create_repository_provider,
|
||||
meeting_id,
|
||||
ProcessingStatusUpdate(
|
||||
step="diarization",
|
||||
status=ProcessingStepStatus.SKIPPED,
|
||||
error_message=ERR_CANCELLED_BY_USER,
|
||||
),
|
||||
)
|
||||
log_state_transition(
|
||||
"diarization_job",
|
||||
job_id,
|
||||
@@ -138,6 +169,16 @@ class JobStatusMixin:
|
||||
error_message=str(exc),
|
||||
)
|
||||
await repo.commit()
|
||||
if meeting_id is not None:
|
||||
await update_processing_status(
|
||||
self.create_repository_provider,
|
||||
meeting_id,
|
||||
ProcessingStatusUpdate(
|
||||
step="diarization",
|
||||
status=ProcessingStepStatus.FAILED,
|
||||
error_message=str(exc),
|
||||
),
|
||||
)
|
||||
log_state_transition(
|
||||
"diarization_job",
|
||||
job_id,
|
||||
|
||||
@@ -4,11 +4,13 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from noteflow.domain.entities.processing import ProcessingStepStatus
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
from ..proto import noteflow_pb2
|
||||
from ._types import GrpcContext
|
||||
from ._model_status import log_model_status
|
||||
from ._processing_status import ProcessingStatusUpdate, update_processing_status
|
||||
from .converters import entity_to_proto, parse_meeting_id_or_abort
|
||||
from .errors import (
|
||||
ENTITY_ENTITY,
|
||||
@@ -25,7 +27,8 @@ from .protocols import EntitiesRepositoryProvider
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from noteflow.application.services.ner import NerService
|
||||
from noteflow.application.services.ner import ExtractionResult, NerService
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -41,6 +44,77 @@ class EntitiesMixin:
|
||||
ner_service: NerService | None
|
||||
create_repository_provider: Callable[..., object]
|
||||
|
||||
async def _mark_entities_step(
|
||||
self,
|
||||
meeting_id_str: str,
|
||||
status: ProcessingStepStatus,
|
||||
error_message: str | None = None,
|
||||
) -> None:
|
||||
await update_processing_status(
|
||||
self.create_repository_provider,
|
||||
meeting_id_str,
|
||||
ProcessingStatusUpdate(
|
||||
step="entities",
|
||||
status=status,
|
||||
error_message=error_message,
|
||||
),
|
||||
)
|
||||
|
||||
async def _run_entity_extraction(
|
||||
self,
|
||||
meeting_id: MeetingId,
|
||||
meeting_id_str: str,
|
||||
force_refresh: bool,
|
||||
context: GrpcContext,
|
||||
) -> ExtractionResult:
|
||||
ner_service = await require_ner_service(self.ner_service, context)
|
||||
try:
|
||||
return await ner_service.extract_entities(
|
||||
meeting_id=meeting_id,
|
||||
force_refresh=force_refresh,
|
||||
)
|
||||
except ValueError:
|
||||
await abort_not_found(context, ENTITY_MEETING, meeting_id_str)
|
||||
raise
|
||||
|
||||
async def _handle_entities_failure(
|
||||
self,
|
||||
meeting_id_str: str,
|
||||
error: Exception,
|
||||
context: GrpcContext | None = None,
|
||||
) -> None:
|
||||
await self._mark_entities_step(
|
||||
meeting_id_str,
|
||||
ProcessingStepStatus.FAILED,
|
||||
str(error),
|
||||
)
|
||||
if isinstance(error, RuntimeError) and context is not None:
|
||||
await abort_failed_precondition(context, str(error))
|
||||
|
||||
async def _extract_entities_with_status(
|
||||
self,
|
||||
meeting_id: MeetingId,
|
||||
meeting_id_str: str,
|
||||
force_refresh: bool,
|
||||
context: GrpcContext,
|
||||
) -> ExtractionResult:
|
||||
await self._mark_entities_step(meeting_id_str, ProcessingStepStatus.RUNNING)
|
||||
try:
|
||||
result = await self._run_entity_extraction(
|
||||
meeting_id,
|
||||
meeting_id_str,
|
||||
force_refresh,
|
||||
context,
|
||||
)
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
await self._handle_entities_failure(meeting_id_str, exc, context)
|
||||
raise
|
||||
|
||||
await self._mark_entities_step(meeting_id_str, ProcessingStepStatus.COMPLETED)
|
||||
return result
|
||||
|
||||
async def ExtractEntities(
|
||||
self,
|
||||
request: noteflow_pb2.ExtractEntitiesRequest,
|
||||
@@ -53,21 +127,12 @@ class EntitiesMixin:
|
||||
"""
|
||||
log_model_status(self, "ner_extract_start", meeting_id=request.meeting_id)
|
||||
meeting_id = await parse_meeting_id_or_abort(request.meeting_id, context)
|
||||
ner_service = await require_ner_service(self.ner_service, context)
|
||||
|
||||
try:
|
||||
result = await ner_service.extract_entities(
|
||||
meeting_id=meeting_id,
|
||||
force_refresh=request.force_refresh,
|
||||
)
|
||||
except ValueError:
|
||||
# Meeting not found
|
||||
await abort_not_found(context, ENTITY_MEETING, request.meeting_id)
|
||||
raise # Unreachable: abort raises
|
||||
except RuntimeError as e:
|
||||
# Feature disabled
|
||||
await abort_failed_precondition(context, str(e))
|
||||
raise # Unreachable: abort raises
|
||||
result = await self._extract_entities_with_status(
|
||||
meeting_id,
|
||||
request.meeting_id,
|
||||
request.force_refresh,
|
||||
context,
|
||||
)
|
||||
|
||||
# Convert to proto
|
||||
proto_entities = [entity_to_proto(entity) for entity in result.entities]
|
||||
|
||||
@@ -7,6 +7,7 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from noteflow.domain.entities.processing import ProcessingStatus, ProcessingStepState
|
||||
from noteflow.domain.ports.unit_of_work import UnitOfWork
|
||||
from noteflow.domain.value_objects import MeetingId, MeetingState
|
||||
from noteflow.infrastructure.logging import get_logger, log_state_transition
|
||||
@@ -133,6 +134,7 @@ async def _complete_without_summary(
|
||||
"Post-processing: no segments, completing without summary",
|
||||
meeting_id=meeting_id,
|
||||
)
|
||||
_set_summary_processing_status(meeting, ProcessingStepState.skipped())
|
||||
_complete_meeting(meeting, meeting_id)
|
||||
await repo.meetings.update(meeting)
|
||||
await repo.commit()
|
||||
@@ -140,6 +142,7 @@ async def _complete_without_summary(
|
||||
|
||||
async def _save_summary_and_complete(context: _SummaryCompletionContext) -> Summary:
|
||||
saved_summary = await context.repo.summaries.save(context.summary)
|
||||
_set_summary_processing_status(context.meeting, ProcessingStepState.completed())
|
||||
_complete_meeting(context.meeting, context.meeting_id)
|
||||
await context.repo.meetings.update(context.meeting)
|
||||
await context.repo.commit()
|
||||
@@ -159,6 +162,16 @@ def _complete_meeting(meeting: Meeting, meeting_id: str) -> None:
|
||||
log_state_transition("meeting", meeting_id, previous_state, meeting.state)
|
||||
|
||||
|
||||
def _set_summary_processing_status(meeting: Meeting, step_state: ProcessingStepState) -> None:
|
||||
current = meeting.processing_status or ProcessingStatus.create_pending()
|
||||
meeting.processing_status = ProcessingStatus(
|
||||
summary=step_state,
|
||||
entities=current.entities,
|
||||
diarization=current.diarization,
|
||||
queued_at=current.queued_at,
|
||||
)
|
||||
|
||||
|
||||
async def _trigger_summary_webhook(
|
||||
webhook_service: WebhookService | None,
|
||||
meeting: Meeting,
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, cast
|
||||
from uuid import UUID
|
||||
|
||||
from noteflow.config.constants import DEFAULT_MEETING_TITLE
|
||||
from noteflow.domain.entities import Meeting
|
||||
from noteflow.domain.entities import Meeting, ProcessingStatus
|
||||
from noteflow.domain.entities.meeting import MeetingCreateParams
|
||||
from noteflow.domain.identity import OperationContext
|
||||
from noteflow.domain.value_objects import MeetingId, MeetingState
|
||||
@@ -91,6 +91,8 @@ async def _stop_meeting_and_persist(context: _StopMeetingContext) -> Meeting:
|
||||
previous_state,
|
||||
context.context,
|
||||
)
|
||||
if context.meeting.processing_status is None:
|
||||
context.meeting.processing_status = ProcessingStatus.create_pending()
|
||||
await context.repo.meetings.update(context.meeting)
|
||||
|
||||
if context.repo.supports_diarization_jobs:
|
||||
|
||||
@@ -2,11 +2,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from noteflow.domain.entities import Segment, Summary
|
||||
from noteflow.domain.entities.processing import ProcessingStepStatus
|
||||
from noteflow.domain.ports.unit_of_work import UnitOfWork
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
@@ -15,6 +17,7 @@ from ...proto import noteflow_pb2
|
||||
from ...startup.startup import auto_enable_cloud_llm
|
||||
from .._types import GrpcContext
|
||||
from .._model_status import log_model_status
|
||||
from .._processing_status import ProcessingStatusUpdate, update_processing_status
|
||||
from ..converters import parse_meeting_id_or_abort, summary_to_proto
|
||||
from ..errors import ENTITY_MEETING, abort_not_found
|
||||
from ._summary_generation import generate_placeholder_summary, summarize_or_placeholder
|
||||
@@ -35,6 +38,16 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _SummaryGenerationContext:
|
||||
meeting_id: MeetingId
|
||||
meeting_id_str: str
|
||||
meeting: Meeting
|
||||
segments: list[Segment]
|
||||
style_prompt: str | None
|
||||
force_regenerate: bool
|
||||
|
||||
|
||||
class SummarizationGenerationMixin:
|
||||
"""Generate summaries and handle summary webhooks."""
|
||||
|
||||
@@ -59,12 +72,13 @@ class SummarizationGenerationMixin:
|
||||
await self._prepare_summary_request(request, context)
|
||||
)
|
||||
if existing and not request.force_regenerate:
|
||||
await self._mark_summary_step(request.meeting_id, ProcessingStepStatus.COMPLETED)
|
||||
return summary_to_proto(existing)
|
||||
|
||||
await self._ensure_cloud_provider()
|
||||
|
||||
async with cast(UnitOfWork, self.create_repository_provider()) as repo:
|
||||
style_prompt = await resolve_template_prompt(
|
||||
style_prompt = await self._resolve_style_prompt(
|
||||
TemplateResolutionInputs(
|
||||
request=request,
|
||||
meeting=meeting,
|
||||
@@ -76,23 +90,75 @@ class SummarizationGenerationMixin:
|
||||
summarization_service=self.summarization_service,
|
||||
)
|
||||
)
|
||||
summary = await summarize_or_placeholder(
|
||||
self.summarization_service,
|
||||
meeting_id,
|
||||
segments,
|
||||
style_prompt,
|
||||
)
|
||||
|
||||
saved, trigger_webhook = await self._save_summary(
|
||||
meeting=meeting,
|
||||
summary=summary,
|
||||
force_regenerate=request.force_regenerate,
|
||||
saved, trigger_webhook = await self._generate_summary_with_status(
|
||||
_SummaryGenerationContext(
|
||||
meeting_id=meeting_id,
|
||||
meeting_id_str=request.meeting_id,
|
||||
meeting=meeting,
|
||||
segments=segments,
|
||||
style_prompt=style_prompt,
|
||||
force_regenerate=request.force_regenerate,
|
||||
)
|
||||
)
|
||||
|
||||
if trigger_webhook:
|
||||
await self._trigger_summary_webhook(meeting, saved)
|
||||
return summary_to_proto(saved)
|
||||
|
||||
async def _mark_summary_step(
|
||||
self,
|
||||
meeting_id_str: str,
|
||||
status: ProcessingStepStatus,
|
||||
error_message: str | None = None,
|
||||
) -> None:
|
||||
await update_processing_status(
|
||||
self.create_repository_provider,
|
||||
meeting_id_str,
|
||||
ProcessingStatusUpdate(
|
||||
step="summary",
|
||||
status=status,
|
||||
error_message=error_message,
|
||||
),
|
||||
)
|
||||
|
||||
async def _resolve_style_prompt(self, inputs: TemplateResolutionInputs) -> str | None:
|
||||
return await resolve_template_prompt(inputs)
|
||||
|
||||
async def _generate_summary_with_status(
|
||||
self,
|
||||
context: _SummaryGenerationContext,
|
||||
) -> tuple[Summary, bool]:
|
||||
meeting_id = context.meeting_id
|
||||
meeting_id_str = context.meeting_id_str
|
||||
meeting = context.meeting
|
||||
segments = context.segments
|
||||
style_prompt = context.style_prompt
|
||||
force_regenerate = context.force_regenerate
|
||||
|
||||
await self._mark_summary_step(meeting_id_str, ProcessingStepStatus.RUNNING)
|
||||
try:
|
||||
summary = await summarize_or_placeholder(
|
||||
self.summarization_service,
|
||||
meeting_id,
|
||||
segments,
|
||||
style_prompt,
|
||||
)
|
||||
|
||||
saved, trigger_webhook = await self._save_summary(
|
||||
meeting=meeting,
|
||||
summary=summary,
|
||||
force_regenerate=force_regenerate,
|
||||
)
|
||||
await self._mark_summary_step(meeting_id_str, ProcessingStepStatus.COMPLETED)
|
||||
return saved, trigger_webhook
|
||||
except Exception as exc:
|
||||
await self._mark_summary_step(
|
||||
meeting_id_str,
|
||||
ProcessingStepStatus.FAILED,
|
||||
str(exc),
|
||||
)
|
||||
raise
|
||||
|
||||
async def _prepare_summary_request(
|
||||
self,
|
||||
request: noteflow_pb2.GenerateSummaryRequest,
|
||||
|
||||
@@ -4,7 +4,15 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.domain.constants.fields import SEGMENT_IDS
|
||||
from noteflow.domain.constants.fields import (
|
||||
CATEGORY,
|
||||
CONFIDENCE,
|
||||
IS_PINNED,
|
||||
MEETING_ID,
|
||||
NORMALIZED_TEXT,
|
||||
SEGMENT_IDS,
|
||||
TEXT,
|
||||
)
|
||||
from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
|
||||
@@ -60,11 +68,11 @@ class NerConverter:
|
||||
"""
|
||||
return {
|
||||
"id": entity.id,
|
||||
"meeting_id": entity.meeting_id,
|
||||
"text": entity.text,
|
||||
"normalized_text": entity.normalized_text,
|
||||
"category": entity.category.value,
|
||||
MEETING_ID: entity.meeting_id,
|
||||
TEXT: entity.text,
|
||||
NORMALIZED_TEXT: entity.normalized_text,
|
||||
CATEGORY: entity.category.value,
|
||||
SEGMENT_IDS: entity.segment_ids,
|
||||
"confidence": entity.confidence,
|
||||
"is_pinned": entity.is_pinned,
|
||||
CONFIDENCE: entity.confidence,
|
||||
IS_PINNED: entity.is_pinned,
|
||||
}
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from noteflow.domain.constants.fields import END_TIME, START_TIME
|
||||
from noteflow.domain.entities import (
|
||||
@@ -16,13 +18,17 @@ from noteflow.domain.entities import (
|
||||
from noteflow.domain.entities import (
|
||||
WordTiming as DomainWordTiming,
|
||||
)
|
||||
from noteflow.domain.entities.processing import (
|
||||
ProcessingStatus,
|
||||
ProcessingStepState,
|
||||
ProcessingStepStatus,
|
||||
)
|
||||
from noteflow.domain.value_objects import (
|
||||
AnnotationId,
|
||||
AnnotationType,
|
||||
MeetingId,
|
||||
MeetingState,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.infrastructure.persistence.models import (
|
||||
ActionItemModel,
|
||||
@@ -34,10 +40,108 @@ if TYPE_CHECKING:
|
||||
WordTimingModel,
|
||||
)
|
||||
|
||||
|
||||
class OrmConverter:
|
||||
"""Convert between ORM models and domain entities."""
|
||||
|
||||
@staticmethod
|
||||
def _parse_datetime(value: object) -> datetime | None:
|
||||
"""Parse ISO datetime strings stored in JSON payloads."""
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
try:
|
||||
return datetime.fromisoformat(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _parse_processing_step_status(value: object) -> ProcessingStepStatus:
|
||||
"""Parse a processing step status from JSON payload values."""
|
||||
if isinstance(value, ProcessingStepStatus):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return ProcessingStepStatus(value)
|
||||
except ValueError:
|
||||
return ProcessingStepStatus.PENDING
|
||||
return ProcessingStepStatus.PENDING
|
||||
|
||||
@staticmethod
|
||||
def _processing_step_state_from_payload(
|
||||
payload: Mapping[str, object] | None,
|
||||
) -> ProcessingStepState:
|
||||
"""Convert JSON payload to ProcessingStepState."""
|
||||
if payload is None:
|
||||
return ProcessingStepState.pending()
|
||||
status = OrmConverter._parse_processing_step_status(payload.get("status"))
|
||||
error_message_value = payload.get("error_message")
|
||||
error_message = error_message_value if isinstance(error_message_value, str) else ""
|
||||
return ProcessingStepState(
|
||||
status=status,
|
||||
error_message=error_message,
|
||||
started_at=OrmConverter._parse_datetime(payload.get("started_at")),
|
||||
completed_at=OrmConverter._parse_datetime(payload.get("completed_at")),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _coerce_step_payload(value: object) -> Mapping[str, object] | None:
|
||||
if isinstance(value, Mapping):
|
||||
return cast(Mapping[str, object], value)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def processing_status_from_payload(
|
||||
payload: Mapping[str, object] | None,
|
||||
) -> ProcessingStatus | None:
|
||||
"""Convert JSON payload to ProcessingStatus."""
|
||||
if payload is None:
|
||||
return None
|
||||
summary = OrmConverter._processing_step_state_from_payload(
|
||||
OrmConverter._coerce_step_payload(payload.get("summary")),
|
||||
)
|
||||
entities = OrmConverter._processing_step_state_from_payload(
|
||||
OrmConverter._coerce_step_payload(payload.get("entities")),
|
||||
)
|
||||
diarization = OrmConverter._processing_step_state_from_payload(
|
||||
OrmConverter._coerce_step_payload(payload.get("diarization")),
|
||||
)
|
||||
return ProcessingStatus(
|
||||
summary=summary,
|
||||
entities=entities,
|
||||
diarization=diarization,
|
||||
queued_at=OrmConverter._parse_datetime(payload.get("queued_at")),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _processing_step_state_to_payload(
|
||||
state: ProcessingStepState,
|
||||
) -> dict[str, object]:
|
||||
"""Convert ProcessingStepState to JSON payload."""
|
||||
payload: dict[str, object] = {
|
||||
"status": state.status.value,
|
||||
"error_message": state.error_message,
|
||||
}
|
||||
if state.started_at is not None:
|
||||
payload["started_at"] = state.started_at.isoformat()
|
||||
if state.completed_at is not None:
|
||||
payload["completed_at"] = state.completed_at.isoformat()
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def processing_status_to_payload(
|
||||
status: ProcessingStatus | None,
|
||||
) -> dict[str, object] | None:
|
||||
"""Convert ProcessingStatus to JSON payload."""
|
||||
if status is None:
|
||||
return None
|
||||
payload: dict[str, object] = {
|
||||
"summary": OrmConverter._processing_step_state_to_payload(status.summary),
|
||||
"entities": OrmConverter._processing_step_state_to_payload(status.entities),
|
||||
"diarization": OrmConverter._processing_step_state_to_payload(status.diarization),
|
||||
}
|
||||
if status.queued_at is not None:
|
||||
payload["queued_at"] = status.queued_at.isoformat()
|
||||
return payload
|
||||
|
||||
# --- WordTiming ---
|
||||
|
||||
@staticmethod
|
||||
@@ -113,6 +217,7 @@ class OrmConverter:
|
||||
wrapped_dek=model.wrapped_dek,
|
||||
asset_path=model.asset_path,
|
||||
version=model.version,
|
||||
processing_status=OrmConverter.processing_status_from_payload(model.processing_status),
|
||||
)
|
||||
|
||||
# --- Segment ---
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
"""add_meeting_processing_status
|
||||
|
||||
Revision ID: t4u5v6w7x8y9
|
||||
Revises: s3t4u5v6w7x8
|
||||
Create Date: 2026-01-15 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "t4u5v6w7x8y9"
|
||||
down_revision: str | Sequence[str] | None = "s3t4u5v6w7x8"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add processing_status column to meetings."""
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE noteflow.meetings
|
||||
ADD COLUMN processing_status JSONB;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove processing_status column from meetings."""
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE noteflow.meetings
|
||||
DROP COLUMN processing_status;
|
||||
"""
|
||||
)
|
||||
@@ -17,7 +17,7 @@ from sqlalchemy import (
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from .._base import DEFAULT_USER_ID, DEFAULT_WORKSPACE_ID, EMBEDDING_DIM, Base
|
||||
@@ -115,6 +115,10 @@ class MeetingModel(UuidPrimaryKeyMixin, CreatedAtMixin, MetadataMixin, Base):
|
||||
Text,
|
||||
nullable=True,
|
||||
)
|
||||
processing_status: Mapped[dict[str, object] | None] = mapped_column(
|
||||
JSONB,
|
||||
nullable=True,
|
||||
)
|
||||
version: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
||||
# Soft delete support
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(
|
||||
|
||||
@@ -9,6 +9,8 @@ from sqlalchemy import Boolean, Float, ForeignKey, Integer, String, Text, Unique
|
||||
from sqlalchemy.dialects.postgresql import ARRAY, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from noteflow.domain.constants.fields import NORMALIZED_TEXT
|
||||
|
||||
from .._base import Base
|
||||
from .._mixins import CreatedAtMixin, UpdatedAtMixin, UuidPrimaryKeyMixin
|
||||
from .._strings import (
|
||||
@@ -31,7 +33,7 @@ class NamedEntityModel(UuidPrimaryKeyMixin, CreatedAtMixin, UpdatedAtMixin, Base
|
||||
__tablename__ = "named_entities"
|
||||
__table_args__: tuple[UniqueConstraint, dict[str, str]] = (
|
||||
UniqueConstraint(
|
||||
"meeting_id", "normalized_text", name="uq_named_entities_meeting_text"
|
||||
"meeting_id", NORMALIZED_TEXT, name="uq_named_entities_meeting_text"
|
||||
),
|
||||
{"schema": "noteflow"},
|
||||
)
|
||||
|
||||
@@ -6,8 +6,19 @@ from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
|
||||
from noteflow.domain.constants.fields import (
|
||||
CATEGORY,
|
||||
CONFIDENCE,
|
||||
IS_PINNED,
|
||||
MEETING_ID,
|
||||
NORMALIZED_TEXT,
|
||||
SEGMENT_IDS,
|
||||
TEXT,
|
||||
UPDATED_AT,
|
||||
)
|
||||
from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity
|
||||
from noteflow.infrastructure.converters.ner_converters import NerConverter
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
@@ -43,7 +54,7 @@ class SqlAlchemyEntityRepository(
|
||||
async def save(self, entity: NamedEntity) -> NamedEntity:
|
||||
"""Save or update a named entity.
|
||||
|
||||
Uses merge to handle both insert and update cases.
|
||||
Uses an upsert on (meeting_id, normalized_text) to avoid duplicate key errors.
|
||||
|
||||
Args:
|
||||
entity: The entity to save.
|
||||
@@ -52,10 +63,24 @@ class SqlAlchemyEntityRepository(
|
||||
Saved entity with db_id populated.
|
||||
"""
|
||||
kwargs = NerConverter.to_orm_kwargs(entity)
|
||||
model = NamedEntityModel(**kwargs)
|
||||
merged = await self._session.merge(model)
|
||||
await self._session.flush()
|
||||
entity.db_id = merged.id
|
||||
stmt = insert(NamedEntityModel).values(**kwargs)
|
||||
excluded = stmt.excluded
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=[MEETING_ID, NORMALIZED_TEXT],
|
||||
set_={
|
||||
TEXT: excluded.text,
|
||||
NORMALIZED_TEXT: excluded.normalized_text,
|
||||
CATEGORY: excluded.category,
|
||||
SEGMENT_IDS: excluded.segment_ids,
|
||||
CONFIDENCE: excluded.confidence,
|
||||
IS_PINNED: NamedEntityModel.is_pinned | excluded.is_pinned,
|
||||
UPDATED_AT: func.now(),
|
||||
},
|
||||
).returning(NamedEntityModel.id)
|
||||
result = await self._session.execute(stmt)
|
||||
row = result.first()
|
||||
if row is not None:
|
||||
entity.db_id = row[0]
|
||||
logger.info(
|
||||
"entity_saved",
|
||||
entity_id=str(entity.id),
|
||||
@@ -67,8 +92,7 @@ class SqlAlchemyEntityRepository(
|
||||
async def save_batch(self, entities: Sequence[NamedEntity]) -> Sequence[NamedEntity]:
|
||||
"""Save multiple entities efficiently.
|
||||
|
||||
Uses individual merges to handle upsert semantics with the unique
|
||||
constraint on (meeting_id, normalized_text).
|
||||
Uses an upsert on (meeting_id, normalized_text) to avoid duplicate key errors.
|
||||
|
||||
Args:
|
||||
entities: List of entities to save.
|
||||
@@ -76,13 +100,38 @@ class SqlAlchemyEntityRepository(
|
||||
Returns:
|
||||
Saved entities with db_ids populated.
|
||||
"""
|
||||
for entity in entities:
|
||||
kwargs = NerConverter.to_orm_kwargs(entity)
|
||||
model = NamedEntityModel(**kwargs)
|
||||
merged = await self._session.merge(model)
|
||||
entity.db_id = merged.id
|
||||
if not entities:
|
||||
return entities
|
||||
|
||||
await self._session.flush()
|
||||
payload = [NerConverter.to_orm_kwargs(entity) for entity in entities]
|
||||
stmt = insert(NamedEntityModel).values(payload)
|
||||
excluded = stmt.excluded
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=[MEETING_ID, NORMALIZED_TEXT],
|
||||
set_={
|
||||
TEXT: excluded.text,
|
||||
NORMALIZED_TEXT: excluded.normalized_text,
|
||||
CATEGORY: excluded.category,
|
||||
SEGMENT_IDS: excluded.segment_ids,
|
||||
CONFIDENCE: excluded.confidence,
|
||||
IS_PINNED: NamedEntityModel.is_pinned | excluded.is_pinned,
|
||||
UPDATED_AT: func.now(),
|
||||
},
|
||||
).returning(
|
||||
NamedEntityModel.id,
|
||||
NamedEntityModel.meeting_id,
|
||||
NamedEntityModel.normalized_text,
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
rows = result.all()
|
||||
if rows:
|
||||
ids_by_key = {
|
||||
(str(row.meeting_id), row.normalized_text): row.id
|
||||
for row in rows
|
||||
}
|
||||
for entity in entities:
|
||||
key = (str(entity.meeting_id), entity.normalized_text)
|
||||
entity.db_id = ids_by_key.get(key, entity.db_id)
|
||||
if entities:
|
||||
logger.info(
|
||||
"entities_batch_saved",
|
||||
|
||||
@@ -45,6 +45,7 @@ class SqlAlchemyMeetingRepository(BaseRepository):
|
||||
metadata_=meeting.metadata,
|
||||
wrapped_dek=meeting.wrapped_dek,
|
||||
asset_path=meeting.asset_path,
|
||||
processing_status=OrmConverter.processing_status_to_payload(meeting.processing_status),
|
||||
version=meeting.version,
|
||||
)
|
||||
self._session.add(model)
|
||||
@@ -96,6 +97,9 @@ class SqlAlchemyMeetingRepository(BaseRepository):
|
||||
model.metadata_ = dict(meeting.metadata)
|
||||
model.wrapped_dek = meeting.wrapped_dek
|
||||
model.asset_path = meeting.asset_path
|
||||
model.processing_status = OrmConverter.processing_status_to_payload(
|
||||
meeting.processing_status
|
||||
)
|
||||
model.version += 1
|
||||
meeting.version = model.version
|
||||
|
||||
|
||||
Reference in New Issue
Block a user