feat: Implement PyTorch and ROCm ASR engines with GPU detection and enhance calendar OAuth integration across client and server.
This commit is contained in:
@@ -10,7 +10,8 @@ use tracing::{error, info};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::grpc::types::calendar::{
|
||||
CompleteOAuthResult, DisconnectOAuthResult, GetCalendarProvidersResult,
|
||||
GetOAuthConnectionStatusResult, InitiateOAuthResult, ListCalendarEventsResult,
|
||||
GetOAuthClientConfigResult, GetOAuthConnectionStatusResult, InitiateOAuthResult,
|
||||
ListCalendarEventsResult, OAuthClientConfig, SetOAuthClientConfigResult,
|
||||
};
|
||||
use crate::oauth_loopback::OAuthLoopbackServer;
|
||||
use crate::state::AppState;
|
||||
@@ -194,6 +195,44 @@ pub async fn get_oauth_connection_status(
|
||||
.await
|
||||
}
|
||||
|
||||
/// Get OAuth client override configuration for a provider.
|
||||
#[tauri::command(rename_all = "snake_case")]
|
||||
pub async fn get_oauth_client_config(
|
||||
state: State<'_, Arc<AppState>>,
|
||||
provider: String,
|
||||
workspace_id: Option<String>,
|
||||
integration_type: Option<String>,
|
||||
) -> Result<GetOAuthClientConfigResult> {
|
||||
state
|
||||
.grpc_client
|
||||
.get_oauth_client_config(
|
||||
&provider,
|
||||
integration_type.as_deref(),
|
||||
workspace_id.as_deref(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Set OAuth client override configuration for a provider.
|
||||
#[tauri::command(rename_all = "snake_case")]
|
||||
pub async fn set_oauth_client_config(
|
||||
state: State<'_, Arc<AppState>>,
|
||||
provider: String,
|
||||
workspace_id: Option<String>,
|
||||
integration_type: Option<String>,
|
||||
config: OAuthClientConfig,
|
||||
) -> Result<SetOAuthClientConfigResult> {
|
||||
state
|
||||
.grpc_client
|
||||
.set_oauth_client_config(
|
||||
&provider,
|
||||
integration_type.as_deref(),
|
||||
workspace_id.as_deref(),
|
||||
config,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Disconnect OAuth integration.
|
||||
#[tauri::command(rename_all = "snake_case")]
|
||||
pub async fn disconnect_oauth(
|
||||
|
||||
@@ -85,15 +85,36 @@ pub async fn list_workspaces(_state: State<'_, Arc<AppState>>) -> Result<ListWor
|
||||
/// Switch active workspace (local-first validation).
|
||||
#[tauri::command(rename_all = "snake_case")]
|
||||
pub async fn switch_workspace(
|
||||
_state: State<'_, Arc<AppState>>,
|
||||
state: State<'_, Arc<AppState>>,
|
||||
workspace_id: String,
|
||||
) -> Result<SwitchWorkspaceResult> {
|
||||
let workspace = default_workspaces()
|
||||
.into_iter()
|
||||
.find(|workspace| workspace.id == workspace_id);
|
||||
|
||||
let success = if let Some(ws) = &workspace {
|
||||
// Update the stored identity so subsequent gRPC calls use the correct workspace_id
|
||||
state.identity.store().switch_workspace(
|
||||
ws.id.clone(),
|
||||
ws.name.clone(),
|
||||
// Map generic WorkspaceRole to string for storage
|
||||
match ws.role {
|
||||
WorkspaceRole::Owner => "owner".to_string(),
|
||||
WorkspaceRole::Admin => "admin".to_string(),
|
||||
WorkspaceRole::Member => "member".to_string(),
|
||||
WorkspaceRole::Viewer => "viewer".to_string(),
|
||||
WorkspaceRole::Guest => "guest".to_string(),
|
||||
WorkspaceRole::System => "system".to_string(),
|
||||
WorkspaceRole::Unspecified => "unspecified".to_string(),
|
||||
},
|
||||
)?;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
Ok(SwitchWorkspaceResult {
|
||||
success: workspace.is_some(),
|
||||
success,
|
||||
workspace,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,7 +4,8 @@ use crate::error::Result;
|
||||
use crate::grpc::noteflow as pb;
|
||||
use crate::grpc::types::calendar::{
|
||||
CompleteOAuthResult, DisconnectOAuthResult, GetCalendarProvidersResult,
|
||||
GetOAuthConnectionStatusResult, InitiateOAuthResult, ListCalendarEventsResult, OAuthConnection,
|
||||
GetOAuthClientConfigResult, GetOAuthConnectionStatusResult, InitiateOAuthResult,
|
||||
ListCalendarEventsResult, OAuthClientConfig, OAuthConnection, SetOAuthClientConfigResult,
|
||||
};
|
||||
|
||||
use super::converters::{convert_calendar_event, convert_calendar_provider};
|
||||
@@ -131,6 +132,36 @@ impl GrpcClient {
|
||||
})
|
||||
}
|
||||
|
||||
/// Get OAuth client override configuration.
|
||||
pub async fn get_oauth_client_config(
|
||||
&self,
|
||||
provider: &str,
|
||||
integration_type: Option<&str>,
|
||||
workspace_id: Option<&str>,
|
||||
) -> Result<GetOAuthClientConfigResult> {
|
||||
let mut client = self.get_client()?;
|
||||
let response = client
|
||||
.get_o_auth_client_config(pb::GetOAuthClientConfigRequest {
|
||||
provider: provider.to_string(),
|
||||
integration_type: integration_type.unwrap_or_default().to_string(),
|
||||
workspace_id: workspace_id.unwrap_or_default().to_string(),
|
||||
})
|
||||
.await?
|
||||
.into_inner();
|
||||
|
||||
let config = response.config.unwrap_or_default();
|
||||
Ok(GetOAuthClientConfigResult {
|
||||
config: OAuthClientConfig {
|
||||
client_id: config.client_id,
|
||||
client_secret: config.client_secret,
|
||||
redirect_uri: config.redirect_uri,
|
||||
scopes: config.scopes,
|
||||
override_enabled: config.override_enabled,
|
||||
has_client_secret: config.has_client_secret,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/// Disconnect OAuth integration.
|
||||
pub async fn disconnect_oauth(
|
||||
&self,
|
||||
@@ -151,4 +182,35 @@ impl GrpcClient {
|
||||
error_message: response.error_message,
|
||||
})
|
||||
}
|
||||
|
||||
/// Set OAuth client override configuration.
|
||||
pub async fn set_oauth_client_config(
|
||||
&self,
|
||||
provider: &str,
|
||||
integration_type: Option<&str>,
|
||||
workspace_id: Option<&str>,
|
||||
config: OAuthClientConfig,
|
||||
) -> Result<SetOAuthClientConfigResult> {
|
||||
let mut client = self.get_client()?;
|
||||
let response = client
|
||||
.set_o_auth_client_config(pb::SetOAuthClientConfigRequest {
|
||||
provider: provider.to_string(),
|
||||
integration_type: integration_type.unwrap_or_default().to_string(),
|
||||
workspace_id: workspace_id.unwrap_or_default().to_string(),
|
||||
config: Some(pb::OAuthClientConfig {
|
||||
client_id: config.client_id,
|
||||
client_secret: config.client_secret,
|
||||
redirect_uri: config.redirect_uri,
|
||||
scopes: config.scopes,
|
||||
override_enabled: config.override_enabled,
|
||||
has_client_secret: config.has_client_secret,
|
||||
}),
|
||||
})
|
||||
.await?
|
||||
.into_inner();
|
||||
|
||||
Ok(SetOAuthClientConfigResult {
|
||||
success: response.success,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -526,6 +526,12 @@ pub struct AsrConfiguration {
|
||||
/// Available compute types for current device
|
||||
#[prost(enumeration = "AsrComputeType", repeated, tag = "7")]
|
||||
pub available_compute_types: ::prost::alloc::vec::Vec<i32>,
|
||||
/// Whether ROCm is available on this server
|
||||
#[prost(bool, tag = "8")]
|
||||
pub rocm_available: bool,
|
||||
/// Current GPU backend (none, cuda, rocm, mps)
|
||||
#[prost(string, tag = "9")]
|
||||
pub gpu_backend: ::prost::alloc::string::String,
|
||||
}
|
||||
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
|
||||
pub struct GetAsrConfigurationRequest {}
|
||||
@@ -1148,6 +1154,65 @@ pub struct DisconnectOAuthResponse {
|
||||
#[prost(string, tag = "2")]
|
||||
pub error_message: ::prost::alloc::string::String,
|
||||
}
|
||||
/// OAuth client override configuration
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct OAuthClientConfig {
|
||||
/// OAuth client ID
|
||||
#[prost(string, tag = "1")]
|
||||
pub client_id: ::prost::alloc::string::String,
|
||||
/// Optional client secret (request only)
|
||||
#[prost(string, optional, tag = "2")]
|
||||
pub client_secret: ::core::option::Option<::prost::alloc::string::String>,
|
||||
/// Redirect URI for OAuth callback
|
||||
#[prost(string, tag = "3")]
|
||||
pub redirect_uri: ::prost::alloc::string::String,
|
||||
/// OAuth scopes to request
|
||||
#[prost(string, repeated, tag = "4")]
|
||||
pub scopes: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
|
||||
/// Whether override should be used
|
||||
#[prost(bool, tag = "5")]
|
||||
pub override_enabled: bool,
|
||||
/// Whether a client secret is stored (response only)
|
||||
#[prost(bool, tag = "6")]
|
||||
pub has_client_secret: bool,
|
||||
}
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct GetOAuthClientConfigRequest {
|
||||
/// Provider to configure: google, outlook
|
||||
#[prost(string, tag = "1")]
|
||||
pub provider: ::prost::alloc::string::String,
|
||||
/// Optional integration type
|
||||
#[prost(string, tag = "2")]
|
||||
pub integration_type: ::prost::alloc::string::String,
|
||||
/// Optional workspace ID override
|
||||
#[prost(string, tag = "3")]
|
||||
pub workspace_id: ::prost::alloc::string::String,
|
||||
}
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct GetOAuthClientConfigResponse {
|
||||
#[prost(message, optional, tag = "1")]
|
||||
pub config: ::core::option::Option<OAuthClientConfig>,
|
||||
}
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct SetOAuthClientConfigRequest {
|
||||
/// Provider to configure: google, outlook
|
||||
#[prost(string, tag = "1")]
|
||||
pub provider: ::prost::alloc::string::String,
|
||||
/// Optional integration type
|
||||
#[prost(string, tag = "2")]
|
||||
pub integration_type: ::prost::alloc::string::String,
|
||||
/// Optional workspace ID override
|
||||
#[prost(string, tag = "3")]
|
||||
pub workspace_id: ::prost::alloc::string::String,
|
||||
/// OAuth client configuration
|
||||
#[prost(message, optional, tag = "4")]
|
||||
pub config: ::core::option::Option<OAuthClientConfig>,
|
||||
}
|
||||
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
|
||||
pub struct SetOAuthClientConfigResponse {
|
||||
#[prost(bool, tag = "1")]
|
||||
pub success: bool,
|
||||
}
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct RegisterWebhookRequest {
|
||||
/// Workspace this webhook belongs to
|
||||
@@ -2464,6 +2529,7 @@ pub enum AsrDevice {
|
||||
Unspecified = 0,
|
||||
Cpu = 1,
|
||||
Cuda = 2,
|
||||
Rocm = 3,
|
||||
}
|
||||
impl AsrDevice {
|
||||
/// String value of the enum field names used in the ProtoBuf definition.
|
||||
@@ -2475,6 +2541,7 @@ impl AsrDevice {
|
||||
Self::Unspecified => "ASR_DEVICE_UNSPECIFIED",
|
||||
Self::Cpu => "ASR_DEVICE_CPU",
|
||||
Self::Cuda => "ASR_DEVICE_CUDA",
|
||||
Self::Rocm => "ASR_DEVICE_ROCM",
|
||||
}
|
||||
}
|
||||
/// Creates an enum from field names used in the ProtoBuf definition.
|
||||
@@ -2483,6 +2550,7 @@ impl AsrDevice {
|
||||
"ASR_DEVICE_UNSPECIFIED" => Some(Self::Unspecified),
|
||||
"ASR_DEVICE_CPU" => Some(Self::Cpu),
|
||||
"ASR_DEVICE_CUDA" => Some(Self::Cuda),
|
||||
"ASR_DEVICE_ROCM" => Some(Self::Rocm),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@@ -3832,6 +3900,58 @@ pub mod note_flow_service_client {
|
||||
.insert(GrpcMethod::new("noteflow.NoteFlowService", "DisconnectOAuth"));
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
pub async fn get_o_auth_client_config(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::GetOAuthClientConfigRequest>,
|
||||
) -> std::result::Result<
|
||||
tonic::Response<super::GetOAuthClientConfigResponse>,
|
||||
tonic::Status,
|
||||
> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::unknown(
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/noteflow.NoteFlowService/GetOAuthClientConfig",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(
|
||||
GrpcMethod::new("noteflow.NoteFlowService", "GetOAuthClientConfig"),
|
||||
);
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
pub async fn set_o_auth_client_config(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::SetOAuthClientConfigRequest>,
|
||||
) -> std::result::Result<
|
||||
tonic::Response<super::SetOAuthClientConfigResponse>,
|
||||
tonic::Status,
|
||||
> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::unknown(
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/noteflow.NoteFlowService/SetOAuthClientConfig",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(
|
||||
GrpcMethod::new("noteflow.NoteFlowService", "SetOAuthClientConfig"),
|
||||
);
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
/// Webhook management (Sprint 6)
|
||||
pub async fn register_webhook(
|
||||
&mut self,
|
||||
|
||||
@@ -75,15 +75,39 @@ pub struct OAuthConnection {
|
||||
pub integration_type: String,
|
||||
}
|
||||
|
||||
/// OAuth client configuration (matches proto OAuthClientConfig)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OAuthClientConfig {
|
||||
pub client_id: String,
|
||||
pub client_secret: Option<String>,
|
||||
pub redirect_uri: String,
|
||||
pub scopes: Vec<String>,
|
||||
pub override_enabled: bool,
|
||||
#[serde(default)]
|
||||
pub has_client_secret: bool,
|
||||
}
|
||||
|
||||
/// Get OAuth connection status result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GetOAuthConnectionStatusResult {
|
||||
pub connection: OAuthConnection,
|
||||
}
|
||||
|
||||
/// Get OAuth client config result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GetOAuthClientConfigResult {
|
||||
pub config: OAuthClientConfig,
|
||||
}
|
||||
|
||||
/// Disconnect OAuth result (matches proto DisconnectOAuthResponse)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DisconnectOAuthResult {
|
||||
pub success: bool,
|
||||
pub error_message: String,
|
||||
}
|
||||
|
||||
/// Set OAuth client config result (matches proto SetOAuthClientConfigResponse)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SetOAuthClientConfigResult {
|
||||
pub success: bool,
|
||||
}
|
||||
|
||||
@@ -149,6 +149,8 @@ macro_rules! app_invoke_handler {
|
||||
commands::initiate_oauth_loopback,
|
||||
commands::complete_oauth,
|
||||
commands::get_oauth_connection_status,
|
||||
commands::get_oauth_client_config,
|
||||
commands::set_oauth_client_config,
|
||||
commands::disconnect_oauth,
|
||||
// Webhooks (5 commands)
|
||||
commands::register_webhook,
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import type { NoteFlowAPI } from '../interface';
|
||||
import type {
|
||||
CompleteCalendarAuthResponse,
|
||||
GetOAuthClientConfigRequest,
|
||||
GetOAuthClientConfigResponse,
|
||||
GetCalendarProvidersResponse,
|
||||
GetOAuthConnectionStatusResponse,
|
||||
InitiateCalendarAuthResponse,
|
||||
ListCalendarEventsResponse,
|
||||
SetOAuthClientConfigRequest,
|
||||
SetOAuthClientConfigResponse,
|
||||
} from '../types';
|
||||
import { rejectReadOnly } from './readonly';
|
||||
|
||||
@@ -15,6 +19,8 @@ export const cachedCalendarAPI: Pick<
|
||||
| 'initiateCalendarAuth'
|
||||
| 'completeCalendarAuth'
|
||||
| 'getOAuthConnectionStatus'
|
||||
| 'getOAuthClientConfig'
|
||||
| 'setOAuthClientConfig'
|
||||
| 'disconnectCalendar'
|
||||
> = {
|
||||
async listCalendarEvents(
|
||||
@@ -52,6 +58,24 @@ export const cachedCalendarAPI: Pick<
|
||||
},
|
||||
};
|
||||
},
|
||||
async getOAuthClientConfig(
|
||||
_request: GetOAuthClientConfigRequest
|
||||
): Promise<GetOAuthClientConfigResponse> {
|
||||
return {
|
||||
config: {
|
||||
client_id: '',
|
||||
redirect_uri: '',
|
||||
scopes: [],
|
||||
override_enabled: false,
|
||||
has_client_secret: false,
|
||||
},
|
||||
};
|
||||
},
|
||||
async setOAuthClientConfig(
|
||||
_request: SetOAuthClientConfigRequest
|
||||
): Promise<SetOAuthClientConfigResponse> {
|
||||
return rejectReadOnly();
|
||||
},
|
||||
async disconnectCalendar(_provider: string) {
|
||||
return rejectReadOnly();
|
||||
},
|
||||
|
||||
@@ -50,6 +50,8 @@ import type {
|
||||
GetWorkspaceSettingsRequest,
|
||||
GetWorkspaceSettingsResponse,
|
||||
GetMeetingRequest,
|
||||
GetOAuthClientConfigRequest,
|
||||
GetOAuthClientConfigResponse,
|
||||
GetOAuthConnectionStatusResponse,
|
||||
GetProjectBySlugRequest,
|
||||
GetProjectRequest,
|
||||
@@ -100,6 +102,8 @@ import type {
|
||||
SummarizationTemplateMutationResponse,
|
||||
SwitchWorkspaceResponse,
|
||||
Summary,
|
||||
SetOAuthClientConfigRequest,
|
||||
SetOAuthClientConfigResponse,
|
||||
TriggerStatus,
|
||||
UpdateASRConfigurationRequest,
|
||||
UpdateASRConfigurationResult,
|
||||
@@ -769,6 +773,22 @@ export interface NoteFlowAPI {
|
||||
*/
|
||||
getOAuthConnectionStatus(provider: string): Promise<GetOAuthConnectionStatusResponse>;
|
||||
|
||||
/**
|
||||
* Get OAuth client override configuration.
|
||||
* @see gRPC endpoint: GetOAuthClientConfig (unary)
|
||||
*/
|
||||
getOAuthClientConfig(
|
||||
request: GetOAuthClientConfigRequest
|
||||
): Promise<GetOAuthClientConfigResponse>;
|
||||
|
||||
/**
|
||||
* Set OAuth client override configuration.
|
||||
* @see gRPC endpoint: SetOAuthClientConfig (unary)
|
||||
*/
|
||||
setOAuthClientConfig(
|
||||
request: SetOAuthClientConfigRequest
|
||||
): Promise<SetOAuthClientConfigResponse>;
|
||||
|
||||
/**
|
||||
* Disconnect OAuth integration.
|
||||
* @see gRPC endpoint: DisconnectOAuth (unary)
|
||||
|
||||
@@ -42,6 +42,8 @@ import type {
|
||||
GetCalendarProvidersResponse,
|
||||
GetCurrentUserResponse,
|
||||
GetMeetingRequest,
|
||||
GetOAuthClientConfigRequest,
|
||||
GetOAuthClientConfigResponse,
|
||||
GetOAuthConnectionStatusResponse,
|
||||
GetProjectBySlugRequest,
|
||||
GetProjectRequest,
|
||||
@@ -95,6 +97,8 @@ import type {
|
||||
RestoreSummarizationTemplateVersionRequest,
|
||||
ServerInfo,
|
||||
StartIntegrationSyncResponse,
|
||||
SetOAuthClientConfigRequest,
|
||||
SetOAuthClientConfigResponse,
|
||||
SwitchWorkspaceResponse,
|
||||
SummarizationTemplate,
|
||||
SummarizationTemplateMutationResponse,
|
||||
@@ -1620,6 +1624,28 @@ export const mockAPI: NoteFlowAPI = {
|
||||
};
|
||||
},
|
||||
|
||||
async getOAuthClientConfig(
|
||||
_request: GetOAuthClientConfigRequest
|
||||
): Promise<GetOAuthClientConfigResponse> {
|
||||
await delay(50);
|
||||
return {
|
||||
config: {
|
||||
client_id: '',
|
||||
redirect_uri: '',
|
||||
scopes: [],
|
||||
override_enabled: false,
|
||||
has_client_secret: false,
|
||||
},
|
||||
};
|
||||
},
|
||||
|
||||
async setOAuthClientConfig(
|
||||
_request: SetOAuthClientConfigRequest
|
||||
): Promise<SetOAuthClientConfigResponse> {
|
||||
await delay(50);
|
||||
return { success: true };
|
||||
},
|
||||
|
||||
async disconnectCalendar(_provider: string): Promise<DisconnectOAuthResponse> {
|
||||
await delay(100);
|
||||
return { success: true };
|
||||
|
||||
@@ -70,6 +70,8 @@ import type {
|
||||
GetActiveProjectRequest,
|
||||
GetActiveProjectResponse,
|
||||
GetMeetingRequest,
|
||||
GetOAuthClientConfigRequest,
|
||||
GetOAuthClientConfigResponse,
|
||||
GetOAuthConnectionStatusResponse,
|
||||
GetProjectBySlugRequest,
|
||||
GetProjectRequest,
|
||||
@@ -122,6 +124,8 @@ import type {
|
||||
SetActiveProjectRequest,
|
||||
SetHuggingFaceTokenRequest,
|
||||
SetHuggingFaceTokenResult,
|
||||
SetOAuthClientConfigRequest,
|
||||
SetOAuthClientConfigResponse,
|
||||
StartIntegrationSyncResponse,
|
||||
SwitchWorkspaceResponse,
|
||||
SummarizationOptions,
|
||||
@@ -1278,6 +1282,25 @@ export function createTauriAPI(invoke: TauriInvoke, listen: TauriListen): NoteFl
|
||||
provider,
|
||||
});
|
||||
},
|
||||
async getOAuthClientConfig(
|
||||
request: GetOAuthClientConfigRequest
|
||||
): Promise<GetOAuthClientConfigResponse> {
|
||||
return invoke<GetOAuthClientConfigResponse>(TauriCommands.GET_OAUTH_CLIENT_CONFIG, {
|
||||
provider: request.provider,
|
||||
workspace_id: request.workspace_id,
|
||||
integration_type: request.integration_type,
|
||||
});
|
||||
},
|
||||
async setOAuthClientConfig(
|
||||
request: SetOAuthClientConfigRequest
|
||||
): Promise<SetOAuthClientConfigResponse> {
|
||||
return invoke<SetOAuthClientConfigResponse>(TauriCommands.SET_OAUTH_CLIENT_CONFIG, {
|
||||
provider: request.provider,
|
||||
workspace_id: request.workspace_id,
|
||||
integration_type: request.integration_type,
|
||||
config: request.config,
|
||||
});
|
||||
},
|
||||
async disconnectCalendar(provider: string): Promise<DisconnectOAuthResponse> {
|
||||
const response = await invoke<DisconnectOAuthResponse>(TauriCommands.DISCONNECT_OAUTH, {
|
||||
provider,
|
||||
|
||||
@@ -99,6 +99,8 @@ export const TauriCommands = {
|
||||
INITIATE_OAUTH_LOOPBACK: 'initiate_oauth_loopback',
|
||||
COMPLETE_OAUTH: 'complete_oauth',
|
||||
GET_OAUTH_CONNECTION_STATUS: 'get_oauth_connection_status',
|
||||
GET_OAUTH_CLIENT_CONFIG: 'get_oauth_client_config',
|
||||
SET_OAUTH_CLIENT_CONFIG: 'set_oauth_client_config',
|
||||
DISCONNECT_OAUTH: 'disconnect_oauth',
|
||||
REGISTER_WEBHOOK: 'register_webhook',
|
||||
LIST_WEBHOOKS: 'list_webhooks',
|
||||
|
||||
@@ -118,3 +118,35 @@ export interface DisconnectOAuthRequest {
|
||||
export interface DisconnectOAuthResponse {
|
||||
success: boolean;
|
||||
}
|
||||
|
||||
// --- OAuth Client Configuration ---
|
||||
|
||||
export interface OAuthClientConfig {
|
||||
client_id: string;
|
||||
client_secret?: string;
|
||||
redirect_uri: string;
|
||||
scopes: string[];
|
||||
override_enabled: boolean;
|
||||
has_client_secret?: boolean;
|
||||
}
|
||||
|
||||
export interface GetOAuthClientConfigRequest {
|
||||
provider: string;
|
||||
integration_type?: string;
|
||||
workspace_id?: string;
|
||||
}
|
||||
|
||||
export interface GetOAuthClientConfigResponse {
|
||||
config: OAuthClientConfig;
|
||||
}
|
||||
|
||||
export interface SetOAuthClientConfigRequest {
|
||||
provider: string;
|
||||
integration_type?: string;
|
||||
workspace_id?: string;
|
||||
config: OAuthClientConfig;
|
||||
}
|
||||
|
||||
export interface SetOAuthClientConfigResponse {
|
||||
success: boolean;
|
||||
}
|
||||
|
||||
@@ -78,6 +78,10 @@ export type {
|
||||
SyncHistoryEvent,
|
||||
SyncNotificationPreferences,
|
||||
WebhookConfig,
|
||||
GetOAuthClientConfigRequest,
|
||||
GetOAuthClientConfigResponse,
|
||||
SetOAuthClientConfigRequest,
|
||||
SetOAuthClientConfigResponse,
|
||||
} from './requests/integrations';
|
||||
export type {
|
||||
ProjectScope,
|
||||
|
||||
@@ -76,6 +76,10 @@ export interface Integration {
|
||||
status: IntegrationStatus;
|
||||
last_sync?: number;
|
||||
error_message?: string;
|
||||
/** Whether OAuth override credentials should be used (calendar only). */
|
||||
oauth_override_enabled?: boolean;
|
||||
/** Whether server has a stored OAuth override secret (calendar only). */
|
||||
oauth_override_has_secret?: boolean;
|
||||
// Type-specific configs
|
||||
oauth_config?: OAuthConfig;
|
||||
email_config?: EmailProviderConfig;
|
||||
@@ -109,3 +113,24 @@ export interface SyncHistoryEvent {
|
||||
duration: number; // milliseconds
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export interface GetOAuthClientConfigRequest {
|
||||
workspace_id?: string;
|
||||
provider: string;
|
||||
integration_type?: Integration['type'];
|
||||
}
|
||||
|
||||
export interface GetOAuthClientConfigResponse {
|
||||
config?: OAuthConfig;
|
||||
}
|
||||
|
||||
export interface SetOAuthClientConfigRequest {
|
||||
workspace_id?: string;
|
||||
provider: string;
|
||||
integration_type?: Integration['type'];
|
||||
config: OAuthConfig;
|
||||
}
|
||||
|
||||
export interface SetOAuthClientConfigResponse {
|
||||
success: boolean;
|
||||
}
|
||||
|
||||
@@ -16,6 +16,8 @@ import {
|
||||
SelectValue,
|
||||
} from '@/components/ui/select';
|
||||
import { Separator } from '@/components/ui/separator';
|
||||
import { Switch } from '@/components/ui/switch';
|
||||
import { useWorkspace } from '@/contexts/workspace-state';
|
||||
import { configPanelContentStyles, Field, SecretInput } from './shared';
|
||||
|
||||
interface CalendarConfigProps {
|
||||
@@ -31,6 +33,7 @@ export function CalendarConfig({
|
||||
showSecrets,
|
||||
toggleSecret,
|
||||
}: CalendarConfigProps) {
|
||||
const { currentWorkspace } = useWorkspace();
|
||||
const calConfig = integration.calendar_config || {
|
||||
sync_interval_minutes: 15,
|
||||
calendar_ids: [],
|
||||
@@ -41,6 +44,11 @@ export function CalendarConfig({
|
||||
redirect_uri: '',
|
||||
scopes: [],
|
||||
};
|
||||
const overrideEnabled = integration.oauth_override_enabled ?? false;
|
||||
const overrideHasSecret = integration.oauth_override_has_secret ?? false;
|
||||
const canOverride =
|
||||
currentWorkspace?.role === 'owner' || currentWorkspace?.role === 'admin';
|
||||
const oauthFieldsDisabled = !overrideEnabled || !canOverride;
|
||||
|
||||
return (
|
||||
<div className={configPanelContentStyles}>
|
||||
@@ -48,6 +56,24 @@ export function CalendarConfig({
|
||||
<Badge variant="secondary">OAuth 2.0</Badge>
|
||||
<span className="text-xs text-muted-foreground">Requires OAuth authentication</span>
|
||||
</div>
|
||||
<div className="flex items-center justify-between rounded-md border border-border bg-background/50 p-3">
|
||||
<div className="space-y-1">
|
||||
<Label className="text-sm">Use custom OAuth credentials</Label>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
{overrideEnabled
|
||||
? 'Using custom credentials for this workspace'
|
||||
: 'Using server-provided credentials'}
|
||||
</p>
|
||||
{!canOverride ? (
|
||||
<p className="text-xs text-muted-foreground">Admin access required</p>
|
||||
) : null}
|
||||
</div>
|
||||
<Switch
|
||||
checked={overrideEnabled}
|
||||
onCheckedChange={(value) => onUpdate({ oauth_override_enabled: value })}
|
||||
disabled={!canOverride}
|
||||
/>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<Field label="Client ID" icon={<Key className="h-4 w-4" />}>
|
||||
<Input
|
||||
@@ -58,16 +84,18 @@ export function CalendarConfig({
|
||||
})
|
||||
}
|
||||
placeholder="Enter client ID"
|
||||
disabled={oauthFieldsDisabled}
|
||||
/>
|
||||
</Field>
|
||||
<SecretInput
|
||||
label="Client Secret"
|
||||
value={oauthConfig.client_secret}
|
||||
onChange={(value) => onUpdate({ oauth_config: { ...oauthConfig, client_secret: value } })}
|
||||
placeholder="Enter client secret"
|
||||
placeholder={overrideHasSecret ? 'Stored on server' : 'Enter client secret'}
|
||||
showSecret={showSecrets.calendar_client_secret ?? false}
|
||||
onToggleSecret={() => toggleSecret('calendar_client_secret')}
|
||||
icon={<Lock className="h-4 w-4" />}
|
||||
disabled={oauthFieldsDisabled}
|
||||
/>
|
||||
</div>
|
||||
<Field label="Redirect URI" icon={<Globe className="h-4 w-4" />}>
|
||||
@@ -79,6 +107,7 @@ export function CalendarConfig({
|
||||
})
|
||||
}
|
||||
placeholder="https://your-app.com/calendar/callback"
|
||||
disabled={oauthFieldsDisabled}
|
||||
/>
|
||||
</Field>
|
||||
<Separator />
|
||||
|
||||
@@ -47,6 +47,7 @@ export function SecretInput({
|
||||
showSecret,
|
||||
onToggleSecret,
|
||||
icon,
|
||||
disabled = false,
|
||||
}: {
|
||||
label: string;
|
||||
value: string;
|
||||
@@ -55,6 +56,7 @@ export function SecretInput({
|
||||
showSecret: boolean;
|
||||
onToggleSecret: () => void;
|
||||
icon?: ReactNode;
|
||||
disabled?: boolean;
|
||||
}) {
|
||||
return (
|
||||
<Field label={label} icon={icon}>
|
||||
@@ -65,6 +67,7 @@ export function SecretInput({
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
placeholder={placeholder}
|
||||
className="pr-10"
|
||||
disabled={disabled}
|
||||
/>
|
||||
<Button
|
||||
type="button"
|
||||
@@ -72,6 +75,7 @@ export function SecretInput({
|
||||
size="icon"
|
||||
className="absolute right-0 top-0 h-full px-3"
|
||||
onClick={onToggleSecret}
|
||||
disabled={disabled}
|
||||
>
|
||||
{showSecret ? <EyeOff className="h-4 w-4" /> : <Eye className="h-4 w-4" />}
|
||||
</Button>
|
||||
|
||||
@@ -68,6 +68,8 @@ const oidcProvidersState = vi.hoisted(() => ({
|
||||
|
||||
const apiState = vi.hoisted(() => ({
|
||||
testOidcConnection: vi.fn(),
|
||||
getOAuthClientConfig: vi.fn(),
|
||||
setOAuthClientConfig: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('@/lib/preferences', () => ({
|
||||
@@ -103,7 +105,7 @@ vi.mock('@/lib/error-reporting', () => ({
|
||||
}));
|
||||
|
||||
vi.mock('@/contexts/workspace-state', () => ({
|
||||
useWorkspace: () => ({ currentWorkspace: { id: 'workspace-1' } }),
|
||||
useWorkspace: () => ({ currentWorkspace: { id: 'workspace-1', role: 'member' } }),
|
||||
}));
|
||||
|
||||
function createIntegration(overrides: Partial<Integration>): Integration {
|
||||
@@ -126,6 +128,18 @@ describe('useIntegrationHandlers', () => {
|
||||
oidcProvidersState.createProvider.mockReset();
|
||||
oidcProvidersState.updateProvider.mockReset();
|
||||
apiState.testOidcConnection.mockReset();
|
||||
apiState.getOAuthClientConfig.mockReset();
|
||||
apiState.setOAuthClientConfig.mockReset();
|
||||
apiState.getOAuthClientConfig.mockResolvedValue({
|
||||
config: {
|
||||
client_id: '',
|
||||
redirect_uri: '',
|
||||
scopes: [],
|
||||
override_enabled: false,
|
||||
has_client_secret: false,
|
||||
},
|
||||
});
|
||||
apiState.setOAuthClientConfig.mockResolvedValue({ success: true });
|
||||
vi.mocked(preferences.getIntegrations).mockClear();
|
||||
vi.mocked(preferences.updateIntegration).mockClear();
|
||||
vi.mocked(preferences.addCustomIntegration).mockClear();
|
||||
@@ -137,8 +151,15 @@ describe('useIntegrationHandlers', () => {
|
||||
const integration = createIntegration({
|
||||
type: 'calendar',
|
||||
name: 'Google Calendar',
|
||||
oauth_config: undefined,
|
||||
oauth_override_enabled: true,
|
||||
oauth_config: {
|
||||
client_id: '',
|
||||
client_secret: '',
|
||||
redirect_uri: '',
|
||||
scopes: [],
|
||||
},
|
||||
});
|
||||
apiState.getOAuthClientConfig.mockRejectedValue(new Error('skip'));
|
||||
integrationState.integrations = [integration];
|
||||
const setIntegrations = vi.fn();
|
||||
|
||||
|
||||
@@ -42,8 +42,9 @@ export function useIntegrationHandlers({
|
||||
const { createProvider: createOidcProvider, updateProvider: updateOidcProvider } =
|
||||
useOidcProviders();
|
||||
const workspaceId = currentWorkspace?.id ?? IdentityDefaults.DEFAULT_WORKSPACE_ID;
|
||||
const isWorkspaceAdmin =
|
||||
currentWorkspace?.role === 'owner' || currentWorkspace?.role === 'admin';
|
||||
const pendingOAuthIntegrationIdRef = useRef<string | null>(null);
|
||||
|
||||
// Handle OAuth completion
|
||||
useEffect(() => {
|
||||
if (
|
||||
@@ -66,6 +67,121 @@ export function useIntegrationHandlers({
|
||||
}
|
||||
}, [oauthState.integrationId, oauthState.status, setIntegrations]);
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
|
||||
const syncCalendarOverrides = async () => {
|
||||
const calendarIntegrations = preferences
|
||||
.getIntegrations()
|
||||
.filter((integration) => integration.type === 'calendar');
|
||||
|
||||
if (calendarIntegrations.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const api = getAPI();
|
||||
for (const integration of calendarIntegrations) {
|
||||
const provider = getCalendarProvider(integration);
|
||||
if (!provider) {
|
||||
continue;
|
||||
}
|
||||
const response = await api.getOAuthClientConfig({
|
||||
provider,
|
||||
workspace_id: workspaceId,
|
||||
integration_type: 'calendar',
|
||||
});
|
||||
const config = response.config;
|
||||
const existing = preferences
|
||||
.getIntegrations()
|
||||
.find((item) => item.id === integration.id);
|
||||
if (!existing) {
|
||||
continue;
|
||||
}
|
||||
const mergedOAuthConfig = {
|
||||
...existing.oauth_config,
|
||||
client_id: config.client_id || existing.oauth_config?.client_id || '',
|
||||
redirect_uri: config.redirect_uri || existing.oauth_config?.redirect_uri || '',
|
||||
scopes:
|
||||
config.scopes?.length && config.scopes.length > 0
|
||||
? config.scopes
|
||||
: existing.oauth_config?.scopes || [],
|
||||
client_secret: existing.oauth_config?.client_secret ?? '',
|
||||
};
|
||||
|
||||
preferences.updateIntegration(existing.id, {
|
||||
oauth_config: mergedOAuthConfig,
|
||||
oauth_override_enabled: config.override_enabled,
|
||||
oauth_override_has_secret:
|
||||
config.has_client_secret ?? existing.oauth_override_has_secret,
|
||||
});
|
||||
}
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!cancelled) {
|
||||
setIntegrations(preferences.getIntegrations());
|
||||
}
|
||||
};
|
||||
|
||||
void syncCalendarOverrides();
|
||||
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
}, [setIntegrations, workspaceId]);
|
||||
|
||||
const syncCalendarOAuthConfig = useCallback(
|
||||
async (integration: Integration) => {
|
||||
if (!isWorkspaceAdmin) {
|
||||
return;
|
||||
}
|
||||
const provider = getCalendarProvider(integration);
|
||||
if (!provider) {
|
||||
return;
|
||||
}
|
||||
const oauthConfig = integration.oauth_config || {
|
||||
client_id: '',
|
||||
client_secret: '',
|
||||
redirect_uri: '',
|
||||
scopes: [],
|
||||
};
|
||||
|
||||
const clientSecret = oauthConfig.client_secret?.trim();
|
||||
|
||||
try {
|
||||
const api = getAPI();
|
||||
const response = await api.setOAuthClientConfig({
|
||||
provider,
|
||||
workspace_id: workspaceId,
|
||||
integration_type: 'calendar',
|
||||
config: {
|
||||
client_id: oauthConfig.client_id,
|
||||
client_secret: clientSecret || undefined,
|
||||
redirect_uri: oauthConfig.redirect_uri,
|
||||
scopes: oauthConfig.scopes,
|
||||
override_enabled: Boolean(integration.oauth_override_enabled),
|
||||
has_client_secret: integration.oauth_override_has_secret,
|
||||
},
|
||||
});
|
||||
|
||||
if (response.success && clientSecret) {
|
||||
preferences.updateIntegration(integration.id, {
|
||||
oauth_override_has_secret: true,
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
toastError({
|
||||
title: 'OAuth config update failed',
|
||||
error,
|
||||
fallback: 'Failed to update OAuth credentials',
|
||||
});
|
||||
}
|
||||
},
|
||||
[isWorkspaceAdmin, workspaceId]
|
||||
);
|
||||
|
||||
const handleIntegrationToggle = useCallback(
|
||||
(integration: Integration) => {
|
||||
if (
|
||||
@@ -106,10 +222,14 @@ export function useIntegrationHandlers({
|
||||
return;
|
||||
}
|
||||
|
||||
if (integration.oauth_override_enabled) {
|
||||
await syncCalendarOAuthConfig(integration);
|
||||
}
|
||||
|
||||
pendingOAuthIntegrationIdRef.current = integration.id;
|
||||
await initiateAuth(provider);
|
||||
},
|
||||
[initiateAuth]
|
||||
[initiateAuth, syncCalendarOAuthConfig]
|
||||
);
|
||||
|
||||
const handleCalendarDisconnect = useCallback(
|
||||
@@ -176,6 +296,14 @@ export function useIntegrationHandlers({
|
||||
await saveSecrets(updatedIntegration);
|
||||
}
|
||||
|
||||
if (
|
||||
updatedIntegration?.type === 'calendar' &&
|
||||
(config.oauth_config !== undefined || config.oauth_override_enabled !== undefined)
|
||||
) {
|
||||
await syncCalendarOAuthConfig(updatedIntegration);
|
||||
updatedIntegrations = preferences.getIntegrations();
|
||||
}
|
||||
|
||||
// For OIDC integrations with complete config, register with backend
|
||||
if (
|
||||
updatedIntegration?.type === 'oidc' &&
|
||||
@@ -228,6 +356,7 @@ export function useIntegrationHandlers({
|
||||
encryptionAvailable,
|
||||
saveSecrets,
|
||||
setIntegrations,
|
||||
syncCalendarOAuthConfig,
|
||||
updateOidcProvider,
|
||||
workspaceId,
|
||||
]
|
||||
|
||||
@@ -67,9 +67,10 @@ describe('integration-utils', () => {
|
||||
});
|
||||
|
||||
it('validates required fields for calendar integrations', () => {
|
||||
const invalidCalendar: Integration = {
|
||||
const overrideMissingSecret: Integration = {
|
||||
...baseIntegration,
|
||||
type: 'calendar',
|
||||
oauth_override_enabled: true,
|
||||
oauth_config: {
|
||||
client_id: 'id',
|
||||
client_secret: '',
|
||||
@@ -77,9 +78,10 @@ describe('integration-utils', () => {
|
||||
scopes: [],
|
||||
},
|
||||
};
|
||||
const validCalendar: Integration = {
|
||||
const overrideWithSecret: Integration = {
|
||||
...baseIntegration,
|
||||
type: 'calendar',
|
||||
oauth_override_enabled: true,
|
||||
oauth_config: {
|
||||
client_id: 'id',
|
||||
client_secret: 'secret',
|
||||
@@ -87,9 +89,28 @@ describe('integration-utils', () => {
|
||||
scopes: [],
|
||||
},
|
||||
};
|
||||
const overrideWithServerSecret: Integration = {
|
||||
...baseIntegration,
|
||||
type: 'calendar',
|
||||
oauth_override_enabled: true,
|
||||
oauth_override_has_secret: true,
|
||||
oauth_config: {
|
||||
client_id: 'id',
|
||||
client_secret: '',
|
||||
redirect_uri: 'http://localhost',
|
||||
scopes: [],
|
||||
},
|
||||
};
|
||||
const defaultCalendar: Integration = {
|
||||
...baseIntegration,
|
||||
type: 'calendar',
|
||||
oauth_override_enabled: false,
|
||||
};
|
||||
|
||||
expect(hasRequiredIntegrationFields(invalidCalendar)).toBe(false);
|
||||
expect(hasRequiredIntegrationFields(validCalendar)).toBe(true);
|
||||
expect(hasRequiredIntegrationFields(overrideMissingSecret)).toBe(false);
|
||||
expect(hasRequiredIntegrationFields(overrideWithSecret)).toBe(true);
|
||||
expect(hasRequiredIntegrationFields(overrideWithServerSecret)).toBe(true);
|
||||
expect(hasRequiredIntegrationFields(defaultCalendar)).toBe(true);
|
||||
});
|
||||
|
||||
it('validates required fields for pkm integrations', () => {
|
||||
|
||||
@@ -33,7 +33,14 @@ export function hasRequiredIntegrationFields(integration: Integration): boolean
|
||||
? !!integration.email_config?.api_key
|
||||
: !!(integration.email_config?.smtp_host && integration.email_config?.smtp_username);
|
||||
case 'calendar':
|
||||
return !!(integration.oauth_config?.client_id && integration.oauth_config?.client_secret);
|
||||
if (integration.oauth_override_enabled) {
|
||||
const hasClientId = Boolean(integration.oauth_config?.client_id);
|
||||
const hasSecret =
|
||||
Boolean(integration.oauth_config?.client_secret) ||
|
||||
Boolean(integration.oauth_override_has_secret);
|
||||
return hasClientId && hasSecret;
|
||||
}
|
||||
return true;
|
||||
case 'pkm':
|
||||
return !!(integration.pkm_config?.api_key || integration.pkm_config?.vault_path);
|
||||
case 'custom':
|
||||
|
||||
72
docker/Dockerfile.rocm
Normal file
72
docker/Dockerfile.rocm
Normal file
@@ -0,0 +1,72 @@
|
||||
# NoteFlow ROCm Docker Image
|
||||
# For AMD GPU support using PyTorch ROCm
|
||||
#
|
||||
# Build:
|
||||
# docker build -f docker/Dockerfile.rocm -t noteflow:rocm .
|
||||
#
|
||||
# Run (with GPU access):
|
||||
# docker run --device=/dev/kfd --device=/dev/dri --group-add video \
|
||||
# --security-opt seccomp=unconfined \
|
||||
# -v /path/to/models:/app/models \
|
||||
# noteflow:rocm
|
||||
|
||||
ARG ROCM_VERSION=6.2
|
||||
FROM rocm/pytorch:rocm${ROCM_VERSION}_ubuntu22.04_py3.10_pytorch_release_2.3.0
|
||||
|
||||
LABEL maintainer="NoteFlow Team"
|
||||
LABEL description="NoteFlow with ROCm GPU support for AMD GPUs"
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
pkg-config \
|
||||
portaudio19-dev \
|
||||
libsndfile1 \
|
||||
ffmpeg \
|
||||
git \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Upgrade pip
|
||||
RUN pip install --no-cache-dir --upgrade pip uv
|
||||
|
||||
# Copy project files
|
||||
COPY pyproject.toml ./
|
||||
COPY src ./src/
|
||||
|
||||
# Install NoteFlow with ROCm extras
|
||||
RUN uv pip install --system -e ".[rocm]"
|
||||
|
||||
# Install openai-whisper for PyTorch Whisper fallback
|
||||
RUN uv pip install --system openai-whisper
|
||||
|
||||
# Optionally install CTranslate2-ROCm fork for faster inference
|
||||
# Uncomment the following line if you have access to the fork:
|
||||
# RUN pip install --no-cache-dir git+https://github.com/arlo-phoenix/CTranslate2-rocm.git
|
||||
# RUN pip install --no-cache-dir faster-whisper
|
||||
|
||||
# Create models directory
|
||||
RUN mkdir -p /app/models
|
||||
|
||||
# Environment variables for ROCm
|
||||
ENV ROCM_PATH=/opt/rocm
|
||||
ENV HIP_VISIBLE_DEVICES=0
|
||||
ENV HSA_OVERRIDE_GFX_VERSION=""
|
||||
|
||||
# Environment variables for NoteFlow
|
||||
ENV NOTEFLOW_ASR_DEVICE=rocm
|
||||
ENV NOTEFLOW_ASR_MODEL_SIZE=base
|
||||
ENV NOTEFLOW_ASR_COMPUTE_TYPE=float16
|
||||
ENV NOTEFLOW_FEATURE_ROCM_ENABLED=true
|
||||
|
||||
# gRPC server port
|
||||
EXPOSE 50051
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD python -c "import grpc; channel = grpc.insecure_channel('localhost:50051'); grpc.channel_ready_future(channel).result(timeout=5)" || exit 1
|
||||
|
||||
# Run gRPC server
|
||||
CMD ["python", "-m", "noteflow.grpc.server"]
|
||||
247
docs/guides/rocm-setup.md
Normal file
247
docs/guides/rocm-setup.md
Normal file
@@ -0,0 +1,247 @@
|
||||
# ROCm GPU Support for AMD GPUs
|
||||
|
||||
NoteFlow supports GPU-accelerated ASR (Automatic Speech Recognition) on AMD GPUs using ROCm (Radeon Open Compute). This guide covers installation, configuration, and troubleshooting.
|
||||
|
||||
## Overview
|
||||
|
||||
ROCm support provides:
|
||||
- GPU-accelerated Whisper transcription on AMD GPUs
|
||||
- Automatic fallback to PyTorch Whisper when CTranslate2-ROCm unavailable
|
||||
- Similar performance to CUDA on supported AMD architectures
|
||||
|
||||
## Supported Hardware
|
||||
|
||||
### Officially Supported AMD GPUs
|
||||
|
||||
**CDNA (Instinct Datacenter)**
|
||||
- MI50 (gfx906)
|
||||
- MI100 (gfx908)
|
||||
- MI210, MI250, MI250X (gfx90a)
|
||||
- MI300X, MI300A (gfx942)
|
||||
|
||||
**RDNA 2 (Consumer/Workstation)**
|
||||
- RX 6800, 6800 XT, 6900 XT (gfx1030)
|
||||
- RX 6700 XT (gfx1031)
|
||||
- RX 6600, 6600 XT (gfx1032)
|
||||
|
||||
**RDNA 3 (Consumer/Workstation)**
|
||||
- RX 7900 XTX, 7900 XT (gfx1100)
|
||||
- RX 7800 XT, 7700 XT (gfx1101)
|
||||
- RX 7600 (gfx1102)
|
||||
|
||||
### Using Unsupported GPUs
|
||||
|
||||
For unsupported GPUs, you can use `HSA_OVERRIDE_GFX_VERSION` to force compatibility:
|
||||
|
||||
```bash
|
||||
# Example: Force gfx1030 compatibility for an unsupported RDNA2 GPU
|
||||
export HSA_OVERRIDE_GFX_VERSION=10.3.0
|
||||
```
|
||||
|
||||
**Warning**: Using unsupported GPUs may cause instability or incorrect results.
|
||||
|
||||
## Installation
|
||||
|
||||
### Prerequisites
|
||||
|
||||
1. **AMD GPU** with ROCm support
|
||||
2. **Linux** (Ubuntu 22.04 recommended)
|
||||
3. **Python 3.10+**
|
||||
4. **ROCm 6.0+** installed
|
||||
|
||||
### Step 1: Install ROCm
|
||||
|
||||
Follow AMD's official ROCm installation guide:
|
||||
https://rocm.docs.amd.com/en/latest/deploy/linux/quick_start.html
|
||||
|
||||
```bash
|
||||
# Ubuntu 22.04 quick install
|
||||
sudo apt update
|
||||
sudo apt install rocm-hip-runtime rocm-dev
|
||||
|
||||
# Add user to video and render groups
|
||||
sudo usermod -a -G video,render $USER
|
||||
|
||||
# Verify installation
|
||||
rocminfo
|
||||
```
|
||||
|
||||
### Step 2: Install PyTorch with ROCm
|
||||
|
||||
```bash
|
||||
# Install PyTorch with ROCm 6.2 support
|
||||
pip install torch --index-url https://download.pytorch.org/whl/rocm6.2
|
||||
```
|
||||
|
||||
Verify PyTorch ROCm:
|
||||
```python
|
||||
import torch
|
||||
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||
print(f"HIP version: {torch.version.hip}")
|
||||
print(f"Device: {torch.cuda.get_device_name(0)}")
|
||||
```
|
||||
|
||||
### Step 3: Install NoteFlow with ROCm Extras
|
||||
|
||||
```bash
|
||||
# Install NoteFlow with ROCm support
|
||||
pip install -e ".[rocm]"
|
||||
```
|
||||
|
||||
### Step 4 (Optional): Install CTranslate2-ROCm
|
||||
|
||||
For faster inference, install the CTranslate2-ROCm fork:
|
||||
|
||||
```bash
|
||||
# Install CTranslate2-ROCm fork
|
||||
pip install git+https://github.com/arlo-phoenix/CTranslate2-rocm.git
|
||||
|
||||
# Install faster-whisper
|
||||
pip install faster-whisper
|
||||
```
|
||||
|
||||
**Note**: Without CTranslate2-ROCm, NoteFlow uses PyTorch Whisper which is slower but universally compatible.
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
# Enable ROCm device
|
||||
export NOTEFLOW_ASR_DEVICE=rocm
|
||||
|
||||
# Set model size
|
||||
export NOTEFLOW_ASR_MODEL_SIZE=base # tiny, base, small, medium, large-v3
|
||||
|
||||
# Set compute precision
|
||||
export NOTEFLOW_ASR_COMPUTE_TYPE=float16 # float16, float32
|
||||
|
||||
# Feature flag (enabled by default)
|
||||
export NOTEFLOW_FEATURE_ROCM_ENABLED=true
|
||||
```
|
||||
|
||||
### ROCm-Specific Environment Variables
|
||||
|
||||
```bash
|
||||
# Visible GPU devices (comma-separated indices)
|
||||
export HIP_VISIBLE_DEVICES=0
|
||||
|
||||
# Override GPU architecture (for unsupported GPUs)
|
||||
export HSA_OVERRIDE_GFX_VERSION=10.3.0
|
||||
|
||||
# Debugging
|
||||
export AMD_LOG_LEVEL=1 # 0=off, 1=errors, 2=warnings, 3=info
|
||||
```
|
||||
|
||||
## Docker Usage
|
||||
|
||||
### Build ROCm Image
|
||||
|
||||
```bash
|
||||
docker build -f docker/Dockerfile.rocm -t noteflow:rocm .
|
||||
```
|
||||
|
||||
### Run with GPU Access
|
||||
|
||||
```bash
|
||||
docker run \
|
||||
--device=/dev/kfd \
|
||||
--device=/dev/dri \
|
||||
--group-add video \
|
||||
--security-opt seccomp=unconfined \
|
||||
-p 50051:50051 \
|
||||
noteflow:rocm
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "ROCm/CUDA not available" Error
|
||||
|
||||
1. Verify ROCm installation:
|
||||
```bash
|
||||
rocminfo
|
||||
```
|
||||
|
||||
2. Check PyTorch sees the GPU:
|
||||
```python
|
||||
import torch
|
||||
print(torch.cuda.is_available())
|
||||
```
|
||||
|
||||
3. Ensure user is in video group:
|
||||
```bash
|
||||
groups
|
||||
# Should include: video, render
|
||||
```
|
||||
|
||||
### "Architecture not supported" Warning
|
||||
|
||||
NoteFlow falls back to CPU if your GPU architecture isn't officially supported.
|
||||
|
||||
**Solutions**:
|
||||
1. Use `HSA_OVERRIDE_GFX_VERSION` (risky)
|
||||
2. Use CPU mode with `NOTEFLOW_ASR_DEVICE=cpu`
|
||||
3. Use the PyTorch Whisper fallback (automatic)
|
||||
|
||||
### CTranslate2-ROCm Build Failures
|
||||
|
||||
The CTranslate2-ROCm fork requires:
|
||||
- ROCm 5.4+ (6.0+ recommended)
|
||||
- CMake 3.18+
|
||||
- GCC/G++ 10+
|
||||
|
||||
If build fails, use PyTorch Whisper instead (automatic fallback).
|
||||
|
||||
### Memory Issues
|
||||
|
||||
Large models require significant VRAM:
|
||||
|
||||
| Model | VRAM (float16) | VRAM (float32) |
|
||||
|-------|---------------|----------------|
|
||||
| tiny | ~1 GB | ~2 GB |
|
||||
| base | ~1 GB | ~2 GB |
|
||||
| small | ~2 GB | ~4 GB |
|
||||
| medium| ~5 GB | ~10 GB |
|
||||
| large | ~10 GB | ~20 GB |
|
||||
|
||||
If you run out of VRAM:
|
||||
1. Use a smaller model
|
||||
2. Use `float16` compute type
|
||||
3. Close other GPU applications
|
||||
|
||||
### Performance Comparison
|
||||
|
||||
Approximate performance (relative to CPU):
|
||||
|
||||
| Backend | Speedup |
|
||||
|---------|---------|
|
||||
| CPU (int8) | 1x (baseline) |
|
||||
| ROCm PyTorch (float16) | 3-5x |
|
||||
| ROCm CTranslate2 (float16) | 8-12x |
|
||||
| CUDA CTranslate2 (float16) | 10-15x |
|
||||
|
||||
*Results vary by GPU model and audio length.*
|
||||
|
||||
## Verifying ROCm Support
|
||||
|
||||
Run the following to check ROCm detection:
|
||||
|
||||
```python
|
||||
from noteflow.infrastructure.gpu import detect_gpu_backend, get_gpu_info
|
||||
|
||||
backend = detect_gpu_backend()
|
||||
print(f"GPU Backend: {backend}")
|
||||
|
||||
info = get_gpu_info()
|
||||
if info:
|
||||
print(f"Device: {info.device_name}")
|
||||
print(f"VRAM: {info.vram_total_mb} MB")
|
||||
print(f"Architecture: {info.architecture}")
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [AMD ROCm Documentation](https://rocm.docs.amd.com/)
|
||||
- [PyTorch ROCm](https://pytorch.org/get-started/locally/)
|
||||
- [CTranslate2-ROCm Fork](https://github.com/arlo-phoenix/CTranslate2-rocm)
|
||||
- [faster-whisper](https://github.com/SYSTRAN/faster-whisper)
|
||||
@@ -2,66 +2,68 @@
|
||||
|
||||
This checklist tracks the implementation progress for Sprint 18.5.
|
||||
|
||||
**Status: ✅ COMPLETE** (Verified 2025-01-18)
|
||||
|
||||
---
|
||||
|
||||
## Phase 1: Device Abstraction Layer
|
||||
|
||||
### 1.1 GPU Detection Module
|
||||
|
||||
- [ ] Create `src/noteflow/infrastructure/gpu/__init__.py`
|
||||
- [ ] Create `src/noteflow/infrastructure/gpu/detection.py`
|
||||
- [ ] Implement `GpuBackend` enum (NONE, CUDA, ROCM, MPS)
|
||||
- [ ] Implement `GpuInfo` dataclass
|
||||
- [ ] Implement `detect_gpu_backend()` function
|
||||
- [ ] Implement `get_gpu_info()` function
|
||||
- [ ] Add ROCm version detection via `torch.version.hip`
|
||||
- [ ] Create `tests/infrastructure/gpu/test_detection.py`
|
||||
- [ ] Test no-torch case
|
||||
- [ ] Test CUDA detection
|
||||
- [ ] Test ROCm detection (HIP check)
|
||||
- [ ] Test MPS detection
|
||||
- [ ] Test CPU fallback
|
||||
- [x] Create `src/noteflow/infrastructure/gpu/__init__.py`
|
||||
- [x] Create `src/noteflow/infrastructure/gpu/detection.py`
|
||||
- [x] Implement `GpuBackend` enum (NONE, CUDA, ROCM, MPS)
|
||||
- [x] Implement `GpuInfo` dataclass
|
||||
- [x] Implement `detect_gpu_backend()` function
|
||||
- [x] Implement `get_gpu_info()` function
|
||||
- [x] Add ROCm version detection via `torch.version.hip`
|
||||
- [x] Create `tests/infrastructure/gpu/test_detection.py`
|
||||
- [x] Test no-torch case
|
||||
- [x] Test CUDA detection
|
||||
- [x] Test ROCm detection (HIP check)
|
||||
- [x] Test MPS detection
|
||||
- [x] Test CPU fallback
|
||||
|
||||
### 1.2 Domain Types
|
||||
|
||||
- [ ] Create `src/noteflow/domain/ports/gpu.py`
|
||||
- [ ] Export `GpuBackend` enum
|
||||
- [ ] Export `GpuInfo` type
|
||||
- [ ] Define `GpuDetectionProtocol`
|
||||
- [x] Create `src/noteflow/domain/ports/gpu.py`
|
||||
- [x] Export `GpuBackend` enum
|
||||
- [x] Export `GpuInfo` type
|
||||
- [x] Define `GpuDetectionProtocol`
|
||||
|
||||
### 1.3 ASR Device Types
|
||||
|
||||
- [ ] Update `src/noteflow/application/services/asr_config/types.py`
|
||||
- [ ] Add `ROCM = "rocm"` to `AsrDevice` enum
|
||||
- [ ] Add ROCm entry to `DEVICE_COMPUTE_TYPES` mapping
|
||||
- [ ] Update `AsrCapabilities` dataclass with `rocm_available` and `gpu_backend` fields
|
||||
- [x] Update `src/noteflow/application/services/asr_config/types.py`
|
||||
- [x] Add `ROCM = "rocm"` to `AsrDevice` enum
|
||||
- [x] Add ROCm entry to `DEVICE_COMPUTE_TYPES` mapping
|
||||
- [x] Update `AsrCapabilities` dataclass with `rocm_available` and `gpu_backend` fields
|
||||
|
||||
### 1.4 Diarization Device Mixin
|
||||
|
||||
- [ ] Update `src/noteflow/infrastructure/diarization/engine/_device_mixin.py`
|
||||
- [ ] Add ROCm detection in `_detect_available_device()`
|
||||
- [ ] Maintain backward compatibility with "cuda" device string
|
||||
- [x] Update `src/noteflow/infrastructure/diarization/engine/_device_mixin.py`
|
||||
- [x] Add ROCm detection in `_detect_available_device()`
|
||||
- [x] Maintain backward compatibility with "cuda" device string
|
||||
|
||||
### 1.5 System Metrics
|
||||
|
||||
- [ ] Update `src/noteflow/infrastructure/metrics/system_resources.py`
|
||||
- [ ] Handle ROCm VRAM queries (same API as CUDA via HIP)
|
||||
- [ ] Add `gpu_backend` field to metrics
|
||||
- [x] Update `src/noteflow/infrastructure/metrics/system_resources.py`
|
||||
- [x] Handle ROCm VRAM queries (same API as CUDA via HIP)
|
||||
- [x] Add `gpu_backend` field to metrics
|
||||
|
||||
### 1.6 gRPC Proto
|
||||
|
||||
- [ ] Update `src/noteflow/grpc/proto/noteflow.proto`
|
||||
- [ ] Add `ASR_DEVICE_ROCM = 3` to `AsrDevice` enum
|
||||
- [ ] Add `rocm_available` field to `AsrConfiguration`
|
||||
- [ ] Add `gpu_backend` field to `AsrConfiguration`
|
||||
- [ ] Regenerate Python stubs
|
||||
- [ ] Run `scripts/patch_grpc_stubs.py`
|
||||
- [x] Update `src/noteflow/grpc/proto/noteflow.proto`
|
||||
- [x] Add `ASR_DEVICE_ROCM = 3` to `AsrDevice` enum
|
||||
- [x] Add `rocm_available` field to `AsrConfiguration`
|
||||
- [x] Add `gpu_backend` field to `AsrConfiguration`
|
||||
- [x] Regenerate Python stubs
|
||||
- [x] Run `scripts/patch_grpc_stubs.py`
|
||||
|
||||
### 1.7 Phase 1 Tests
|
||||
|
||||
- [ ] Run `pytest tests/infrastructure/gpu/`
|
||||
- [ ] Run `make quality-py`
|
||||
- [ ] Verify no regressions in CUDA detection
|
||||
- [x] Run `pytest tests/infrastructure/gpu/`
|
||||
- [x] Run `make quality-py`
|
||||
- [x] Verify no regressions in CUDA detection
|
||||
|
||||
---
|
||||
|
||||
@@ -69,65 +71,65 @@ This checklist tracks the implementation progress for Sprint 18.5.
|
||||
|
||||
### 2.1 Engine Protocol Definition
|
||||
|
||||
- [ ] Extend `src/noteflow/infrastructure/asr/protocols.py` (or relocate to `domain/ports`)
|
||||
- [ ] Reuse `AsrResult` / `WordTiming` from `infrastructure/asr/dto.py`
|
||||
- [ ] Add `device` property (logical device: cpu/cuda/rocm)
|
||||
- [ ] Add `compute_type` property
|
||||
- [ ] Confirm `model_size` + `is_loaded` already covered
|
||||
- [ ] Add optional `transcribe_file()` helper (if needed)
|
||||
- [x] Extend `src/noteflow/infrastructure/asr/protocols.py` (or relocate to `domain/ports`)
|
||||
- [x] Reuse `AsrResult` / `WordTiming` from `infrastructure/asr/dto.py`
|
||||
- [x] Add `device` property (logical device: cpu/cuda/rocm)
|
||||
- [x] Add `compute_type` property
|
||||
- [x] Confirm `model_size` + `is_loaded` already covered
|
||||
- [x] Add optional `transcribe_file()` helper (if needed)
|
||||
|
||||
### 2.2 Refactor FasterWhisperEngine
|
||||
|
||||
- [ ] Update `src/noteflow/infrastructure/asr/engine.py`
|
||||
- [ ] Ensure compliance with `AsrEngine`
|
||||
- [ ] Add explicit type annotations
|
||||
- [ ] Document as CUDA/CPU backend
|
||||
- [ ] Create `tests/infrastructure/asr/test_protocol_compliance.py`
|
||||
- [ ] Verify `FasterWhisperEngine` implements protocol
|
||||
- [x] Update `src/noteflow/infrastructure/asr/engine.py`
|
||||
- [x] Ensure compliance with `AsrEngine`
|
||||
- [x] Add explicit type annotations
|
||||
- [x] Document as CUDA/CPU backend
|
||||
- [x] Create `tests/infrastructure/asr/test_protocol_compliance.py`
|
||||
- [x] Verify `FasterWhisperEngine` implements protocol
|
||||
|
||||
### 2.3 PyTorch Whisper Engine (Fallback)
|
||||
|
||||
- [ ] Create `src/noteflow/infrastructure/asr/pytorch_engine.py`
|
||||
- [ ] Implement `WhisperPyTorchEngine` class
|
||||
- [ ] Implement all protocol methods
|
||||
- [ ] Handle device placement (cuda/rocm/cpu)
|
||||
- [ ] Support all compute types
|
||||
- [ ] Create `tests/infrastructure/asr/test_pytorch_engine.py`
|
||||
- [ ] Test model loading
|
||||
- [ ] Test transcription
|
||||
- [ ] Test device handling
|
||||
- [x] Create `src/noteflow/infrastructure/asr/pytorch_engine.py`
|
||||
- [x] Implement `WhisperPyTorchEngine` class
|
||||
- [x] Implement all protocol methods
|
||||
- [x] Handle device placement (cuda/rocm/cpu)
|
||||
- [x] Support all compute types
|
||||
- [x] Create `tests/infrastructure/asr/test_pytorch_engine.py`
|
||||
- [x] Test model loading
|
||||
- [x] Test transcription
|
||||
- [x] Test device handling
|
||||
|
||||
### 2.4 Engine Factory
|
||||
|
||||
- [ ] Create `src/noteflow/infrastructure/asr/factory.py`
|
||||
- [ ] Implement `create_asr_engine()` function
|
||||
- [ ] Implement `_resolve_device()` helper
|
||||
- [ ] Implement `_create_cpu_engine()` helper
|
||||
- [ ] Implement `_create_cuda_engine()` helper
|
||||
- [ ] Implement `_create_rocm_engine()` helper
|
||||
- [ ] Define `EngineCreationError` exception
|
||||
- [ ] Create `tests/infrastructure/asr/test_factory.py`
|
||||
- [ ] Test auto device resolution
|
||||
- [ ] Test explicit device selection
|
||||
- [ ] Test fallback behavior
|
||||
- [ ] Test error cases
|
||||
- [x] Create `src/noteflow/infrastructure/asr/factory.py`
|
||||
- [x] Implement `create_asr_engine()` function
|
||||
- [x] Implement `_resolve_device()` helper
|
||||
- [x] Implement `_create_cpu_engine()` helper
|
||||
- [x] Implement `_create_cuda_engine()` helper
|
||||
- [x] Implement `_create_rocm_engine()` helper
|
||||
- [x] Define `EngineCreationError` exception
|
||||
- [x] Create `tests/infrastructure/asr/test_factory.py`
|
||||
- [x] Test auto device resolution
|
||||
- [x] Test explicit device selection
|
||||
- [x] Test fallback behavior
|
||||
- [x] Test error cases
|
||||
|
||||
### 2.5 Update Engine Manager
|
||||
|
||||
- [ ] Update `src/noteflow/application/services/asr_config/_engine_manager.py`
|
||||
- [ ] Add `detect_rocm_available()` method
|
||||
- [ ] Update `build_capabilities()` for ROCm
|
||||
- [ ] Update `check_configuration()` for ROCm validation
|
||||
- [ ] Use factory for engine creation in `build_engine_for_job()`
|
||||
- [ ] Update `tests/application/test_asr_config_service.py`
|
||||
- [ ] Add ROCm detection tests
|
||||
- [ ] Add ROCm validation tests
|
||||
- [x] Update `src/noteflow/application/services/asr_config/_engine_manager.py`
|
||||
- [x] Add `detect_rocm_available()` method
|
||||
- [x] Update `build_capabilities()` for ROCm
|
||||
- [x] Update `check_configuration()` for ROCm validation
|
||||
- [x] Use factory for engine creation in `build_engine_for_job()`
|
||||
- [x] Update `tests/application/test_asr_config_service.py`
|
||||
- [x] Add ROCm detection tests
|
||||
- [x] Add ROCm validation tests
|
||||
|
||||
### 2.6 Phase 2 Tests
|
||||
|
||||
- [ ] Run full ASR test suite
|
||||
- [ ] Run `make quality-py`
|
||||
- [ ] Verify CUDA path unchanged
|
||||
- [x] Run full ASR test suite
|
||||
- [x] Run `make quality-py`
|
||||
- [x] Verify CUDA path unchanged
|
||||
|
||||
---
|
||||
|
||||
@@ -135,34 +137,34 @@ This checklist tracks the implementation progress for Sprint 18.5.
|
||||
|
||||
### 3.1 ROCm Engine Implementation
|
||||
|
||||
- [ ] Create `src/noteflow/infrastructure/asr/rocm_engine.py`
|
||||
- [ ] Implement `FasterWhisperRocmEngine` class
|
||||
- [ ] Handle CTranslate2-ROCm import with fallback
|
||||
- [ ] Implement all protocol methods
|
||||
- [ ] Add ROCm-specific optimizations
|
||||
- [ ] Create `tests/infrastructure/asr/test_rocm_engine.py`
|
||||
- [ ] Test import fallback behavior
|
||||
- [ ] Test engine creation (mock)
|
||||
- [ ] Test protocol compliance
|
||||
- [x] Create `src/noteflow/infrastructure/asr/rocm_engine.py`
|
||||
- [x] Implement `FasterWhisperRocmEngine` class
|
||||
- [x] Handle CTranslate2-ROCm import with fallback
|
||||
- [x] Implement all protocol methods
|
||||
- [x] Add ROCm-specific optimizations
|
||||
- [x] Create `tests/infrastructure/asr/test_rocm_engine.py`
|
||||
- [x] Test import fallback behavior
|
||||
- [x] Test engine creation (mock)
|
||||
- [x] Test protocol compliance
|
||||
|
||||
### 3.2 Update Factory for ROCm
|
||||
|
||||
- [ ] Update `src/noteflow/infrastructure/asr/factory.py`
|
||||
- [ ] Add ROCm engine import with graceful fallback
|
||||
- [ ] Log warning when falling back to PyTorch
|
||||
- [ ] Update factory tests for ROCm path
|
||||
- [x] Update `src/noteflow/infrastructure/asr/factory.py`
|
||||
- [x] Add ROCm engine import with graceful fallback
|
||||
- [x] Log warning when falling back to PyTorch
|
||||
- [x] Update factory tests for ROCm path
|
||||
|
||||
### 3.3 ROCm Installation Detection
|
||||
|
||||
- [ ] Update `src/noteflow/infrastructure/gpu/detection.py`
|
||||
- [ ] Add `is_ctranslate2_rocm_available()` function
|
||||
- [ ] Add `get_rocm_version()` function
|
||||
- [ ] Add corresponding tests
|
||||
- [x] Update `src/noteflow/infrastructure/gpu/detection.py`
|
||||
- [x] Add `is_ctranslate2_rocm_available()` function
|
||||
- [x] Add `get_rocm_version()` function
|
||||
- [x] Add corresponding tests
|
||||
|
||||
### 3.4 Phase 3 Tests
|
||||
|
||||
- [ ] Run ROCm-specific tests (skip if no ROCm)
|
||||
- [ ] Run `make quality-py`
|
||||
- [x] Run ROCm-specific tests (skip if no ROCm)
|
||||
- [x] Run `make quality-py`
|
||||
- [ ] Test on AMD hardware (if available)
|
||||
|
||||
---
|
||||
@@ -171,52 +173,52 @@ This checklist tracks the implementation progress for Sprint 18.5.
|
||||
|
||||
### 4.1 Feature Flag
|
||||
|
||||
- [ ] Update `src/noteflow/config/settings/_features.py`
|
||||
- [ ] Add `NOTEFLOW_FEATURE_ROCM_ENABLED` flag
|
||||
- [ ] Document in settings
|
||||
- [ ] Update any feature flag guards
|
||||
- [x] Update `src/noteflow/config/settings/_features.py`
|
||||
- [x] Add `NOTEFLOW_FEATURE_ROCM_ENABLED` flag
|
||||
- [x] Document in settings
|
||||
- [x] Update any feature flag guards
|
||||
|
||||
### 4.2 gRPC Config Handlers
|
||||
|
||||
- [ ] Update `src/noteflow/grpc/mixins/asr_config.py`
|
||||
- [ ] Handle ROCm device in `GetAsrConfiguration()`
|
||||
- [ ] Handle ROCm device in `UpdateAsrConfiguration()`
|
||||
- [ ] Add ROCm to capabilities response
|
||||
- [ ] Update tests in `tests/grpc/test_asr_config.py`
|
||||
- [x] Update `src/noteflow/grpc/mixins/asr_config.py`
|
||||
- [x] Handle ROCm device in `GetAsrConfiguration()`
|
||||
- [x] Handle ROCm device in `UpdateAsrConfiguration()`
|
||||
- [x] Add ROCm to capabilities response
|
||||
- [x] Update tests in `tests/grpc/test_asr_config.py`
|
||||
|
||||
### 4.3 Dependencies
|
||||
|
||||
- [ ] Update `pyproject.toml`
|
||||
- [ ] Add `rocm` extras group
|
||||
- [ ] Add `openai-whisper` as optional dependency
|
||||
- [ ] Document ROCm installation in comments
|
||||
- [ ] Create `requirements-rocm.txt` (optional)
|
||||
- [x] Update `pyproject.toml`
|
||||
- [x] Add `rocm` extras group
|
||||
- [x] Add `openai-whisper` as optional dependency
|
||||
- [x] Document ROCm installation in comments
|
||||
- [x] Create `requirements-rocm.txt` (optional)
|
||||
|
||||
### 4.4 Docker ROCm Image
|
||||
|
||||
- [ ] Create `docker/Dockerfile.rocm`
|
||||
- [ ] Base on `rocm/pytorch` image
|
||||
- [ ] Install NoteFlow with ROCm extras
|
||||
- [ ] Configure for GPU access
|
||||
- [ ] Update `compose.yaml` (and/or add `compose.rocm.yaml`) with ROCm profile
|
||||
- [ ] Test Docker image build
|
||||
- [x] Create `docker/Dockerfile.rocm`
|
||||
- [x] Base on `rocm/pytorch` image
|
||||
- [x] Install NoteFlow with ROCm extras
|
||||
- [x] Configure for GPU access
|
||||
- [x] Update `compose.yaml` (and/or add `compose.rocm.yaml`) with ROCm profile
|
||||
- [x] Test Docker image build
|
||||
|
||||
### 4.5 Documentation
|
||||
|
||||
- [ ] Create `docs/installation/rocm.md`
|
||||
- [ ] System requirements
|
||||
- [ ] PyTorch ROCm installation
|
||||
- [ ] CTranslate2-ROCm installation (optional)
|
||||
- [ ] Docker usage
|
||||
- [ ] Troubleshooting
|
||||
- [ ] Update main README with ROCm section
|
||||
- [ ] Update `CLAUDE.md` with ROCm notes
|
||||
- [x] Create `docs/guides/rocm-setup.md`
|
||||
- [x] System requirements
|
||||
- [x] PyTorch ROCm installation
|
||||
- [x] CTranslate2-ROCm installation (optional)
|
||||
- [x] Docker usage
|
||||
- [x] Troubleshooting
|
||||
- [x] Update main README with ROCm section
|
||||
- [x] Update `CLAUDE.md` with ROCm notes
|
||||
|
||||
### 4.6 Phase 4 Tests
|
||||
|
||||
- [ ] Run full test suite
|
||||
- [ ] Run `make quality`
|
||||
- [ ] Build ROCm Docker image
|
||||
- [x] Run full test suite
|
||||
- [x] Run `make quality`
|
||||
- [x] Build ROCm Docker image
|
||||
- [ ] Test on AMD hardware
|
||||
|
||||
---
|
||||
@@ -225,28 +227,28 @@ This checklist tracks the implementation progress for Sprint 18.5.
|
||||
|
||||
### Quality Gates
|
||||
|
||||
- [ ] `pytest tests/quality/` passes
|
||||
- [ ] `make quality-py` passes
|
||||
- [ ] `make quality` passes (full stack)
|
||||
- [ ] Proto regenerated correctly
|
||||
- [ ] No type errors (`basedpyright`)
|
||||
- [ ] No lint errors (`ruff`)
|
||||
- [x] `pytest tests/quality/` passes (90 tests)
|
||||
- [x] `make quality-py` passes
|
||||
- [x] `make quality` passes (full stack)
|
||||
- [x] Proto regenerated correctly
|
||||
- [x] No type errors (`basedpyright`)
|
||||
- [x] No lint errors (`ruff`)
|
||||
|
||||
### Functional Validation
|
||||
|
||||
- [ ] CUDA path works (no regression)
|
||||
- [ ] CPU path works (no regression)
|
||||
- [ ] ROCm detection works
|
||||
- [ ] PyTorch fallback works
|
||||
- [ ] gRPC configuration works
|
||||
- [ ] Device switching works
|
||||
- [x] CUDA path works (no regression)
|
||||
- [x] CPU path works (no regression)
|
||||
- [x] ROCm detection works
|
||||
- [x] PyTorch fallback works
|
||||
- [x] gRPC configuration works
|
||||
- [x] Device switching works
|
||||
|
||||
### Documentation
|
||||
|
||||
- [ ] Sprint README complete
|
||||
- [ ] Implementation checklist complete
|
||||
- [ ] Installation guide complete
|
||||
- [ ] API documentation updated
|
||||
- [x] Sprint README complete
|
||||
- [x] Implementation checklist complete
|
||||
- [x] Installation guide complete
|
||||
- [x] API documentation updated
|
||||
|
||||
---
|
||||
|
||||
@@ -256,27 +258,44 @@ This checklist tracks the implementation progress for Sprint 18.5.
|
||||
|
||||
| File | Status |
|
||||
|------|--------|
|
||||
| `src/noteflow/domain/ports/gpu.py` | ❌ |
|
||||
| `src/noteflow/domain/ports/asr.py` | optional (only if relocating protocol) |
|
||||
| `src/noteflow/infrastructure/gpu/__init__.py` | ❌ |
|
||||
| `src/noteflow/infrastructure/gpu/detection.py` | ❌ |
|
||||
| `src/noteflow/infrastructure/asr/pytorch_engine.py` | ❌ |
|
||||
| `src/noteflow/infrastructure/asr/rocm_engine.py` | ❌ |
|
||||
| `src/noteflow/infrastructure/asr/factory.py` | ❌ |
|
||||
| `docker/Dockerfile.rocm` | ❌ |
|
||||
| `docs/installation/rocm.md` | ❌ |
|
||||
| `src/noteflow/domain/ports/gpu.py` | ✅ |
|
||||
| `src/noteflow/domain/ports/asr.py` | N/A (using existing protocols.py) |
|
||||
| `src/noteflow/infrastructure/gpu/__init__.py` | ✅ |
|
||||
| `src/noteflow/infrastructure/gpu/detection.py` | ✅ |
|
||||
| `src/noteflow/infrastructure/asr/pytorch_engine.py` | ✅ |
|
||||
| `src/noteflow/infrastructure/asr/rocm_engine.py` | ✅ |
|
||||
| `src/noteflow/infrastructure/asr/factory.py` | ✅ |
|
||||
| `docker/Dockerfile.rocm` | ✅ |
|
||||
| `docs/guides/rocm-setup.md` | ✅ |
|
||||
|
||||
### Files Modified
|
||||
|
||||
| File | Status |
|
||||
|------|--------|
|
||||
| `application/services/asr_config/types.py` | ❌ |
|
||||
| `application/services/asr_config/_engine_manager.py` | ❌ |
|
||||
| `infrastructure/diarization/engine/_device_mixin.py` | ❌ |
|
||||
| `infrastructure/metrics/system_resources.py` | ❌ |
|
||||
| `infrastructure/asr/engine.py` | ❌ |
|
||||
| `infrastructure/asr/protocols.py` | ❌ |
|
||||
| `grpc/proto/noteflow.proto` | ❌ |
|
||||
| `grpc/mixins/asr_config.py` | ❌ |
|
||||
| `config/settings/_features.py` | ❌ |
|
||||
| `pyproject.toml` | ❌ |
|
||||
| `application/services/asr_config/types.py` | ✅ |
|
||||
| `application/services/asr_config/_engine_manager.py` | ✅ |
|
||||
| `infrastructure/diarization/engine/_device_mixin.py` | ✅ |
|
||||
| `infrastructure/metrics/system_resources.py` | ✅ |
|
||||
| `infrastructure/asr/engine.py` | ✅ |
|
||||
| `infrastructure/asr/protocols.py` | ✅ |
|
||||
| `grpc/proto/noteflow.proto` | ✅ |
|
||||
| `grpc/mixins/asr_config.py` | ✅ |
|
||||
| `config/settings/_features.py` | ✅ |
|
||||
| `pyproject.toml` | ✅ |
|
||||
|
||||
### Test Results (2025-01-18)
|
||||
|
||||
- **GPU Detection Tests**: 40 passed
|
||||
- **ASR Factory Tests**: 14 passed
|
||||
- **Quality Tests**: 90 passed
|
||||
- **Total ROCm-related Tests**: 54 passed
|
||||
- **Type Checking**: 0 errors, 0 warnings, 0 notes
|
||||
|
||||
### Remaining Hardware Validation
|
||||
|
||||
The implementation is complete and tested with mocks. Full validation on actual AMD hardware is recommended when available:
|
||||
|
||||
- [ ] Test on AMD Instinct (MI series) datacenter GPU
|
||||
- [ ] Test on AMD Radeon RX 7000 series (RDNA3)
|
||||
- [ ] Test on AMD Radeon RX 6000 series (RDNA2)
|
||||
- [ ] Benchmark ROCm vs CUDA performance
|
||||
|
||||
@@ -5,39 +5,45 @@
|
||||
|
||||
---
|
||||
|
||||
## Validation Status (2025-01-17)
|
||||
## Validation Status (2025-01-18)
|
||||
|
||||
### Research Complete — Implementation Ready
|
||||
### ✅ Implementation Complete
|
||||
|
||||
Note: Hardware/driver compatibility and ROCm wheel availability are time-sensitive.
|
||||
Re-verify against AMD ROCm compatibility matrices and PyTorch ROCm install guidance
|
||||
before implementation.
|
||||
All components have been implemented and tested. Hardware validation on AMD GPUs
|
||||
is recommended when available.
|
||||
|
||||
### Repo Alignment Notes (current tree)
|
||||
### Repo Alignment Notes
|
||||
|
||||
- ASR protocol already exists at `src/noteflow/infrastructure/asr/protocols.py` and
|
||||
returns `AsrResult` from `src/noteflow/infrastructure/asr/dto.py`. Extend these
|
||||
instead of adding parallel `domain/ports/asr.py` types unless we plan a broader
|
||||
layering refactor.
|
||||
- gRPC mixins live under `src/noteflow/grpc/mixins/` (not `_mixins`).
|
||||
- Tests live under `tests/infrastructure/` and `tests/application/` (no `tests/unit/`).
|
||||
- ASR protocol exists at `src/noteflow/infrastructure/asr/protocols.py` and
|
||||
returns `AsrResult` from `src/noteflow/infrastructure/asr/dto.py`.
|
||||
- gRPC mixins live under `src/noteflow/grpc/mixins/`.
|
||||
- Tests live under `tests/infrastructure/` and `tests/application/`.
|
||||
|
||||
| Prerequisite | Status | Impact |
|
||||
|--------------|--------|--------|
|
||||
| PyTorch ROCm support | ✅ Available | PyTorch HIP layer works with existing `torch.cuda` API |
|
||||
| CTranslate2 ROCm support | ⚠️ Community fork | No official support; requires alternative engine strategy |
|
||||
| CTranslate2 ROCm support | ⚠️ Community fork | Using CTranslate2-ROCm fork with PyTorch fallback |
|
||||
| pyannote.audio ROCm support | ✅ Available | Pure PyTorch, works out of box |
|
||||
| diart ROCm support | ✅ Available | Pure PyTorch, works out of box |
|
||||
|
||||
| Component | Status | Notes |
|
||||
|-----------|--------|-------|
|
||||
| Device abstraction layer | ❌ Not implemented | Need `GpuBackend` enum |
|
||||
| ASR engine protocol | ⚠️ Partial | AsrEngine exists; extend with device/compute metadata |
|
||||
| ROCm detection | ❌ Not implemented | Need `torch.version.hip` check |
|
||||
| gRPC proto updates | ❌ Not implemented | Need `ASR_DEVICE_ROCM` |
|
||||
| PyTorch Whisper fallback | ❌ Not implemented | Fallback for universal compatibility |
|
||||
| Component | Status | Location |
|
||||
|-----------|--------|----------|
|
||||
| Device abstraction layer | ✅ Implemented | `domain/ports/gpu.py`, `infrastructure/gpu/detection.py` |
|
||||
| ASR engine protocol | ✅ Implemented | `infrastructure/asr/protocols.py` with device/compute metadata |
|
||||
| ROCm detection | ✅ Implemented | `infrastructure/gpu/detection.py` via `torch.version.hip` |
|
||||
| gRPC proto updates | ✅ Implemented | `ASR_DEVICE_ROCM`, `rocm_available`, `gpu_backend` fields |
|
||||
| PyTorch Whisper fallback | ✅ Implemented | `infrastructure/asr/pytorch_engine.py` |
|
||||
| ROCm-specific engine | ✅ Implemented | `infrastructure/asr/rocm_engine.py` |
|
||||
| Engine factory | ✅ Implemented | `infrastructure/asr/factory.py` with auto-detection |
|
||||
| Docker ROCm image | ✅ Implemented | `docker/Dockerfile.rocm` |
|
||||
| Documentation | ✅ Implemented | `docs/guides/rocm-setup.md` |
|
||||
|
||||
**Action required**: Implement device abstraction layer and engine protocol pattern.
|
||||
### Test Results (2025-01-18)
|
||||
|
||||
- GPU Detection Tests: 40 passed
|
||||
- ASR Factory Tests: 14 passed
|
||||
- Quality Tests: 90 passed
|
||||
- Type Checking: 0 errors, 0 warnings, 0 notes
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ dependencies = [
|
||||
"structlog>=24.0",
|
||||
"sounddevice>=0.5.3",
|
||||
"spacy>=3.8.11",
|
||||
"openai-whisper>=20250625",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@@ -76,6 +77,16 @@ calendar = [
|
||||
"google-auth>=2.23",
|
||||
"google-auth-oauthlib>=1.1",
|
||||
]
|
||||
rocm = [
|
||||
# ROCm GPU support for AMD GPUs
|
||||
# Requires PyTorch with ROCm support (install separately)
|
||||
# pip install torch --index-url https://download.pytorch.org/whl/rocm6.2
|
||||
]
|
||||
rocm-ctranslate2 = [
|
||||
# Optional: CTranslate2-ROCm for faster inference
|
||||
# Install manually: pip install git+https://github.com/arlo-phoenix/CTranslate2-rocm.git
|
||||
"faster-whisper>=1.0",
|
||||
]
|
||||
observability = [
|
||||
"opentelemetry-api>=1.28",
|
||||
"opentelemetry-sdk>=1.28",
|
||||
|
||||
@@ -6,11 +6,14 @@ import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.domain.ports.gpu import GpuBackend
|
||||
from noteflow.infrastructure.asr import VALID_MODEL_SIZES
|
||||
from noteflow.infrastructure.gpu import detect_gpu_backend
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
from .types import (
|
||||
DEVICE_COMPUTE_TYPES,
|
||||
INVALID_MODEL_SIZE_PREFIX,
|
||||
AsrCapabilities,
|
||||
AsrComputeType,
|
||||
AsrConfigJob,
|
||||
@@ -18,22 +21,42 @@ from .types import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.infrastructure.asr import FasterWhisperEngine
|
||||
from noteflow.infrastructure.asr.protocols import AsrEngine
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _extract_engine_config(
|
||||
engine: AsrEngine | None,
|
||||
) -> tuple[AsrDevice, AsrComputeType]:
|
||||
"""Extract device and compute type from engine.
|
||||
|
||||
Args:
|
||||
engine: The ASR engine to extract config from, or None.
|
||||
|
||||
Returns:
|
||||
Tuple of (device, compute_type).
|
||||
"""
|
||||
if engine is None:
|
||||
return AsrDevice.CPU, AsrComputeType.INT8
|
||||
|
||||
device_map = {AsrDevice.ROCM.value: AsrDevice.ROCM, AsrDevice.CUDA.value: AsrDevice.CUDA}
|
||||
device = device_map.get(engine.device, AsrDevice.CPU)
|
||||
return device, AsrComputeType(engine.compute_type)
|
||||
|
||||
|
||||
class AsrEngineManager:
|
||||
"""Manage ASR engine lifecycle and capabilities.
|
||||
|
||||
Handles engine creation, model loading, and capability detection.
|
||||
Supports CUDA, ROCm, and CPU backends.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
asr_engine: FasterWhisperEngine | None,
|
||||
asr_engine: AsrEngine | None,
|
||||
*,
|
||||
on_engine_update: Callable[[FasterWhisperEngine], None] | None = None,
|
||||
on_engine_update: Callable[[AsrEngine], None] | None = None,
|
||||
on_config_persist: Callable[[AsrCapabilities], Awaitable[None]] | None = None,
|
||||
) -> None:
|
||||
"""Initialize engine manager.
|
||||
@@ -49,7 +72,7 @@ class AsrEngineManager:
|
||||
self._reload_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def engine(self) -> FasterWhisperEngine | None:
|
||||
def engine(self) -> AsrEngine | None:
|
||||
"""Return the current ASR engine."""
|
||||
return self._asr_engine
|
||||
|
||||
@@ -64,12 +87,25 @@ class AsrEngineManager:
|
||||
Returns:
|
||||
True if CUDA is available, False otherwise.
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
backend = detect_gpu_backend()
|
||||
return backend == GpuBackend.CUDA
|
||||
|
||||
return torch.cuda.is_available()
|
||||
except ImportError:
|
||||
return False
|
||||
def detect_rocm_available(self) -> bool:
|
||||
"""Detect if ROCm is available for ASR.
|
||||
|
||||
Returns:
|
||||
True if ROCm is available, False otherwise.
|
||||
"""
|
||||
backend = detect_gpu_backend()
|
||||
return backend == GpuBackend.ROCM
|
||||
|
||||
def detect_gpu_backend(self) -> GpuBackend:
|
||||
"""Detect the current GPU backend.
|
||||
|
||||
Returns:
|
||||
GpuBackend enum indicating the available backend.
|
||||
"""
|
||||
return detect_gpu_backend()
|
||||
|
||||
def build_capabilities(self) -> AsrCapabilities:
|
||||
"""Get current ASR configuration and capabilities.
|
||||
@@ -77,24 +113,20 @@ class AsrEngineManager:
|
||||
Returns:
|
||||
Current ASR configuration including available options.
|
||||
"""
|
||||
cuda_available = self.detect_cuda_available()
|
||||
current_device = AsrDevice.CPU
|
||||
current_compute_type = AsrComputeType.INT8
|
||||
|
||||
# Capture engine reference to avoid race between null check and attribute access
|
||||
backend = detect_gpu_backend()
|
||||
engine = self._asr_engine
|
||||
if engine is not None:
|
||||
current_device = AsrDevice(engine.device)
|
||||
current_compute_type = AsrComputeType(engine.compute_type)
|
||||
current_device, current_compute_type = _extract_engine_config(engine)
|
||||
|
||||
return AsrCapabilities(
|
||||
model_size=engine.model_size if engine else None,
|
||||
device=current_device,
|
||||
compute_type=current_compute_type,
|
||||
is_ready=engine.is_loaded if engine else False,
|
||||
cuda_available=cuda_available,
|
||||
cuda_available=backend == GpuBackend.CUDA,
|
||||
available_model_sizes=VALID_MODEL_SIZES,
|
||||
available_compute_types=DEVICE_COMPUTE_TYPES[current_device],
|
||||
rocm_available=backend == GpuBackend.ROCM,
|
||||
gpu_backend=backend.value,
|
||||
)
|
||||
|
||||
def check_configuration(
|
||||
@@ -115,11 +147,14 @@ class AsrEngineManager:
|
||||
"""
|
||||
if model_size is not None and model_size not in VALID_MODEL_SIZES:
|
||||
valid_sizes = ", ".join(VALID_MODEL_SIZES)
|
||||
return f"Invalid model size: {model_size}. Valid: {valid_sizes}"
|
||||
return f"{INVALID_MODEL_SIZE_PREFIX}{model_size}. Valid: {valid_sizes}"
|
||||
|
||||
if device == AsrDevice.CUDA and not self.detect_cuda_available():
|
||||
return "CUDA requested but not available on this server"
|
||||
|
||||
if device == AsrDevice.ROCM and not self.detect_rocm_available():
|
||||
return "ROCm requested but not available on this server"
|
||||
|
||||
if device is not None and compute_type is not None:
|
||||
valid_types = DEVICE_COMPUTE_TYPES[device]
|
||||
if compute_type not in valid_types:
|
||||
@@ -130,9 +165,11 @@ class AsrEngineManager:
|
||||
def build_engine_for_job(
|
||||
self,
|
||||
job: AsrConfigJob,
|
||||
) -> tuple[FasterWhisperEngine, bool]:
|
||||
) -> tuple[AsrEngine, bool]:
|
||||
"""Build or reuse engine based on job configuration.
|
||||
|
||||
Uses the factory to create appropriate engine based on device.
|
||||
|
||||
Args:
|
||||
job: The configuration job with target settings.
|
||||
|
||||
@@ -147,21 +184,21 @@ class AsrEngineManager:
|
||||
):
|
||||
return current_engine, False
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from noteflow.infrastructure.asr import FasterWhisperEngine
|
||||
# Use factory to create appropriate engine
|
||||
from noteflow.infrastructure.asr.factory import create_asr_engine
|
||||
|
||||
return (
|
||||
FasterWhisperEngine(
|
||||
compute_type=job.target_compute_type.value,
|
||||
create_asr_engine(
|
||||
device=job.target_device.value,
|
||||
compute_type=job.target_compute_type.value,
|
||||
),
|
||||
True,
|
||||
)
|
||||
|
||||
def set_active_engine(
|
||||
self,
|
||||
engine: FasterWhisperEngine,
|
||||
old_engine: FasterWhisperEngine | None = None,
|
||||
engine: AsrEngine,
|
||||
old_engine: AsrEngine | None = None,
|
||||
) -> None:
|
||||
"""Replace the active engine, unloading the old one.
|
||||
|
||||
@@ -179,7 +216,7 @@ class AsrEngineManager:
|
||||
|
||||
async def load_model(
|
||||
self,
|
||||
engine: FasterWhisperEngine,
|
||||
engine: AsrEngine,
|
||||
model_size: str,
|
||||
) -> None:
|
||||
"""Load a model into the engine asynchronously.
|
||||
|
||||
@@ -8,6 +8,15 @@ from enum import Enum
|
||||
from typing import Final
|
||||
from uuid import UUID
|
||||
|
||||
from noteflow.domain.ports.gpu import GpuBackend
|
||||
|
||||
# ASR error message constants
|
||||
INVALID_MODEL_SIZE_PREFIX: Final[str] = "Invalid model size: "
|
||||
VALID_SIZES_SUFFIX: Final[str] = ". Valid sizes: "
|
||||
|
||||
# GPU backend value constants (for consistency with GpuBackend enum)
|
||||
_GPU_BACKEND_NONE: Final[str] = GpuBackend.NONE.value
|
||||
|
||||
|
||||
class AsrConfigPhase(str, Enum):
|
||||
"""Phases of ASR reconfiguration."""
|
||||
@@ -23,7 +32,8 @@ class AsrDevice(str, Enum):
|
||||
"""Supported ASR devices."""
|
||||
|
||||
CPU = "cpu"
|
||||
CUDA = "cuda"
|
||||
CUDA = GpuBackend.CUDA.value
|
||||
ROCM = GpuBackend.ROCM.value
|
||||
|
||||
|
||||
class AsrComputeType(str, Enum):
|
||||
@@ -41,6 +51,11 @@ DEVICE_COMPUTE_TYPES: Final[dict[AsrDevice, tuple[AsrComputeType, ...]]] = {
|
||||
AsrComputeType.FLOAT16,
|
||||
AsrComputeType.FLOAT32,
|
||||
),
|
||||
AsrDevice.ROCM: (
|
||||
AsrComputeType.INT8,
|
||||
AsrComputeType.FLOAT16,
|
||||
AsrComputeType.FLOAT32,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -70,3 +85,5 @@ class AsrCapabilities:
|
||||
cuda_available: bool
|
||||
available_model_sizes: tuple[str, ...]
|
||||
available_compute_types: tuple[AsrComputeType, ...]
|
||||
rocm_available: bool = False
|
||||
gpu_backend: str = _GPU_BACKEND_NONE
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from noteflow.config.constants import OAUTH_FIELD_ACCESS_TOKEN
|
||||
from noteflow.domain.entities.integration import Integration, IntegrationStatus
|
||||
@@ -30,10 +31,14 @@ class CalendarServiceConnectionMixin:
|
||||
_resolve_connection_status: Callable[..., tuple[str, datetime | None]]
|
||||
_fetch_calendar_integration: Callable[..., Awaitable[Integration | None]]
|
||||
|
||||
async def get_connection_status(self, provider: str) -> OAuthConnectionInfo:
|
||||
async def get_connection_status(
|
||||
self,
|
||||
provider: str,
|
||||
workspace_id: UUID | None = None,
|
||||
) -> OAuthConnectionInfo:
|
||||
"""Get OAuth connection status for a provider."""
|
||||
async with self._uow_factory() as uow:
|
||||
integration = await self._fetch_calendar_integration(uow, provider)
|
||||
integration = await self._fetch_calendar_integration(uow, provider, workspace_id)
|
||||
|
||||
if integration is None:
|
||||
return OAuthConnectionInfo(
|
||||
@@ -52,12 +57,12 @@ class CalendarServiceConnectionMixin:
|
||||
error_message=integration.error_message,
|
||||
)
|
||||
|
||||
async def disconnect(self, provider: str) -> bool:
|
||||
async def disconnect(self, provider: str, workspace_id: UUID | None = None) -> bool:
|
||||
"""Disconnect OAuth integration and revoke tokens."""
|
||||
oauth_provider = self._parse_calendar_provider(provider)
|
||||
|
||||
async with self._uow_factory() as uow:
|
||||
integration = await self._fetch_calendar_integration(uow, provider)
|
||||
integration = await self._fetch_calendar_integration(uow, provider, workspace_id)
|
||||
|
||||
if integration is None:
|
||||
return False
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from noteflow.config.constants import ERR_TOKEN_REFRESH_PREFIX
|
||||
from noteflow.domain.entities.integration import Integration
|
||||
@@ -24,6 +25,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from noteflow.config.settings import CalendarIntegrationSettings
|
||||
from noteflow.domain.ports.unit_of_work import UnitOfWork
|
||||
from noteflow.domain.value_objects import OAuthClientConfig
|
||||
|
||||
|
||||
|
||||
@@ -37,12 +39,14 @@ class CalendarServiceEventsMixin:
|
||||
_outlook_adapter: OutlookCalendarAdapter
|
||||
_parse_calendar_provider: Callable[..., OAuthProvider]
|
||||
_fetch_calendar_integration: Callable[..., Awaitable[Integration | None]]
|
||||
_build_override_config: Callable[..., OAuthClientConfig | None]
|
||||
|
||||
async def list_calendar_events(
|
||||
self,
|
||||
provider: str | None = None,
|
||||
hours_ahead: int | None = None,
|
||||
limit: int | None = None,
|
||||
workspace_id: UUID | None = None,
|
||||
) -> list[CalendarEventInfo]:
|
||||
"""Fetch calendar events from connected providers."""
|
||||
effective_hours = hours_ahead or self._settings.sync_hours_ahead
|
||||
@@ -53,9 +57,14 @@ class CalendarServiceEventsMixin:
|
||||
provider=provider,
|
||||
hours_ahead=effective_hours,
|
||||
limit=effective_limit,
|
||||
workspace_id=workspace_id,
|
||||
)
|
||||
else:
|
||||
events = await self._fetch_all_provider_events(effective_hours, effective_limit)
|
||||
events = await self._fetch_all_provider_events(
|
||||
effective_hours,
|
||||
effective_limit,
|
||||
workspace_id,
|
||||
)
|
||||
|
||||
events.sort(key=lambda e: e.start_time)
|
||||
return events
|
||||
@@ -64,11 +73,17 @@ class CalendarServiceEventsMixin:
|
||||
self,
|
||||
hours_ahead: int,
|
||||
limit: int,
|
||||
workspace_id: UUID | None,
|
||||
) -> list[CalendarEventInfo]:
|
||||
"""Fetch events from all configured providers, ignoring errors."""
|
||||
events: list[CalendarEventInfo] = []
|
||||
for p in [OAuthProvider.GOOGLE.value, OAuthProvider.OUTLOOK.value]:
|
||||
provider_events = await self._try_fetch_provider_events(p, hours_ahead, limit)
|
||||
provider_events = await self._try_fetch_provider_events(
|
||||
p,
|
||||
hours_ahead,
|
||||
limit,
|
||||
workspace_id,
|
||||
)
|
||||
events.extend(provider_events)
|
||||
return events
|
||||
|
||||
@@ -77,6 +92,7 @@ class CalendarServiceEventsMixin:
|
||||
provider: str,
|
||||
hours_ahead: int,
|
||||
limit: int,
|
||||
workspace_id: UUID | None,
|
||||
) -> list[CalendarEventInfo]:
|
||||
"""Attempt to fetch events from a provider, returning empty list on error."""
|
||||
try:
|
||||
@@ -84,6 +100,7 @@ class CalendarServiceEventsMixin:
|
||||
provider=provider,
|
||||
hours_ahead=hours_ahead,
|
||||
limit=limit,
|
||||
workspace_id=workspace_id,
|
||||
)
|
||||
except CalendarServiceError:
|
||||
return []
|
||||
@@ -93,12 +110,13 @@ class CalendarServiceEventsMixin:
|
||||
provider: str,
|
||||
hours_ahead: int,
|
||||
limit: int,
|
||||
workspace_id: UUID | None,
|
||||
) -> list[CalendarEventInfo]:
|
||||
"""Fetch events from a specific provider with token refresh."""
|
||||
oauth_provider = self._parse_calendar_provider(provider)
|
||||
|
||||
async with self._uow_factory() as uow:
|
||||
integration = await self._fetch_calendar_integration(uow, provider)
|
||||
integration = await self._fetch_calendar_integration(uow, provider, workspace_id)
|
||||
if integration is None or not integration.is_connected:
|
||||
raise CalendarServiceError(f"Provider {provider} not connected")
|
||||
tokens = await self._load_tokens_for_provider(uow, provider, integration)
|
||||
@@ -144,13 +162,19 @@ class CalendarServiceEventsMixin:
|
||||
return tokens
|
||||
|
||||
try:
|
||||
secrets = await uow.integrations.get_secrets(integration.id) or {}
|
||||
override_config = self._build_override_config(
|
||||
oauth_provider, integration, secrets
|
||||
)
|
||||
refreshed = await self._oauth_manager.refresh_tokens(
|
||||
provider=oauth_provider,
|
||||
refresh_token=tokens.refresh_token,
|
||||
client_config=override_config,
|
||||
)
|
||||
merged_secrets = {**secrets, **refreshed.to_secrets_dict()}
|
||||
await uow.integrations.set_secrets(
|
||||
integration_id=integration.id,
|
||||
secrets=refreshed.to_secrets_dict(),
|
||||
secrets=merged_secrets,
|
||||
)
|
||||
await uow.commit()
|
||||
return refreshed
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
"""OAuth configuration mixin for calendar service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from noteflow.domain.constants.fields import (
|
||||
OAUTH_OVERRIDE_CLIENT_ID,
|
||||
OAUTH_OVERRIDE_CLIENT_SECRET,
|
||||
OAUTH_OVERRIDE_ENABLED,
|
||||
OAUTH_OVERRIDE_REDIRECT_URI,
|
||||
OAUTH_OVERRIDE_SCOPES,
|
||||
PROVIDER,
|
||||
)
|
||||
from noteflow.domain.entities.integration import Integration, IntegrationType
|
||||
from noteflow.domain.value_objects import OAuthClientConfig, OAuthProvider
|
||||
|
||||
from ._errors import CalendarServiceError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from noteflow.config.settings import CalendarIntegrationSettings
|
||||
from noteflow.domain.ports.unit_of_work import UnitOfWork
|
||||
|
||||
|
||||
class CalendarServiceOAuthConfigMixin:
|
||||
"""Mixin for managing OAuth override configuration."""
|
||||
|
||||
_settings: CalendarIntegrationSettings
|
||||
_uow_factory: Callable[[], UnitOfWork]
|
||||
_parse_calendar_provider: Callable[..., OAuthProvider]
|
||||
_fetch_calendar_integration: Callable[..., Awaitable[Integration | None]]
|
||||
_default_scopes_for_provider: Callable[..., tuple[str, ...]]
|
||||
_build_override_view: Callable[..., tuple[OAuthClientConfig, bool, bool]]
|
||||
_resolve_workspace_id: Callable[..., UUID]
|
||||
|
||||
async def get_oauth_client_config(
|
||||
self,
|
||||
provider: str,
|
||||
workspace_id: UUID | None = None,
|
||||
) -> tuple[OAuthClientConfig, bool, bool]:
|
||||
"""Get stored OAuth override configuration for a provider."""
|
||||
oauth_provider = self._parse_calendar_provider(provider)
|
||||
|
||||
async with self._uow_factory() as uow:
|
||||
integration = await self._fetch_calendar_integration(
|
||||
uow, provider, workspace_id
|
||||
)
|
||||
if integration is None:
|
||||
config = OAuthClientConfig(
|
||||
client_id="",
|
||||
client_secret="",
|
||||
redirect_uri=self._settings.redirect_uri,
|
||||
scopes=self._default_scopes_for_provider(oauth_provider),
|
||||
)
|
||||
return config, False, False
|
||||
|
||||
secrets = await uow.integrations.get_secrets(integration.id) or {}
|
||||
return self._build_override_view(oauth_provider, integration, secrets)
|
||||
|
||||
async def _get_or_create_calendar_integration(
|
||||
self,
|
||||
uow: UnitOfWork,
|
||||
provider: str,
|
||||
workspace_id: UUID | None,
|
||||
) -> Integration:
|
||||
integration = await self._fetch_calendar_integration(uow, provider, workspace_id)
|
||||
if integration is None:
|
||||
integration = Integration.create(
|
||||
workspace_id=self._resolve_workspace_id(workspace_id),
|
||||
name=f"{provider.title()} Calendar",
|
||||
integration_type=IntegrationType.CALENDAR,
|
||||
config={PROVIDER: provider},
|
||||
)
|
||||
await uow.integrations.create(integration)
|
||||
return integration
|
||||
|
||||
@staticmethod
|
||||
def _apply_override_config(
|
||||
integration: Integration,
|
||||
provider: str,
|
||||
override_enabled: bool,
|
||||
client_config: OAuthClientConfig,
|
||||
) -> None:
|
||||
config = dict(integration.config or {})
|
||||
config[PROVIDER] = provider
|
||||
config[OAUTH_OVERRIDE_ENABLED] = override_enabled
|
||||
config[OAUTH_OVERRIDE_CLIENT_ID] = client_config.client_id
|
||||
config[OAUTH_OVERRIDE_REDIRECT_URI] = client_config.redirect_uri
|
||||
config[OAUTH_OVERRIDE_SCOPES] = list(client_config.scopes)
|
||||
integration.config = config
|
||||
|
||||
@staticmethod
|
||||
def _ensure_override_secret_present(
|
||||
override_enabled: bool,
|
||||
secrets: dict[str, str],
|
||||
) -> None:
|
||||
if override_enabled and not secrets.get(OAUTH_OVERRIDE_CLIENT_SECRET):
|
||||
raise CalendarServiceError(
|
||||
"OAuth override enabled but client secret is missing"
|
||||
)
|
||||
|
||||
async def set_oauth_client_config(
|
||||
self,
|
||||
provider: str,
|
||||
client_config: OAuthClientConfig,
|
||||
override_enabled: bool,
|
||||
workspace_id: UUID | None = None,
|
||||
) -> None:
|
||||
"""Persist OAuth override configuration for a provider."""
|
||||
self._parse_calendar_provider(provider)
|
||||
if override_enabled and not client_config.client_id:
|
||||
raise CalendarServiceError("OAuth override enabled but client ID is missing")
|
||||
client_secret = client_config.client_secret.strip()
|
||||
|
||||
async with self._uow_factory() as uow:
|
||||
integration = await self._get_or_create_calendar_integration(
|
||||
uow,
|
||||
provider,
|
||||
workspace_id,
|
||||
)
|
||||
self._apply_override_config(
|
||||
integration,
|
||||
provider,
|
||||
override_enabled,
|
||||
client_config,
|
||||
)
|
||||
await uow.integrations.update(integration)
|
||||
|
||||
existing_secrets = await uow.integrations.get_secrets(integration.id) or {}
|
||||
if client_secret:
|
||||
existing_secrets[OAUTH_OVERRIDE_CLIENT_SECRET] = client_secret
|
||||
self._ensure_override_secret_present(override_enabled, existing_secrets)
|
||||
|
||||
await uow.integrations.set_secrets(
|
||||
integration_id=integration.id,
|
||||
secrets=existing_secrets,
|
||||
)
|
||||
await uow.commit()
|
||||
@@ -2,12 +2,20 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, TypeGuard, cast
|
||||
from uuid import UUID
|
||||
|
||||
from noteflow.domain.constants.fields import PROVIDER
|
||||
from noteflow.domain.constants.fields import (
|
||||
OAUTH_OVERRIDE_CLIENT_ID,
|
||||
OAUTH_OVERRIDE_CLIENT_SECRET,
|
||||
OAUTH_OVERRIDE_ENABLED,
|
||||
OAUTH_OVERRIDE_REDIRECT_URI,
|
||||
OAUTH_OVERRIDE_SCOPES,
|
||||
PROVIDER,
|
||||
)
|
||||
from noteflow.domain.entities.integration import Integration, IntegrationType
|
||||
from noteflow.domain.value_objects import OAuthProvider, OAuthTokens
|
||||
from noteflow.domain.value_objects import OAuthClientConfig, OAuthProvider, OAuthTokens
|
||||
from noteflow.infrastructure.calendar import OAuthManager
|
||||
from noteflow.infrastructure.calendar.google_adapter import GoogleCalendarError
|
||||
from noteflow.infrastructure.calendar.oauth import OAuthError
|
||||
@@ -25,6 +33,17 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _is_str_sequence(value: object) -> TypeGuard[Sequence[str]]:
|
||||
"""Check whether a value is a non-string sequence of strings."""
|
||||
if isinstance(value, str):
|
||||
return False
|
||||
if not isinstance(value, Sequence):
|
||||
return False
|
||||
# Cast to satisfy type checker when iterating unknown Sequence contents.
|
||||
sequence = cast(Sequence[object], value)
|
||||
return all(isinstance(item, str) for item in sequence)
|
||||
|
||||
|
||||
class CalendarServiceOAuthMixin:
|
||||
"""Mixin for OAuth flow operations."""
|
||||
|
||||
@@ -36,19 +55,140 @@ class CalendarServiceOAuthMixin:
|
||||
_fetch_calendar_integration: Callable[..., Awaitable[Integration | None]]
|
||||
_fetch_account_email: Callable[[OAuthProvider, str], Awaitable[str]]
|
||||
|
||||
def _resolve_workspace_id(self, workspace_id: UUID | None) -> UUID:
|
||||
"""Resolve workspace ID with fallback for single-user mode."""
|
||||
return workspace_id or self.DEFAULT_WORKSPACE_ID
|
||||
|
||||
def _default_scopes_for_provider(self, provider: OAuthProvider) -> tuple[str, ...]:
|
||||
"""Return default OAuth scopes for provider."""
|
||||
scopes = (
|
||||
self._oauth_manager.GOOGLE_SCOPES
|
||||
if provider == OAuthProvider.GOOGLE
|
||||
else self._oauth_manager.OUTLOOK_SCOPES
|
||||
)
|
||||
return tuple(scopes)
|
||||
|
||||
def _parse_override_scopes(
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
scopes: object,
|
||||
) -> tuple[str, ...]:
|
||||
"""Normalize override scopes, falling back to defaults."""
|
||||
if _is_str_sequence(scopes):
|
||||
normalized = tuple(scopes)
|
||||
if normalized:
|
||||
return normalized
|
||||
return self._default_scopes_for_provider(provider)
|
||||
|
||||
def _extract_override_fields(
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
integration: Integration,
|
||||
) -> tuple[bool, str, str, tuple[str, ...]]:
|
||||
config = integration.config or {}
|
||||
override_enabled = bool(config.get(OAUTH_OVERRIDE_ENABLED))
|
||||
client_id_raw = config.get(OAUTH_OVERRIDE_CLIENT_ID)
|
||||
redirect_raw = config.get(OAUTH_OVERRIDE_REDIRECT_URI)
|
||||
client_id = client_id_raw if isinstance(client_id_raw, str) else ""
|
||||
redirect_uri = (
|
||||
redirect_raw if isinstance(redirect_raw, str) else self._settings.redirect_uri
|
||||
)
|
||||
scopes = self._parse_override_scopes(provider, config.get(OAUTH_OVERRIDE_SCOPES))
|
||||
return override_enabled, client_id, redirect_uri, scopes
|
||||
|
||||
def _build_override_config(
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
integration: Integration,
|
||||
secrets: dict[str, str] | None,
|
||||
) -> OAuthClientConfig | None:
|
||||
override_enabled, client_id, redirect_uri, scopes = self._extract_override_fields(
|
||||
provider, integration
|
||||
)
|
||||
if not override_enabled:
|
||||
return None
|
||||
client_secret = secrets.get(OAUTH_OVERRIDE_CLIENT_SECRET, "") if secrets else ""
|
||||
if not (client_id and client_secret):
|
||||
raise CalendarServiceError(
|
||||
"OAuth override enabled but client credentials are missing"
|
||||
)
|
||||
override_config = OAuthClientConfig(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=redirect_uri or self._settings.redirect_uri,
|
||||
scopes=scopes,
|
||||
)
|
||||
return override_config
|
||||
|
||||
def _build_override_view(
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
integration: Integration,
|
||||
secrets: dict[str, str] | None,
|
||||
) -> tuple[OAuthClientConfig, bool, bool]:
|
||||
override_enabled, client_id, redirect_uri, scopes = self._extract_override_fields(
|
||||
provider, integration
|
||||
)
|
||||
has_secret = bool(secrets.get(OAUTH_OVERRIDE_CLIENT_SECRET)) if secrets else False
|
||||
return (
|
||||
OAuthClientConfig(
|
||||
client_id=client_id,
|
||||
client_secret="",
|
||||
redirect_uri=redirect_uri or self._settings.redirect_uri,
|
||||
scopes=scopes,
|
||||
),
|
||||
override_enabled,
|
||||
has_secret,
|
||||
)
|
||||
|
||||
async def _load_override_config(
|
||||
self,
|
||||
provider: str,
|
||||
oauth_provider: OAuthProvider,
|
||||
workspace_id: UUID | None,
|
||||
) -> OAuthClientConfig | None:
|
||||
async with self._uow_factory() as uow:
|
||||
integration = await self._fetch_calendar_integration(uow, provider, workspace_id)
|
||||
if not integration:
|
||||
return None
|
||||
secrets = await uow.integrations.get_secrets(integration.id)
|
||||
return self._build_override_config(oauth_provider, integration, secrets)
|
||||
|
||||
async def _resolve_override_config_and_redirect(
|
||||
self,
|
||||
provider: str,
|
||||
oauth_provider: OAuthProvider,
|
||||
redirect_uri: str | None,
|
||||
workspace_id: UUID | None,
|
||||
) -> tuple[OAuthClientConfig | None, str]:
|
||||
override_config = await self._load_override_config(
|
||||
provider, oauth_provider, workspace_id
|
||||
)
|
||||
effective_redirect = redirect_uri or self._settings.redirect_uri
|
||||
if override_config and not redirect_uri:
|
||||
effective_redirect = override_config.redirect_uri
|
||||
return override_config, effective_redirect
|
||||
|
||||
async def initiate_oauth(
|
||||
self,
|
||||
provider: str,
|
||||
redirect_uri: str | None = None,
|
||||
workspace_id: UUID | None = None,
|
||||
) -> tuple[str, str]:
|
||||
"""Start OAuth flow for a calendar provider."""
|
||||
oauth_provider = self._parse_calendar_provider(provider)
|
||||
effective_redirect = redirect_uri or self._settings.redirect_uri
|
||||
override_config, effective_redirect = await self._resolve_override_config_and_redirect(
|
||||
provider,
|
||||
oauth_provider,
|
||||
redirect_uri,
|
||||
workspace_id,
|
||||
)
|
||||
|
||||
try:
|
||||
auth_url, state = self._oauth_manager.initiate_auth(
|
||||
provider=oauth_provider,
|
||||
redirect_uri=effective_redirect,
|
||||
client_config=override_config,
|
||||
)
|
||||
logger.info("Initiated OAuth flow for provider=%s", provider)
|
||||
return auth_url, state
|
||||
@@ -60,13 +200,21 @@ class CalendarServiceOAuthMixin:
|
||||
provider: str,
|
||||
code: str,
|
||||
state: str,
|
||||
workspace_id: UUID | None = None,
|
||||
) -> UUID:
|
||||
"""Complete OAuth flow and store tokens."""
|
||||
oauth_provider = self._parse_calendar_provider(provider)
|
||||
override_config = await self._load_override_config(
|
||||
provider,
|
||||
oauth_provider,
|
||||
workspace_id,
|
||||
)
|
||||
|
||||
tokens = await self._exchange_tokens(oauth_provider, code, state)
|
||||
tokens = await self._exchange_tokens(oauth_provider, code, state, override_config)
|
||||
email = await self._fetch_provider_email(oauth_provider, tokens.access_token)
|
||||
integration_id = await self._store_calendar_integration(provider, email, tokens)
|
||||
integration_id = await self._store_calendar_integration(
|
||||
provider, email, tokens, workspace_id
|
||||
)
|
||||
|
||||
logger.info("Completed OAuth for provider=%s, email=%s", provider, email)
|
||||
return integration_id
|
||||
@@ -76,6 +224,7 @@ class CalendarServiceOAuthMixin:
|
||||
oauth_provider: OAuthProvider,
|
||||
code: str,
|
||||
state: str,
|
||||
client_config: OAuthClientConfig | None = None,
|
||||
) -> OAuthTokens:
|
||||
"""Exchange authorization code for tokens."""
|
||||
try:
|
||||
@@ -83,6 +232,7 @@ class CalendarServiceOAuthMixin:
|
||||
provider=oauth_provider,
|
||||
code=code,
|
||||
state=state,
|
||||
client_config=client_config,
|
||||
)
|
||||
except OAuthError as e:
|
||||
raise CalendarServiceError(f"OAuth failed: {e}") from e
|
||||
@@ -103,14 +253,18 @@ class CalendarServiceOAuthMixin:
|
||||
provider: str,
|
||||
email: str,
|
||||
tokens: OAuthTokens,
|
||||
workspace_id: UUID | None = None,
|
||||
) -> UUID:
|
||||
"""Persist calendar integration and encrypted tokens."""
|
||||
effective_workspace_id = self._resolve_workspace_id(workspace_id)
|
||||
async with self._uow_factory() as uow:
|
||||
integration = await self._fetch_calendar_integration(uow, provider)
|
||||
integration = await self._fetch_calendar_integration(
|
||||
uow, provider, workspace_id
|
||||
)
|
||||
|
||||
if integration is None:
|
||||
integration = Integration.create(
|
||||
workspace_id=self.DEFAULT_WORKSPACE_ID,
|
||||
workspace_id=effective_workspace_id,
|
||||
name=f"{provider.title()} Calendar",
|
||||
integration_type=IntegrationType.CALENDAR,
|
||||
config={PROVIDER: provider},
|
||||
@@ -122,9 +276,11 @@ class CalendarServiceOAuthMixin:
|
||||
integration.connect(provider_email=email)
|
||||
await uow.integrations.update(integration)
|
||||
|
||||
existing_secrets = await uow.integrations.get_secrets(integration.id) or {}
|
||||
merged_secrets = {**existing_secrets, **tokens.to_secrets_dict()}
|
||||
await uow.integrations.set_secrets(
|
||||
integration_id=integration.id,
|
||||
secrets=tokens.to_secrets_dict(),
|
||||
secrets=merged_secrets,
|
||||
)
|
||||
await uow.commit()
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.domain.entities.integration import Integration, IntegrationStatus, IntegrationType
|
||||
@@ -65,9 +66,11 @@ class CalendarServiceSupportMixin:
|
||||
self,
|
||||
uow: UnitOfWork,
|
||||
provider: str,
|
||||
workspace_id: UUID | None = None,
|
||||
) -> Integration | None:
|
||||
"""Fetch calendar integration for provider, or None if not found."""
|
||||
return await uow.integrations.get_by_provider(
|
||||
provider=provider,
|
||||
integration_type=IntegrationType.CALENDAR.value,
|
||||
workspace_id=workspace_id,
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@ from noteflow.infrastructure.calendar import (
|
||||
|
||||
from ._connection_mixin import CalendarServiceConnectionMixin
|
||||
from ._events_mixin import CalendarServiceEventsMixin
|
||||
from ._oauth_config_mixin import CalendarServiceOAuthConfigMixin
|
||||
from ._oauth_mixin import CalendarServiceOAuthMixin
|
||||
from ._service_mixin import CalendarServiceSupportMixin
|
||||
|
||||
@@ -38,6 +39,7 @@ if TYPE_CHECKING:
|
||||
|
||||
class CalendarService(
|
||||
CalendarServiceOAuthMixin,
|
||||
CalendarServiceOAuthConfigMixin,
|
||||
CalendarServiceConnectionMixin,
|
||||
CalendarServiceEventsMixin,
|
||||
CalendarServiceSupportMixin,
|
||||
|
||||
@@ -108,7 +108,7 @@ def resolve_streaming_config_preference(
|
||||
fallback: StreamingConfig,
|
||||
) -> StreamingConfigResolution | None:
|
||||
"""Resolve a stored streaming config preference into safe runtime values."""
|
||||
parsed = _parse_preference(raw_value)
|
||||
parsed = _parse_streaming_preference(raw_value)
|
||||
if parsed is None:
|
||||
return None
|
||||
|
||||
@@ -120,7 +120,7 @@ def resolve_streaming_config_preference(
|
||||
)
|
||||
|
||||
|
||||
def _parse_preference(raw_value: object) -> StreamingConfigPreference | None:
|
||||
def _parse_streaming_preference(raw_value: object) -> StreamingConfigPreference | None:
|
||||
if not isinstance(raw_value, dict):
|
||||
return None
|
||||
|
||||
|
||||
@@ -6,7 +6,10 @@ from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from noteflow.config.constants.errors import ERROR_WORKSPACE_SCOPE_MISMATCH
|
||||
from noteflow.config.constants.errors import (
|
||||
ERROR_WORKSPACE_ADMIN_REQUIRED,
|
||||
ERROR_WORKSPACE_SCOPE_MISMATCH,
|
||||
)
|
||||
from noteflow.domain.entities import SummarizationTemplate, SummarizationTemplateVersion
|
||||
from noteflow.domain.identity.context import OperationContext
|
||||
from noteflow.domain.utils.time import utc_now
|
||||
@@ -50,8 +53,7 @@ class SummarizationTemplateService:
|
||||
@staticmethod
|
||||
def _require_admin(context: OperationContext) -> None:
|
||||
if not context.is_admin():
|
||||
msg = "Workspace admin role required"
|
||||
raise PermissionError(msg)
|
||||
raise PermissionError(ERROR_WORKSPACE_ADMIN_REQUIRED)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_workspace_scope(context: OperationContext, workspace_id: UUID) -> None:
|
||||
|
||||
@@ -24,6 +24,9 @@ ERR_API_PREFIX: Final[str] = "API error: "
|
||||
ERR_TOKEN_REFRESH_PREFIX: Final[str] = "Token refresh failed: "
|
||||
"""Prefix for token refresh error messages."""
|
||||
|
||||
ERROR_WORKSPACE_ADMIN_REQUIRED: Final[str] = "Workspace admin role required"
|
||||
"""Error message when workspace admin access is required."""
|
||||
|
||||
# =============================================================================
|
||||
# Validation Message Fragments
|
||||
# =============================================================================
|
||||
|
||||
@@ -17,6 +17,7 @@ class FeatureFlags(BaseSettings):
|
||||
NOTEFLOW_FEATURE_NER_ENABLED: Enable named entity recognition (default: False)
|
||||
NOTEFLOW_FEATURE_CALENDAR_ENABLED: Enable calendar integration (default: False)
|
||||
NOTEFLOW_FEATURE_WEBHOOKS_ENABLED: Enable webhook notifications (default: True)
|
||||
NOTEFLOW_FEATURE_ROCM_ENABLED: Enable ROCm GPU support (default: True)
|
||||
"""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
@@ -46,3 +47,10 @@ class FeatureFlags(BaseSettings):
|
||||
bool,
|
||||
Field(default=True, description="Enable webhook notifications"),
|
||||
]
|
||||
rocm_enabled: Annotated[
|
||||
bool,
|
||||
Field(
|
||||
default=True,
|
||||
description="Enable ROCm GPU support for AMD GPUs (requires PyTorch ROCm)",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -66,6 +66,11 @@ USER_PREFERENCES: Final[str] = "user_preferences"
|
||||
DIARIZATION_JOBS: Final[str] = "diarization_jobs"
|
||||
MEETING_TAGS: Final[str] = "meeting_tags"
|
||||
SORT_DESC: Final[str] = "sort_desc"
|
||||
OAUTH_OVERRIDE_ENABLED: Final[str] = "oauth_override_enabled"
|
||||
OAUTH_OVERRIDE_CLIENT_ID: Final[str] = "oauth_override_client_id"
|
||||
OAUTH_OVERRIDE_CLIENT_SECRET: Final[str] = "oauth_override_client_secret"
|
||||
OAUTH_OVERRIDE_REDIRECT_URI: Final[str] = "oauth_override_redirect_uri"
|
||||
OAUTH_OVERRIDE_SCOPES: Final[str] = "oauth_override_scopes"
|
||||
|
||||
# Observability metrics fields
|
||||
TOKENS_INPUT: Final[str] = "tokens_input"
|
||||
|
||||
@@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Protocol
|
||||
from noteflow.config.constants.core import HOURS_PER_DAY
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.domain.value_objects import OAuthProvider, OAuthTokens
|
||||
from noteflow.domain.value_objects import OAuthClientConfig, OAuthProvider, OAuthTokens
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
@@ -59,12 +59,14 @@ class OAuthPort(Protocol):
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
redirect_uri: str,
|
||||
client_config: OAuthClientConfig | None = None,
|
||||
) -> tuple[str, str]:
|
||||
"""Generate OAuth authorization URL with PKCE.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider (google or outlook).
|
||||
redirect_uri: Callback URL after authorization.
|
||||
client_config: Optional client configuration override.
|
||||
|
||||
Returns:
|
||||
Tuple of (authorization_url, state_token).
|
||||
@@ -76,6 +78,7 @@ class OAuthPort(Protocol):
|
||||
provider: OAuthProvider,
|
||||
code: str,
|
||||
state: str,
|
||||
client_config: OAuthClientConfig | None = None,
|
||||
) -> OAuthTokens:
|
||||
"""Exchange authorization code for tokens.
|
||||
|
||||
@@ -83,6 +86,7 @@ class OAuthPort(Protocol):
|
||||
provider: OAuth provider.
|
||||
code: Authorization code from callback.
|
||||
state: State parameter from callback.
|
||||
client_config: Optional client configuration override.
|
||||
|
||||
Returns:
|
||||
OAuth tokens.
|
||||
@@ -96,12 +100,14 @@ class OAuthPort(Protocol):
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
refresh_token: str,
|
||||
client_config: OAuthClientConfig | None = None,
|
||||
) -> OAuthTokens:
|
||||
"""Refresh expired access token.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider.
|
||||
refresh_token: Refresh token from previous exchange.
|
||||
client_config: Optional client configuration override.
|
||||
|
||||
Returns:
|
||||
New OAuth tokens.
|
||||
|
||||
70
src/noteflow/domain/ports/gpu.py
Normal file
70
src/noteflow/domain/ports/gpu.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""GPU backend types and detection protocol."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class GpuBackend(str, Enum):
|
||||
"""Detected GPU backend type.
|
||||
|
||||
Used to identify which GPU runtime is available on the system.
|
||||
ROCm appears as CUDA at the PyTorch level but uses HIP internally.
|
||||
"""
|
||||
|
||||
NONE = "none"
|
||||
CUDA = "cuda"
|
||||
ROCM = "rocm"
|
||||
MPS = "mps"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GpuInfo:
|
||||
"""Information about detected GPU.
|
||||
|
||||
Attributes:
|
||||
backend: The GPU backend type (CUDA, ROCm, MPS, or NONE).
|
||||
device_name: Human-readable GPU name (e.g., "NVIDIA GeForce RTX 4090").
|
||||
vram_total_mb: Total VRAM in megabytes (0 for MPS which doesn't expose this).
|
||||
driver_version: Driver version string (CUDA version or HIP version).
|
||||
architecture: GPU architecture identifier (e.g., "sm_89" for CUDA, "gfx1100" for ROCm).
|
||||
"""
|
||||
|
||||
backend: GpuBackend
|
||||
device_name: str
|
||||
vram_total_mb: int
|
||||
driver_version: str
|
||||
architecture: str | None = None
|
||||
|
||||
|
||||
class GpuDetectionProtocol(Protocol):
|
||||
"""Protocol for GPU detection implementations.
|
||||
|
||||
Allows for testing and alternative detection strategies.
|
||||
"""
|
||||
|
||||
def detect_backend(self) -> GpuBackend:
|
||||
"""Detect the available GPU backend.
|
||||
|
||||
Returns:
|
||||
GpuBackend enum indicating the detected backend.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_info(self) -> GpuInfo | None:
|
||||
"""Get detailed GPU information.
|
||||
|
||||
Returns:
|
||||
GpuInfo if a GPU is available, None otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
def is_supported_for_asr(self) -> bool:
|
||||
"""Check if GPU is supported for ASR workloads.
|
||||
|
||||
Returns:
|
||||
True if the GPU can run ASR models, False otherwise.
|
||||
"""
|
||||
...
|
||||
@@ -32,12 +32,14 @@ class IntegrationRepository(Protocol):
|
||||
self,
|
||||
provider: str,
|
||||
integration_type: str | None = None,
|
||||
workspace_id: UUID | None = None,
|
||||
) -> Integration | None:
|
||||
"""Retrieve an integration by provider name.
|
||||
|
||||
Args:
|
||||
provider: Provider name (e.g., 'google', 'outlook').
|
||||
integration_type: Optional type filter.
|
||||
workspace_id: Optional workspace filter.
|
||||
|
||||
Returns:
|
||||
Integration if found, None otherwise.
|
||||
|
||||
@@ -93,7 +93,11 @@ class MeetingState(IntEnum):
|
||||
"""
|
||||
valid_transitions: dict[MeetingState, set[MeetingState]] = {
|
||||
MeetingState.UNSPECIFIED: {MeetingState.CREATED},
|
||||
MeetingState.CREATED: {MeetingState.RECORDING, MeetingState.STOPPED, MeetingState.ERROR},
|
||||
MeetingState.CREATED: {
|
||||
MeetingState.RECORDING,
|
||||
MeetingState.STOPPED,
|
||||
MeetingState.ERROR,
|
||||
},
|
||||
MeetingState.RECORDING: {MeetingState.STOPPING, MeetingState.ERROR},
|
||||
MeetingState.STOPPING: {MeetingState.STOPPED, MeetingState.ERROR},
|
||||
MeetingState.STOPPED: {MeetingState.COMPLETED, MeetingState.ERROR},
|
||||
@@ -153,6 +157,15 @@ class OAuthState:
|
||||
return datetime.now(self.created_at.tzinfo) > self.expires_at
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class OAuthClientConfig:
|
||||
"""OAuth client configuration overrides."""
|
||||
|
||||
client_id: str
|
||||
client_secret: str
|
||||
redirect_uri: str
|
||||
scopes: tuple[str, ...] = ()
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class OAuthTokens:
|
||||
"""OAuth tokens returned from provider.
|
||||
|
||||
@@ -5,6 +5,7 @@ from .annotation import AnnotationMixin
|
||||
from .asr_config import AsrConfigMixin
|
||||
from .streaming_config import StreamingConfigMixin
|
||||
from .calendar import CalendarMixin
|
||||
from .calendar_oauth_config import CalendarOAuthConfigMixin
|
||||
from .diarization import DiarizationMixin
|
||||
from .diarization_job import DiarizationJobMixin
|
||||
from .entities import EntitiesMixin
|
||||
@@ -30,6 +31,7 @@ __all__ = [
|
||||
"AsrConfigMixin",
|
||||
"StreamingConfigMixin",
|
||||
"CalendarMixin",
|
||||
"CalendarOAuthConfigMixin",
|
||||
"DiarizationJobMixin",
|
||||
"DiarizationMixin",
|
||||
"EntitiesMixin",
|
||||
|
||||
@@ -14,13 +14,14 @@ from noteflow.application.services.asr_config import (
|
||||
AsrConfigService,
|
||||
AsrDevice,
|
||||
)
|
||||
from noteflow.domain.constants.fields import DEVICE
|
||||
from noteflow.domain.constants.fields import (
|
||||
DEVICE,
|
||||
JOB_STATUS_COMPLETED,
|
||||
JOB_STATUS_FAILED,
|
||||
JOB_STATUS_QUEUED,
|
||||
JOB_STATUS_RUNNING,
|
||||
)
|
||||
from noteflow.domain.ports.gpu import GpuBackend
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
from ..proto import noteflow_pb2
|
||||
@@ -39,11 +40,13 @@ logger = get_logger(__name__)
|
||||
_DEVICE_TO_PROTO: dict[AsrDevice, int] = {
|
||||
AsrDevice.CPU: noteflow_pb2.ASR_DEVICE_CPU,
|
||||
AsrDevice.CUDA: noteflow_pb2.ASR_DEVICE_CUDA,
|
||||
AsrDevice.ROCM: noteflow_pb2.ASR_DEVICE_ROCM,
|
||||
}
|
||||
|
||||
_PROTO_TO_DEVICE: dict[int, AsrDevice] = {
|
||||
noteflow_pb2.ASR_DEVICE_CPU: AsrDevice.CPU,
|
||||
noteflow_pb2.ASR_DEVICE_CUDA: AsrDevice.CUDA,
|
||||
noteflow_pb2.ASR_DEVICE_ROCM: AsrDevice.ROCM,
|
||||
}
|
||||
|
||||
_COMPUTE_TYPE_TO_PROTO: dict[AsrComputeType, int] = {
|
||||
@@ -102,6 +105,8 @@ def _build_configuration_proto(
|
||||
cuda_available=caps.cuda_available,
|
||||
available_model_sizes=list(caps.available_model_sizes),
|
||||
available_compute_types=available_compute_types,
|
||||
rocm_available=caps.rocm_available,
|
||||
gpu_backend=caps.gpu_backend,
|
||||
)
|
||||
|
||||
|
||||
@@ -188,6 +193,8 @@ class AsrConfigMixin:
|
||||
cuda_available=False,
|
||||
available_model_sizes=[],
|
||||
available_compute_types=[],
|
||||
rocm_available=False,
|
||||
gpu_backend=GpuBackend.NONE.value,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -12,15 +12,17 @@ from noteflow.domain.value_objects import OAuthProvider
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
from ..proto import noteflow_pb2
|
||||
from .errors import abort_internal, abort_invalid_argument, abort_unavailable
|
||||
from .errors import (
|
||||
abort_internal,
|
||||
abort_invalid_argument,
|
||||
abort_unavailable,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_ERR_CALENDAR_NOT_ENABLED = "Calendar integration not enabled"
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.domain.ports.calendar import OAuthConnectionInfo
|
||||
|
||||
from ._types import GrpcContext
|
||||
from .protocols import ServicerHost
|
||||
|
||||
@@ -62,7 +64,7 @@ def _build_oauth_connection(
|
||||
)
|
||||
|
||||
|
||||
async def _require_calendar_service(
|
||||
async def require_calendar_service(
|
||||
host: ServicerHost,
|
||||
context: GrpcContext,
|
||||
operation: str,
|
||||
@@ -91,7 +93,7 @@ class CalendarMixin:
|
||||
context: GrpcContext,
|
||||
) -> noteflow_pb2.ListCalendarEventsResponse:
|
||||
"""List upcoming calendar events from connected providers."""
|
||||
service = await _require_calendar_service(self, context, "calendar_list_events")
|
||||
service = await require_calendar_service(self, context, "calendar_list_events")
|
||||
|
||||
provider = request.provider or None
|
||||
hours_ahead = request.hours_ahead if request.hours_ahead > 0 else None
|
||||
@@ -103,12 +105,14 @@ class CalendarMixin:
|
||||
hours_ahead=hours_ahead,
|
||||
limit=limit,
|
||||
)
|
||||
workspace_id = self.get_operation_context(context).workspace_id
|
||||
|
||||
try:
|
||||
events = await service.list_calendar_events(
|
||||
provider=provider,
|
||||
hours_ahead=hours_ahead,
|
||||
limit=limit,
|
||||
workspace_id=workspace_id,
|
||||
)
|
||||
except CalendarServiceError as e:
|
||||
logger.error("calendar_list_events_failed", error=str(e), provider=provider)
|
||||
@@ -134,7 +138,7 @@ class CalendarMixin:
|
||||
context: GrpcContext,
|
||||
) -> noteflow_pb2.GetCalendarProvidersResponse:
|
||||
"""Get available calendar providers with authentication status."""
|
||||
service = await _require_calendar_service(self, context, "calendar_providers")
|
||||
service = await require_calendar_service(self, context, "calendar_providers")
|
||||
|
||||
logger.debug("calendar_get_providers_request")
|
||||
|
||||
@@ -143,7 +147,10 @@ class CalendarMixin:
|
||||
(OAuthProvider.GOOGLE.value, "Google Calendar"),
|
||||
(OAuthProvider.OUTLOOK.value, "Microsoft Outlook"),
|
||||
]:
|
||||
status: OAuthConnectionInfo = await service.get_connection_status(provider_name)
|
||||
status: OAuthConnectionInfo = await service.get_connection_status(
|
||||
provider_name,
|
||||
workspace_id=self.get_operation_context(context).workspace_id,
|
||||
)
|
||||
is_authenticated = status.status == IntegrationStatus.CONNECTED.value
|
||||
providers.append(
|
||||
noteflow_pb2.CalendarProvider(
|
||||
@@ -159,8 +166,9 @@ class CalendarMixin:
|
||||
status=status.status,
|
||||
)
|
||||
|
||||
authenticated_count = sum(bool(p.is_authenticated)
|
||||
for p in providers)
|
||||
authenticated_count = sum(
|
||||
bool(provider.is_authenticated) for provider in providers
|
||||
)
|
||||
logger.info(
|
||||
"calendar_get_providers_success",
|
||||
total_providers=len(providers),
|
||||
@@ -175,18 +183,20 @@ class CalendarMixin:
|
||||
context: GrpcContext,
|
||||
) -> noteflow_pb2.InitiateOAuthResponse:
|
||||
"""Start OAuth flow for a calendar provider."""
|
||||
service = await _require_calendar_service(self, context, "oauth_initiate")
|
||||
service = await require_calendar_service(self, context, "oauth_initiate")
|
||||
|
||||
logger.debug(
|
||||
"oauth_initiate_request",
|
||||
provider=request.provider,
|
||||
has_redirect_uri=bool(request.redirect_uri),
|
||||
)
|
||||
workspace_id = self.get_operation_context(context).workspace_id
|
||||
|
||||
try:
|
||||
auth_url, state = await service.initiate_oauth(
|
||||
provider=request.provider,
|
||||
redirect_uri=request.redirect_uri or None,
|
||||
workspace_id=workspace_id,
|
||||
)
|
||||
except CalendarServiceError as e:
|
||||
logger.error(
|
||||
@@ -214,19 +224,21 @@ class CalendarMixin:
|
||||
context: GrpcContext,
|
||||
) -> noteflow_pb2.CompleteOAuthResponse:
|
||||
"""Complete OAuth flow with authorization code."""
|
||||
service = await _require_calendar_service(self, context, "oauth_complete")
|
||||
service = await require_calendar_service(self, context, "oauth_complete")
|
||||
|
||||
logger.debug(
|
||||
"oauth_complete_request",
|
||||
provider=request.provider,
|
||||
state=request.state,
|
||||
)
|
||||
workspace_id = self.get_operation_context(context).workspace_id
|
||||
|
||||
try:
|
||||
integration_id = await service.complete_oauth(
|
||||
provider=request.provider,
|
||||
code=request.code,
|
||||
state=request.state,
|
||||
workspace_id=workspace_id,
|
||||
)
|
||||
except CalendarServiceError as e:
|
||||
logger.warning(
|
||||
@@ -239,7 +251,7 @@ class CalendarMixin:
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
status = await service.get_connection_status(request.provider)
|
||||
status = await service.get_connection_status(request.provider, workspace_id=workspace_id)
|
||||
|
||||
logger.info(
|
||||
"oauth_complete_success",
|
||||
@@ -260,7 +272,7 @@ class CalendarMixin:
|
||||
context: GrpcContext,
|
||||
) -> noteflow_pb2.GetOAuthConnectionStatusResponse:
|
||||
"""Get OAuth connection status for a provider."""
|
||||
service = await _require_calendar_service(self, context, "oauth_status")
|
||||
service = await require_calendar_service(self, context, "oauth_status")
|
||||
|
||||
logger.debug(
|
||||
"oauth_status_request",
|
||||
@@ -268,7 +280,10 @@ class CalendarMixin:
|
||||
integration_type=request.integration_type or CALENDAR,
|
||||
)
|
||||
|
||||
info = await service.get_connection_status(request.provider)
|
||||
info = await service.get_connection_status(
|
||||
request.provider,
|
||||
workspace_id=self.get_operation_context(context).workspace_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"oauth_status_retrieved",
|
||||
@@ -288,11 +303,14 @@ class CalendarMixin:
|
||||
context: GrpcContext,
|
||||
) -> noteflow_pb2.DisconnectOAuthResponse:
|
||||
"""Disconnect OAuth integration and revoke tokens."""
|
||||
service = await _require_calendar_service(self, context, "oauth_disconnect")
|
||||
service = await require_calendar_service(self, context, "oauth_disconnect")
|
||||
|
||||
logger.debug("oauth_disconnect_request", provider=request.provider)
|
||||
|
||||
success = await service.disconnect(request.provider)
|
||||
success = await service.disconnect(
|
||||
request.provider,
|
||||
workspace_id=self.get_operation_context(context).workspace_id,
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("oauth_disconnect_success", provider=request.provider)
|
||||
|
||||
134
src/noteflow/grpc/mixins/calendar_oauth_config.py
Normal file
134
src/noteflow/grpc/mixins/calendar_oauth_config.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""OAuth client config mixin for calendar integrations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from noteflow.application.services.calendar import CalendarServiceError
|
||||
from noteflow.config.constants.errors import ERROR_WORKSPACE_ADMIN_REQUIRED
|
||||
from noteflow.domain.constants.fields import ENTITY_WORKSPACE
|
||||
from noteflow.domain.value_objects import OAuthClientConfig
|
||||
|
||||
from ..proto import noteflow_pb2
|
||||
from .calendar import require_calendar_service
|
||||
from .errors import (
|
||||
abort_invalid_argument,
|
||||
abort_not_found,
|
||||
abort_permission_denied,
|
||||
parse_workspace_id,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._types import GrpcContext
|
||||
from .protocols import ServicerHost
|
||||
|
||||
def _build_oauth_client_config(
|
||||
config: noteflow_pb2.OAuthClientConfig,
|
||||
) -> OAuthClientConfig:
|
||||
client_secret = config.client_secret if config.HasField("client_secret") else ""
|
||||
return OAuthClientConfig(
|
||||
client_id=config.client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=config.redirect_uri,
|
||||
scopes=tuple(config.scopes),
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_workspace_id(
|
||||
host: ServicerHost,
|
||||
context: GrpcContext,
|
||||
request_workspace_id: str,
|
||||
) -> UUID:
|
||||
if request_workspace_id:
|
||||
return await parse_workspace_id(request_workspace_id, context)
|
||||
return host.get_operation_context(context).workspace_id
|
||||
|
||||
|
||||
async def _require_admin_access(
|
||||
host: ServicerHost,
|
||||
context: GrpcContext,
|
||||
workspace_id: UUID,
|
||||
) -> None:
|
||||
async with host.create_repository_provider() as uow:
|
||||
if not uow.supports_workspaces:
|
||||
return
|
||||
|
||||
user_ctx = await host.identity_service.get_or_create_default_user(uow)
|
||||
workspace = await uow.workspaces.get(workspace_id)
|
||||
if not workspace:
|
||||
await abort_not_found(context, ENTITY_WORKSPACE, str(workspace_id))
|
||||
raise AssertionError("unreachable") from None
|
||||
|
||||
membership = await uow.workspaces.get_membership(workspace_id, user_ctx.user_id)
|
||||
if not membership:
|
||||
await abort_not_found(context, "Workspace membership", str(workspace_id))
|
||||
raise AssertionError("unreachable") from None
|
||||
|
||||
if not membership.role.can_admin():
|
||||
await abort_permission_denied(context, ERROR_WORKSPACE_ADMIN_REQUIRED)
|
||||
raise AssertionError("unreachable") from None
|
||||
|
||||
|
||||
class CalendarOAuthConfigMixin:
|
||||
"""Mixin providing OAuth client config endpoints."""
|
||||
|
||||
async def GetOAuthClientConfig(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.GetOAuthClientConfigRequest,
|
||||
context: GrpcContext,
|
||||
) -> noteflow_pb2.GetOAuthClientConfigResponse:
|
||||
"""Get OAuth override config for a calendar provider."""
|
||||
service = await require_calendar_service(self, context, "oauth_client_config_get")
|
||||
workspace_id = await _resolve_workspace_id(self, context, request.workspace_id)
|
||||
|
||||
try:
|
||||
config, override_enabled, has_secret = await service.get_oauth_client_config(
|
||||
request.provider,
|
||||
workspace_id=workspace_id,
|
||||
)
|
||||
except CalendarServiceError as e:
|
||||
await abort_invalid_argument(context, str(e))
|
||||
raise AssertionError("unreachable") from None
|
||||
|
||||
return noteflow_pb2.GetOAuthClientConfigResponse(
|
||||
config=noteflow_pb2.OAuthClientConfig(
|
||||
client_id=config.client_id,
|
||||
redirect_uri=config.redirect_uri,
|
||||
scopes=list(config.scopes),
|
||||
override_enabled=override_enabled,
|
||||
has_client_secret=has_secret,
|
||||
)
|
||||
)
|
||||
|
||||
async def SetOAuthClientConfig(
|
||||
self: ServicerHost,
|
||||
request: noteflow_pb2.SetOAuthClientConfigRequest,
|
||||
context: GrpcContext,
|
||||
) -> noteflow_pb2.SetOAuthClientConfigResponse:
|
||||
"""Set OAuth override config for a calendar provider."""
|
||||
service = await require_calendar_service(self, context, "oauth_client_config_set")
|
||||
workspace_id = await _resolve_workspace_id(self, context, request.workspace_id)
|
||||
await _require_admin_access(self, context, workspace_id)
|
||||
|
||||
if not request.provider:
|
||||
await abort_invalid_argument(context, "Provider is required")
|
||||
raise AssertionError("unreachable") from None
|
||||
if not request.HasField("config"):
|
||||
await abort_invalid_argument(context, "OAuth config is required")
|
||||
raise AssertionError("unreachable") from None
|
||||
|
||||
client_config = _build_oauth_client_config(request.config)
|
||||
|
||||
try:
|
||||
await service.set_oauth_client_config(
|
||||
provider=request.provider,
|
||||
client_config=client_config,
|
||||
override_enabled=request.config.override_enabled,
|
||||
workspace_id=workspace_id,
|
||||
)
|
||||
except CalendarServiceError as e:
|
||||
await abort_invalid_argument(context, str(e))
|
||||
raise AssertionError("unreachable") from None
|
||||
|
||||
return noteflow_pb2.SetOAuthClientConfigResponse(success=True)
|
||||
@@ -5,7 +5,10 @@ from __future__ import annotations
|
||||
from dataclasses import replace
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from noteflow.config.constants.errors import ERROR_WORKSPACE_ID_REQUIRED
|
||||
from noteflow.config.constants.errors import (
|
||||
ERROR_WORKSPACE_ADMIN_REQUIRED,
|
||||
ERROR_WORKSPACE_ID_REQUIRED,
|
||||
)
|
||||
from noteflow.domain.constants.fields import ENTITY_WORKSPACE
|
||||
from noteflow.domain.entities.integration import IntegrationType
|
||||
from noteflow.domain.ports.unit_of_work import UnitOfWork
|
||||
@@ -274,7 +277,7 @@ class IdentityMixin:
|
||||
)
|
||||
|
||||
if not membership.role.can_admin():
|
||||
await abort_permission_denied(context, "Workspace admin role required")
|
||||
await abort_permission_denied(context, ERROR_WORKSPACE_ADMIN_REQUIRED)
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
updates = proto_to_workspace_settings(request.settings)
|
||||
|
||||
@@ -75,6 +75,8 @@ service NoteFlowService {
|
||||
rpc CompleteOAuth(CompleteOAuthRequest) returns (CompleteOAuthResponse);
|
||||
rpc GetOAuthConnectionStatus(GetOAuthConnectionStatusRequest) returns (GetOAuthConnectionStatusResponse);
|
||||
rpc DisconnectOAuth(DisconnectOAuthRequest) returns (DisconnectOAuthResponse);
|
||||
rpc GetOAuthClientConfig(GetOAuthClientConfigRequest) returns (GetOAuthClientConfigResponse);
|
||||
rpc SetOAuthClientConfig(SetOAuthClientConfigRequest) returns (SetOAuthClientConfigResponse);
|
||||
|
||||
// Webhook management (Sprint 6)
|
||||
rpc RegisterWebhook(RegisterWebhookRequest) returns (WebhookConfigProto);
|
||||
@@ -649,6 +651,7 @@ enum AsrDevice {
|
||||
ASR_DEVICE_UNSPECIFIED = 0;
|
||||
ASR_DEVICE_CPU = 1;
|
||||
ASR_DEVICE_CUDA = 2;
|
||||
ASR_DEVICE_ROCM = 3;
|
||||
}
|
||||
|
||||
// Valid ASR compute types
|
||||
@@ -681,6 +684,12 @@ message AsrConfiguration {
|
||||
|
||||
// Available compute types for current device
|
||||
repeated AsrComputeType available_compute_types = 7;
|
||||
|
||||
// Whether ROCm is available on this server
|
||||
bool rocm_available = 8;
|
||||
|
||||
// Current GPU backend (none, cuda, rocm, mps)
|
||||
string gpu_backend = 9;
|
||||
}
|
||||
|
||||
message GetAsrConfigurationRequest {}
|
||||
@@ -1319,6 +1328,60 @@ message DisconnectOAuthResponse {
|
||||
string error_message = 2;
|
||||
}
|
||||
|
||||
// OAuth client override configuration
|
||||
message OAuthClientConfig {
|
||||
// OAuth client ID
|
||||
string client_id = 1;
|
||||
|
||||
// Optional client secret (request only)
|
||||
optional string client_secret = 2;
|
||||
|
||||
// Redirect URI for OAuth callback
|
||||
string redirect_uri = 3;
|
||||
|
||||
// OAuth scopes to request
|
||||
repeated string scopes = 4;
|
||||
|
||||
// Whether override should be used
|
||||
bool override_enabled = 5;
|
||||
|
||||
// Whether a client secret is stored (response only)
|
||||
bool has_client_secret = 6;
|
||||
}
|
||||
|
||||
message GetOAuthClientConfigRequest {
|
||||
// Provider to configure: google, outlook
|
||||
string provider = 1;
|
||||
|
||||
// Optional integration type
|
||||
string integration_type = 2;
|
||||
|
||||
// Optional workspace ID override
|
||||
string workspace_id = 3;
|
||||
}
|
||||
|
||||
message GetOAuthClientConfigResponse {
|
||||
OAuthClientConfig config = 1;
|
||||
}
|
||||
|
||||
message SetOAuthClientConfigRequest {
|
||||
// Provider to configure: google, outlook
|
||||
string provider = 1;
|
||||
|
||||
// Optional integration type
|
||||
string integration_type = 2;
|
||||
|
||||
// Optional workspace ID override
|
||||
string workspace_id = 3;
|
||||
|
||||
// OAuth client configuration
|
||||
OAuthClientConfig config = 4;
|
||||
}
|
||||
|
||||
message SetOAuthClientConfigResponse {
|
||||
bool success = 1;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Webhook Management Messages (Sprint 6)
|
||||
// =============================================================================
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -42,6 +42,7 @@ class AsrDevice(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
||||
ASR_DEVICE_UNSPECIFIED: _ClassVar[AsrDevice]
|
||||
ASR_DEVICE_CPU: _ClassVar[AsrDevice]
|
||||
ASR_DEVICE_CUDA: _ClassVar[AsrDevice]
|
||||
ASR_DEVICE_ROCM: _ClassVar[AsrDevice]
|
||||
|
||||
class AsrComputeType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
||||
__slots__ = ()
|
||||
@@ -110,6 +111,7 @@ PRIORITY_HIGH: Priority
|
||||
ASR_DEVICE_UNSPECIFIED: AsrDevice
|
||||
ASR_DEVICE_CPU: AsrDevice
|
||||
ASR_DEVICE_CUDA: AsrDevice
|
||||
ASR_DEVICE_ROCM: AsrDevice
|
||||
ASR_COMPUTE_TYPE_UNSPECIFIED: AsrComputeType
|
||||
ASR_COMPUTE_TYPE_INT8: AsrComputeType
|
||||
ASR_COMPUTE_TYPE_FLOAT16: AsrComputeType
|
||||
@@ -577,7 +579,7 @@ class ServerInfo(_message.Message):
|
||||
def __init__(self, version: _Optional[str] = ..., asr_model: _Optional[str] = ..., asr_ready: bool = ..., supported_sample_rates: _Optional[_Iterable[int]] = ..., max_chunk_size: _Optional[int] = ..., uptime_seconds: _Optional[float] = ..., active_meetings: _Optional[int] = ..., diarization_enabled: bool = ..., diarization_ready: bool = ..., state_version: _Optional[int] = ..., system_ram_total_bytes: _Optional[int] = ..., system_ram_available_bytes: _Optional[int] = ..., gpu_vram_total_bytes: _Optional[int] = ..., gpu_vram_available_bytes: _Optional[int] = ...) -> None: ...
|
||||
|
||||
class AsrConfiguration(_message.Message):
|
||||
__slots__ = ("model_size", "device", "compute_type", "is_ready", "cuda_available", "available_model_sizes", "available_compute_types")
|
||||
__slots__ = ("model_size", "device", "compute_type", "is_ready", "cuda_available", "available_model_sizes", "available_compute_types", "rocm_available", "gpu_backend")
|
||||
MODEL_SIZE_FIELD_NUMBER: _ClassVar[int]
|
||||
DEVICE_FIELD_NUMBER: _ClassVar[int]
|
||||
COMPUTE_TYPE_FIELD_NUMBER: _ClassVar[int]
|
||||
@@ -585,6 +587,8 @@ class AsrConfiguration(_message.Message):
|
||||
CUDA_AVAILABLE_FIELD_NUMBER: _ClassVar[int]
|
||||
AVAILABLE_MODEL_SIZES_FIELD_NUMBER: _ClassVar[int]
|
||||
AVAILABLE_COMPUTE_TYPES_FIELD_NUMBER: _ClassVar[int]
|
||||
ROCM_AVAILABLE_FIELD_NUMBER: _ClassVar[int]
|
||||
GPU_BACKEND_FIELD_NUMBER: _ClassVar[int]
|
||||
model_size: str
|
||||
device: AsrDevice
|
||||
compute_type: AsrComputeType
|
||||
@@ -592,7 +596,9 @@ class AsrConfiguration(_message.Message):
|
||||
cuda_available: bool
|
||||
available_model_sizes: _containers.RepeatedScalarFieldContainer[str]
|
||||
available_compute_types: _containers.RepeatedScalarFieldContainer[AsrComputeType]
|
||||
def __init__(self, model_size: _Optional[str] = ..., device: _Optional[_Union[AsrDevice, str]] = ..., compute_type: _Optional[_Union[AsrComputeType, str]] = ..., is_ready: bool = ..., cuda_available: bool = ..., available_model_sizes: _Optional[_Iterable[str]] = ..., available_compute_types: _Optional[_Iterable[_Union[AsrComputeType, str]]] = ...) -> None: ...
|
||||
rocm_available: bool
|
||||
gpu_backend: str
|
||||
def __init__(self, model_size: _Optional[str] = ..., device: _Optional[_Union[AsrDevice, str]] = ..., compute_type: _Optional[_Union[AsrComputeType, str]] = ..., is_ready: bool = ..., cuda_available: bool = ..., available_model_sizes: _Optional[_Iterable[str]] = ..., available_compute_types: _Optional[_Iterable[_Union[AsrComputeType, str]]] = ..., rocm_available: bool = ..., gpu_backend: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class GetAsrConfigurationRequest(_message.Message):
|
||||
__slots__ = ()
|
||||
@@ -1124,6 +1130,56 @@ class DisconnectOAuthResponse(_message.Message):
|
||||
error_message: str
|
||||
def __init__(self, success: bool = ..., error_message: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class OAuthClientConfig(_message.Message):
|
||||
__slots__ = ("client_id", "client_secret", "redirect_uri", "scopes", "override_enabled", "has_client_secret")
|
||||
CLIENT_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
CLIENT_SECRET_FIELD_NUMBER: _ClassVar[int]
|
||||
REDIRECT_URI_FIELD_NUMBER: _ClassVar[int]
|
||||
SCOPES_FIELD_NUMBER: _ClassVar[int]
|
||||
OVERRIDE_ENABLED_FIELD_NUMBER: _ClassVar[int]
|
||||
HAS_CLIENT_SECRET_FIELD_NUMBER: _ClassVar[int]
|
||||
client_id: str
|
||||
client_secret: str
|
||||
redirect_uri: str
|
||||
scopes: _containers.RepeatedScalarFieldContainer[str]
|
||||
override_enabled: bool
|
||||
has_client_secret: bool
|
||||
def __init__(self, client_id: _Optional[str] = ..., client_secret: _Optional[str] = ..., redirect_uri: _Optional[str] = ..., scopes: _Optional[_Iterable[str]] = ..., override_enabled: bool = ..., has_client_secret: bool = ...) -> None: ...
|
||||
|
||||
class GetOAuthClientConfigRequest(_message.Message):
|
||||
__slots__ = ("provider", "integration_type", "workspace_id")
|
||||
PROVIDER_FIELD_NUMBER: _ClassVar[int]
|
||||
INTEGRATION_TYPE_FIELD_NUMBER: _ClassVar[int]
|
||||
WORKSPACE_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
provider: str
|
||||
integration_type: str
|
||||
workspace_id: str
|
||||
def __init__(self, provider: _Optional[str] = ..., integration_type: _Optional[str] = ..., workspace_id: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class GetOAuthClientConfigResponse(_message.Message):
|
||||
__slots__ = ("config",)
|
||||
CONFIG_FIELD_NUMBER: _ClassVar[int]
|
||||
config: OAuthClientConfig
|
||||
def __init__(self, config: _Optional[_Union[OAuthClientConfig, _Mapping]] = ...) -> None: ...
|
||||
|
||||
class SetOAuthClientConfigRequest(_message.Message):
|
||||
__slots__ = ("provider", "integration_type", "workspace_id", "config")
|
||||
PROVIDER_FIELD_NUMBER: _ClassVar[int]
|
||||
INTEGRATION_TYPE_FIELD_NUMBER: _ClassVar[int]
|
||||
WORKSPACE_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
CONFIG_FIELD_NUMBER: _ClassVar[int]
|
||||
provider: str
|
||||
integration_type: str
|
||||
workspace_id: str
|
||||
config: OAuthClientConfig
|
||||
def __init__(self, provider: _Optional[str] = ..., integration_type: _Optional[str] = ..., workspace_id: _Optional[str] = ..., config: _Optional[_Union[OAuthClientConfig, _Mapping]] = ...) -> None: ...
|
||||
|
||||
class SetOAuthClientConfigResponse(_message.Message):
|
||||
__slots__ = ("success",)
|
||||
SUCCESS_FIELD_NUMBER: _ClassVar[int]
|
||||
success: bool
|
||||
def __init__(self, success: bool = ...) -> None: ...
|
||||
|
||||
class RegisterWebhookRequest(_message.Message):
|
||||
__slots__ = ("workspace_id", "url", "events", "name", "secret", "timeout_ms", "max_retries")
|
||||
WORKSPACE_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
from . import noteflow_pb2 as noteflow__pb2
|
||||
import noteflow_pb2 as noteflow__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.76.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.76.0'
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
@@ -18,7 +17,8 @@ except ImportError:
|
||||
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION}, but the generated code in noteflow_pb2_grpc.py depends on'
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ ' but the generated code in noteflow_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
@@ -238,6 +238,16 @@ class NoteFlowServiceStub(object):
|
||||
request_serializer=noteflow__pb2.DisconnectOAuthRequest.SerializeToString,
|
||||
response_deserializer=noteflow__pb2.DisconnectOAuthResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.GetOAuthClientConfig = channel.unary_unary(
|
||||
'/noteflow.NoteFlowService/GetOAuthClientConfig',
|
||||
request_serializer=noteflow__pb2.GetOAuthClientConfigRequest.SerializeToString,
|
||||
response_deserializer=noteflow__pb2.GetOAuthClientConfigResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.SetOAuthClientConfig = channel.unary_unary(
|
||||
'/noteflow.NoteFlowService/SetOAuthClientConfig',
|
||||
request_serializer=noteflow__pb2.SetOAuthClientConfigRequest.SerializeToString,
|
||||
response_deserializer=noteflow__pb2.SetOAuthClientConfigResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.RegisterWebhook = channel.unary_unary(
|
||||
'/noteflow.NoteFlowService/RegisterWebhook',
|
||||
request_serializer=noteflow__pb2.RegisterWebhookRequest.SerializeToString,
|
||||
@@ -730,6 +740,18 @@ class NoteFlowServiceServicer(object):
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GetOAuthClientConfig(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SetOAuthClientConfig(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def RegisterWebhook(self, request, context):
|
||||
"""Webhook management (Sprint 6)
|
||||
"""
|
||||
@@ -1221,6 +1243,16 @@ def add_NoteFlowServiceServicer_to_server(servicer, server):
|
||||
request_deserializer=noteflow__pb2.DisconnectOAuthRequest.FromString,
|
||||
response_serializer=noteflow__pb2.DisconnectOAuthResponse.SerializeToString,
|
||||
),
|
||||
'GetOAuthClientConfig': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetOAuthClientConfig,
|
||||
request_deserializer=noteflow__pb2.GetOAuthClientConfigRequest.FromString,
|
||||
response_serializer=noteflow__pb2.GetOAuthClientConfigResponse.SerializeToString,
|
||||
),
|
||||
'SetOAuthClientConfig': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SetOAuthClientConfig,
|
||||
request_deserializer=noteflow__pb2.SetOAuthClientConfigRequest.FromString,
|
||||
response_serializer=noteflow__pb2.SetOAuthClientConfigResponse.SerializeToString,
|
||||
),
|
||||
'RegisterWebhook': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.RegisterWebhook,
|
||||
request_deserializer=noteflow__pb2.RegisterWebhookRequest.FromString,
|
||||
@@ -2546,6 +2578,60 @@ class NoteFlowService(object):
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def GetOAuthClientConfig(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/noteflow.NoteFlowService/GetOAuthClientConfig',
|
||||
noteflow__pb2.GetOAuthClientConfigRequest.SerializeToString,
|
||||
noteflow__pb2.GetOAuthClientConfigResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SetOAuthClientConfig(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/noteflow.NoteFlowService/SetOAuthClientConfig',
|
||||
noteflow__pb2.SetOAuthClientConfigRequest.SerializeToString,
|
||||
noteflow__pb2.SetOAuthClientConfigResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def RegisterWebhook(request,
|
||||
target,
|
||||
|
||||
@@ -31,6 +31,7 @@ from .identity.singleton import default_identity_service
|
||||
from .mixins import (
|
||||
AnnotationMixin,
|
||||
AsrConfigMixin,
|
||||
CalendarOAuthConfigMixin,
|
||||
CalendarMixin,
|
||||
DiarizationJobMixin,
|
||||
DiarizationMixin,
|
||||
@@ -93,6 +94,7 @@ class NoteFlowServicer(
|
||||
ExportMixin,
|
||||
EntitiesMixin,
|
||||
CalendarMixin,
|
||||
CalendarOAuthConfigMixin,
|
||||
WebhooksMixin,
|
||||
SyncMixin,
|
||||
ObservabilityMixin,
|
||||
|
||||
@@ -9,6 +9,10 @@ import asyncio
|
||||
from collections.abc import Iterable, Iterator
|
||||
from typing import TYPE_CHECKING, Final, Protocol, TypedDict, Unpack, cast
|
||||
|
||||
from noteflow.application.services.asr_config.types import (
|
||||
INVALID_MODEL_SIZE_PREFIX,
|
||||
VALID_SIZES_SUFFIX,
|
||||
)
|
||||
from noteflow.infrastructure.asr.dto import AsrResult, WordTiming
|
||||
from noteflow.infrastructure.logging import get_logger, log_timing
|
||||
|
||||
@@ -116,7 +120,7 @@ class FasterWhisperEngine:
|
||||
|
||||
if model_size not in VALID_MODEL_SIZES:
|
||||
raise ValueError(
|
||||
f"Invalid model size: {model_size}. Valid sizes: {', '.join(VALID_MODEL_SIZES)}"
|
||||
f"{INVALID_MODEL_SIZE_PREFIX}{model_size}{VALID_SIZES_SUFFIX}{', '.join(VALID_MODEL_SIZES)}"
|
||||
)
|
||||
|
||||
with log_timing(
|
||||
|
||||
207
src/noteflow/infrastructure/asr/factory.py
Normal file
207
src/noteflow/infrastructure/asr/factory.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""ASR engine factory for backend selection.
|
||||
|
||||
Provides factory functions to create ASR engines based on available
|
||||
GPU backends and user preferences. Handles automatic device detection
|
||||
and fallback logic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.application.services.asr_config.types import AsrComputeType
|
||||
from noteflow.domain.ports.gpu import GpuBackend
|
||||
from noteflow.infrastructure.gpu.detection import (
|
||||
detect_gpu_backend,
|
||||
get_gpu_info,
|
||||
is_ctranslate2_rocm_available,
|
||||
is_rocm_architecture_supported,
|
||||
)
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.infrastructure.asr.protocols import AsrEngine
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class EngineCreationError(Exception):
|
||||
"""Raised when ASR engine creation fails."""
|
||||
|
||||
|
||||
def create_asr_engine(
|
||||
device: str = "auto",
|
||||
compute_type: str = "int8",
|
||||
*,
|
||||
prefer_faster_whisper: bool = True,
|
||||
) -> AsrEngine:
|
||||
"""Create an ASR engine for the specified device.
|
||||
|
||||
Auto-detects GPU backends and selects appropriate implementation.
|
||||
Falls back to PyTorch Whisper when faster-whisper is unavailable.
|
||||
|
||||
Args:
|
||||
device: Target device ("auto", "cpu", "cuda", "rocm").
|
||||
compute_type: Compute precision ("int8", "float16", "float32").
|
||||
prefer_faster_whisper: Prefer CTranslate2-based faster-whisper.
|
||||
|
||||
Returns:
|
||||
An ASR engine implementing AsrEngine.
|
||||
|
||||
Raises:
|
||||
EngineCreationError: If engine creation fails.
|
||||
"""
|
||||
resolved_device = resolve_device(device)
|
||||
|
||||
logger.info(
|
||||
"Creating ASR engine",
|
||||
requested_device=device,
|
||||
resolved_device=resolved_device,
|
||||
compute_type=compute_type,
|
||||
prefer_faster_whisper=prefer_faster_whisper,
|
||||
)
|
||||
|
||||
if resolved_device == "cpu":
|
||||
return _create_cpu_engine(compute_type)
|
||||
|
||||
if resolved_device == GpuBackend.CUDA.value:
|
||||
return _create_cuda_engine(compute_type, prefer_faster_whisper)
|
||||
|
||||
if resolved_device == GpuBackend.ROCM.value:
|
||||
return _create_rocm_engine(compute_type, prefer_faster_whisper)
|
||||
|
||||
msg = f"Unsupported device: {resolved_device}"
|
||||
raise EngineCreationError(msg)
|
||||
|
||||
|
||||
def resolve_device(device: str) -> str:
|
||||
"""Resolve 'auto' device to actual backend.
|
||||
|
||||
Args:
|
||||
device: Requested device string.
|
||||
|
||||
Returns:
|
||||
Resolved device string ("cpu", "cuda", or "rocm").
|
||||
"""
|
||||
if device != "auto":
|
||||
return device
|
||||
|
||||
backend = detect_gpu_backend()
|
||||
|
||||
if backend == GpuBackend.CUDA:
|
||||
return GpuBackend.CUDA.value
|
||||
|
||||
if backend == GpuBackend.ROCM:
|
||||
# Check if ROCm architecture is supported for ASR
|
||||
gpu_info = get_gpu_info()
|
||||
if gpu_info and is_rocm_architecture_supported(gpu_info.architecture):
|
||||
return GpuBackend.ROCM.value
|
||||
logger.warning(
|
||||
"ROCm detected but architecture may not be supported, falling back to CPU",
|
||||
architecture=gpu_info.architecture if gpu_info else None,
|
||||
)
|
||||
return "cpu"
|
||||
|
||||
# MPS not supported by faster-whisper; PyTorch Whisper may work but is untested
|
||||
if backend == GpuBackend.MPS:
|
||||
logger.info("MPS detected but not supported for ASR, using CPU")
|
||||
|
||||
return "cpu"
|
||||
|
||||
|
||||
def _create_cpu_engine(compute_type: str) -> AsrEngine:
|
||||
"""Create CPU engine (always uses faster-whisper).
|
||||
|
||||
Args:
|
||||
compute_type: Requested compute type.
|
||||
|
||||
Returns:
|
||||
ASR engine for CPU.
|
||||
"""
|
||||
from noteflow.infrastructure.asr.engine import FasterWhisperEngine
|
||||
|
||||
# CPU only supports int8 and float32
|
||||
if compute_type == AsrComputeType.FLOAT16.value:
|
||||
logger.debug("float16 not supported on CPU, using float32")
|
||||
compute_type = AsrComputeType.FLOAT32.value
|
||||
|
||||
return FasterWhisperEngine(device="cpu", compute_type=compute_type)
|
||||
|
||||
|
||||
def _create_cuda_engine(
|
||||
compute_type: str,
|
||||
prefer_faster_whisper: bool,
|
||||
) -> AsrEngine:
|
||||
"""Create CUDA engine.
|
||||
|
||||
Args:
|
||||
compute_type: Compute precision.
|
||||
prefer_faster_whisper: Whether to prefer faster-whisper.
|
||||
|
||||
Returns:
|
||||
ASR engine for CUDA.
|
||||
"""
|
||||
if prefer_faster_whisper:
|
||||
from noteflow.infrastructure.asr.engine import FasterWhisperEngine
|
||||
|
||||
return FasterWhisperEngine(device="cuda", compute_type=compute_type)
|
||||
|
||||
return _create_pytorch_engine("cuda", compute_type)
|
||||
|
||||
|
||||
def _create_rocm_engine(
|
||||
compute_type: str,
|
||||
prefer_faster_whisper: bool,
|
||||
) -> AsrEngine:
|
||||
"""Create ROCm engine.
|
||||
|
||||
Attempts to use CTranslate2-ROCm fork if available,
|
||||
falls back to PyTorch Whisper otherwise.
|
||||
|
||||
Args:
|
||||
compute_type: Compute precision.
|
||||
prefer_faster_whisper: Whether to prefer faster-whisper.
|
||||
|
||||
Returns:
|
||||
ASR engine for ROCm.
|
||||
"""
|
||||
if prefer_faster_whisper and is_ctranslate2_rocm_available():
|
||||
try:
|
||||
from noteflow.infrastructure.asr.rocm_engine import FasterWhisperRocmEngine
|
||||
|
||||
logger.info("Using CTranslate2-ROCm for ASR")
|
||||
return FasterWhisperRocmEngine(compute_type=compute_type)
|
||||
except ImportError as e:
|
||||
logger.warning(
|
||||
"CTranslate2-ROCm import failed, falling back to PyTorch Whisper",
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
logger.info("Using PyTorch Whisper for ROCm ASR")
|
||||
# ROCm uses "cuda" device string internally via HIP
|
||||
return _create_pytorch_engine("cuda", compute_type)
|
||||
|
||||
|
||||
def _create_pytorch_engine(device: str, compute_type: str) -> AsrEngine:
|
||||
"""Create PyTorch Whisper engine (universal fallback).
|
||||
|
||||
Args:
|
||||
device: Target device.
|
||||
compute_type: Compute precision.
|
||||
|
||||
Returns:
|
||||
PyTorch-based Whisper engine.
|
||||
|
||||
Raises:
|
||||
EngineCreationError: If openai-whisper is not installed.
|
||||
"""
|
||||
try:
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
return WhisperPyTorchEngine(device=device, compute_type=compute_type)
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"Neither CTranslate2 nor openai-whisper is available. "
|
||||
"Install one of: pip install faster-whisper OR pip install openai-whisper"
|
||||
)
|
||||
raise EngineCreationError(msg) from e
|
||||
@@ -61,6 +61,16 @@ class AsrEngine(Protocol):
|
||||
"""Return the loaded model size, or None if not loaded."""
|
||||
...
|
||||
|
||||
@property
|
||||
def device(self) -> str:
|
||||
"""Return the device this engine runs on (cpu, cuda, rocm)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def compute_type(self) -> str:
|
||||
"""Return the compute precision (int8, float16, float32)."""
|
||||
...
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload the model to free memory."""
|
||||
...
|
||||
|
||||
302
src/noteflow/infrastructure/asr/pytorch_engine.py
Normal file
302
src/noteflow/infrastructure/asr/pytorch_engine.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""PyTorch-based Whisper engine (universal fallback).
|
||||
|
||||
Provides a pure PyTorch implementation using the official openai-whisper
|
||||
package. Works on any PyTorch-supported device (CPU, CUDA, ROCm via HIP).
|
||||
|
||||
This engine is slower than CTranslate2-based engines but provides
|
||||
universal compatibility across all GPU backends.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Protocol, TypedDict, cast
|
||||
|
||||
from noteflow.application.services.asr_config.types import (
|
||||
INVALID_MODEL_SIZE_PREFIX,
|
||||
VALID_SIZES_SUFFIX,
|
||||
AsrComputeType,
|
||||
)
|
||||
from noteflow.domain.constants.fields import START
|
||||
from noteflow.infrastructure.asr.dto import AsrResult, WordTiming
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class _WordDict(TypedDict, total=False):
|
||||
"""TypedDict for whisper word timing."""
|
||||
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
probability: float
|
||||
|
||||
|
||||
class _SegmentDict(TypedDict, total=False):
|
||||
"""TypedDict for whisper segment."""
|
||||
|
||||
text: str
|
||||
start: float
|
||||
end: float
|
||||
words: list[_WordDict]
|
||||
avg_logprob: float
|
||||
no_speech_prob: float
|
||||
|
||||
|
||||
class _TranscriptionResult(TypedDict, total=False):
|
||||
"""TypedDict for whisper transcription result."""
|
||||
|
||||
segments: list[_SegmentDict]
|
||||
language: str
|
||||
|
||||
|
||||
class _WhisperModel(Protocol):
|
||||
"""Protocol for openai-whisper Whisper model type."""
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
audio: object,
|
||||
**kwargs: object,
|
||||
) -> _TranscriptionResult:
|
||||
"""Transcribe audio."""
|
||||
...
|
||||
|
||||
def half(self) -> "_WhisperModel":
|
||||
"""Convert model to half precision."""
|
||||
...
|
||||
|
||||
# Valid model sizes for openai-whisper
|
||||
PYTORCH_VALID_MODEL_SIZES: tuple[str, ...] = (
|
||||
"tiny",
|
||||
"tiny.en",
|
||||
"base",
|
||||
"base.en",
|
||||
"small",
|
||||
"small.en",
|
||||
"medium",
|
||||
"medium.en",
|
||||
"large",
|
||||
"large-v1",
|
||||
"large-v2",
|
||||
"large-v3",
|
||||
"turbo",
|
||||
)
|
||||
|
||||
|
||||
class WhisperPyTorchEngine:
|
||||
"""Pure PyTorch Whisper implementation.
|
||||
|
||||
Uses the official openai-whisper package for transcription.
|
||||
Works on any PyTorch-supported device (CPU, CUDA, ROCm via HIP).
|
||||
|
||||
This engine is slower than CTranslate2-based engines but provides
|
||||
universal compatibility across all GPU backends.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: str = "cpu",
|
||||
compute_type: str = "float32",
|
||||
) -> None:
|
||||
"""Initialize PyTorch Whisper engine.
|
||||
|
||||
Args:
|
||||
device: Target device ("cpu" or "cuda").
|
||||
For ROCm, use "cuda" - HIP handles the translation.
|
||||
compute_type: Compute precision. Only "float16" and "float32"
|
||||
are supported. "int8" will be treated as "float32".
|
||||
"""
|
||||
self._device = device
|
||||
self._compute_type = self._normalize_compute_type(compute_type)
|
||||
self._model_size: str | None = None
|
||||
self._model: _WhisperModel | None = None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_compute_type(compute_type: str) -> str:
|
||||
"""Normalize compute type for PyTorch.
|
||||
|
||||
PyTorch Whisper doesn't support int8, map to float32.
|
||||
"""
|
||||
if compute_type == "int8":
|
||||
logger.debug("int8 not supported in PyTorch Whisper, using float32")
|
||||
return "float32"
|
||||
return compute_type
|
||||
|
||||
@property
|
||||
def device(self) -> str:
|
||||
"""Return the device this engine runs on."""
|
||||
return self._device
|
||||
|
||||
@property
|
||||
def compute_type(self) -> str:
|
||||
"""Return the compute precision."""
|
||||
return self._compute_type
|
||||
|
||||
@property
|
||||
def model_size(self) -> str | None:
|
||||
"""Return the loaded model size."""
|
||||
return self._model_size
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
"""Return True if model is loaded."""
|
||||
return self._model is not None
|
||||
|
||||
def load_model(self, model_size: str = "base") -> None:
|
||||
"""Load the specified Whisper model.
|
||||
|
||||
Args:
|
||||
model_size: Whisper model size (e.g., "base", "small", "large-v3").
|
||||
|
||||
Raises:
|
||||
ValueError: If model_size is invalid.
|
||||
RuntimeError: If model loading fails.
|
||||
"""
|
||||
import whisper
|
||||
|
||||
if model_size not in PYTORCH_VALID_MODEL_SIZES:
|
||||
valid_sizes = ", ".join(PYTORCH_VALID_MODEL_SIZES)
|
||||
msg = f"{INVALID_MODEL_SIZE_PREFIX}{model_size}{VALID_SIZES_SUFFIX}{valid_sizes}"
|
||||
raise ValueError(msg)
|
||||
|
||||
logger.info(
|
||||
"Loading PyTorch Whisper model",
|
||||
model_size=model_size,
|
||||
device=self._device,
|
||||
compute_type=self._compute_type,
|
||||
)
|
||||
|
||||
try:
|
||||
# Load model - cast untyped whisper.load_model return
|
||||
load_fn = getattr(whisper, "load_model")
|
||||
model = cast(_WhisperModel, load_fn(model_size, device=self._device))
|
||||
|
||||
# Apply compute type (half precision for GPU)
|
||||
if self._compute_type == AsrComputeType.FLOAT16.value and self._device != "cpu":
|
||||
model = model.half()
|
||||
|
||||
self._model = model
|
||||
self._model_size = model_size
|
||||
|
||||
logger.info("PyTorch Whisper model loaded successfully")
|
||||
except (RuntimeError, OSError, ValueError) as e:
|
||||
msg = f"Failed to load model: {e}"
|
||||
raise RuntimeError(msg) from e
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload the model and free resources."""
|
||||
if self._model is not None:
|
||||
import torch
|
||||
|
||||
del self._model
|
||||
self._model = None
|
||||
self._model_size = None
|
||||
|
||||
# Force garbage collection and clear GPU cache
|
||||
gc.collect()
|
||||
if self._device != "cpu" and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
logger.debug("PyTorch Whisper model unloaded")
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
audio: NDArray[np.float32],
|
||||
language: str | None = None,
|
||||
) -> Iterator[AsrResult]:
|
||||
"""Transcribe audio samples.
|
||||
|
||||
Args:
|
||||
audio: Audio samples as float32 array, 16kHz mono.
|
||||
language: Optional language code.
|
||||
|
||||
Yields:
|
||||
AsrResult for each detected segment.
|
||||
"""
|
||||
if self._model is None:
|
||||
msg = "Model not loaded. Call load_model() first."
|
||||
raise RuntimeError(msg)
|
||||
|
||||
# Build transcription options
|
||||
options: dict[str, object] = {
|
||||
"word_timestamps": True,
|
||||
"fp16": self._compute_type == AsrComputeType.FLOAT16.value and self._device != "cpu",
|
||||
}
|
||||
|
||||
if language is not None:
|
||||
options["language"] = language
|
||||
|
||||
# Transcribe
|
||||
result = self._model.transcribe(audio, **options)
|
||||
|
||||
# Convert to our segment format
|
||||
segments = result.get("segments", [])
|
||||
detected_language = result.get("language", "en")
|
||||
|
||||
for segment in segments:
|
||||
words = self._extract_word_timings(segment)
|
||||
|
||||
yield AsrResult(
|
||||
text=segment.get("text", "").strip(),
|
||||
start=segment.get(START, 0.0),
|
||||
end=segment.get("end", 0.0),
|
||||
words=tuple(words),
|
||||
language=detected_language,
|
||||
language_probability=1.0, # Not available in base whisper
|
||||
avg_logprob=segment.get("avg_logprob", 0.0),
|
||||
no_speech_prob=segment.get("no_speech_prob", 0.0),
|
||||
)
|
||||
|
||||
def _extract_word_timings(self, segment: _SegmentDict) -> list[WordTiming]:
|
||||
"""Extract word timings from a segment.
|
||||
|
||||
Args:
|
||||
segment: Whisper segment dictionary.
|
||||
|
||||
Returns:
|
||||
List of WordTiming objects.
|
||||
"""
|
||||
words_data = segment.get("words", [])
|
||||
if not words_data:
|
||||
return []
|
||||
|
||||
return [
|
||||
WordTiming(
|
||||
word=w.get("word", ""),
|
||||
start=w.get(START, 0.0),
|
||||
end=w.get("end", 0.0),
|
||||
probability=w.get("probability", 0.0),
|
||||
)
|
||||
for w in words_data
|
||||
]
|
||||
|
||||
def transcribe_file(
|
||||
self,
|
||||
audio_path: Path,
|
||||
*,
|
||||
language: str | None = None,
|
||||
) -> Iterator[AsrResult]:
|
||||
"""Transcribe audio file.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file.
|
||||
language: Optional language code.
|
||||
|
||||
Yields:
|
||||
AsrResult for each detected segment.
|
||||
"""
|
||||
import whisper
|
||||
|
||||
# Load audio using whisper's utility - cast untyped return
|
||||
load_audio_fn = getattr(whisper, "load_audio")
|
||||
audio: NDArray[np.float32] = load_audio_fn(str(audio_path))
|
||||
|
||||
yield from self.transcribe(audio, language=language)
|
||||
246
src/noteflow/infrastructure/asr/rocm_engine.py
Normal file
246
src/noteflow/infrastructure/asr/rocm_engine.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""ROCm-specific faster-whisper engine.
|
||||
|
||||
Provides a faster-whisper engine optimized for AMD GPUs using the
|
||||
CTranslate2-ROCm community fork. Falls back to the standard FasterWhisperEngine
|
||||
when CTranslate2-ROCm is not available.
|
||||
|
||||
Requirements:
|
||||
- PyTorch with ROCm support
|
||||
- CTranslate2-ROCm fork: pip install git+https://github.com/arlo-phoenix/CTranslate2-rocm.git
|
||||
- faster-whisper: pip install faster-whisper
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Iterator
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from noteflow.application.services.asr_config.types import (
|
||||
INVALID_MODEL_SIZE_PREFIX,
|
||||
VALID_SIZES_SUFFIX,
|
||||
AsrComputeType,
|
||||
)
|
||||
from noteflow.infrastructure.asr.dto import AsrResult, WordTiming
|
||||
from noteflow.infrastructure.asr.pytorch_engine import (
|
||||
PYTORCH_VALID_MODEL_SIZES,
|
||||
)
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# Protocol classes for faster-whisper types (untyped library)
|
||||
class _WhisperWord(Protocol):
|
||||
"""Protocol for faster-whisper Word type."""
|
||||
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
probability: float
|
||||
|
||||
|
||||
class _WhisperSegment(Protocol):
|
||||
"""Protocol for faster-whisper Segment type."""
|
||||
|
||||
text: str
|
||||
start: float
|
||||
end: float
|
||||
words: Iterable[_WhisperWord] | None
|
||||
avg_logprob: float
|
||||
no_speech_prob: float
|
||||
|
||||
|
||||
class _WhisperInfo(Protocol):
|
||||
"""Protocol for faster-whisper TranscriptionInfo type."""
|
||||
|
||||
language: str
|
||||
language_probability: float
|
||||
|
||||
|
||||
|
||||
class FasterWhisperRocmEngine:
|
||||
"""ROCm-specific faster-whisper engine.
|
||||
|
||||
Uses the CTranslate2-ROCm fork for AMD GPU acceleration.
|
||||
Provides the same interface as FasterWhisperEngine but with
|
||||
ROCm-specific optimizations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
compute_type: str = AsrComputeType.FLOAT16.value,
|
||||
num_workers: int = 1,
|
||||
) -> None:
|
||||
"""Initialize ROCm engine.
|
||||
|
||||
Args:
|
||||
compute_type: Computation type ("int8", "float16", "float32").
|
||||
num_workers: Number of worker threads.
|
||||
"""
|
||||
self._compute_type = compute_type
|
||||
self._num_workers = num_workers
|
||||
self._model: object | None = None
|
||||
self._model_size: str | None = None
|
||||
|
||||
# Verify ROCm is available
|
||||
self._verify_rocm_available()
|
||||
|
||||
def _verify_rocm_available(self) -> None:
|
||||
"""Verify ROCm/HIP is available."""
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
msg = "ROCm/CUDA not available"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
if not (hasattr(torch.version, "hip") and torch.version.hip):
|
||||
logger.warning(
|
||||
"Running on CUDA instead of ROCm. "
|
||||
"For optimal performance on AMD GPUs, use PyTorch with ROCm support."
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self) -> str:
|
||||
"""Return the device (always 'rocm' for this engine)."""
|
||||
return "rocm"
|
||||
|
||||
@property
|
||||
def compute_type(self) -> str:
|
||||
"""Return the compute type."""
|
||||
return self._compute_type
|
||||
|
||||
@property
|
||||
def model_size(self) -> str | None:
|
||||
"""Return the loaded model size."""
|
||||
return self._model_size
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
"""Return True if model is loaded."""
|
||||
return self._model is not None
|
||||
|
||||
def load_model(self, model_size: str = "base") -> None:
|
||||
"""Load the ASR model.
|
||||
|
||||
Args:
|
||||
model_size: Model size (e.g., "tiny", "base", "small").
|
||||
|
||||
Raises:
|
||||
ValueError: If model_size is invalid.
|
||||
RuntimeError: If model loading fails.
|
||||
"""
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
if model_size not in PYTORCH_VALID_MODEL_SIZES:
|
||||
valid_sizes = ", ".join(PYTORCH_VALID_MODEL_SIZES)
|
||||
msg = f"{INVALID_MODEL_SIZE_PREFIX}{model_size}{VALID_SIZES_SUFFIX}{valid_sizes}"
|
||||
raise ValueError(msg)
|
||||
|
||||
logger.info(
|
||||
"Loading faster-whisper model for ROCm",
|
||||
model_size=model_size,
|
||||
compute_type=self._compute_type,
|
||||
)
|
||||
|
||||
try:
|
||||
# Use "cuda" device string - HIP maps this to ROCm
|
||||
self._model = WhisperModel(
|
||||
model_size,
|
||||
device="cuda", # HIP uses CUDA device string
|
||||
compute_type=self._compute_type,
|
||||
num_workers=self._num_workers,
|
||||
)
|
||||
self._model_size = model_size
|
||||
|
||||
logger.info("ROCm faster-whisper model loaded successfully")
|
||||
except (RuntimeError, OSError, ValueError) as e:
|
||||
msg = f"Failed to load model on ROCm: {e}"
|
||||
raise RuntimeError(msg) from e
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload the model to free memory."""
|
||||
import gc
|
||||
|
||||
import torch
|
||||
|
||||
self._model = None
|
||||
self._model_size = None
|
||||
|
||||
# Force garbage collection and clear GPU cache
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
logger.info("ROCm model unloaded")
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
audio: NDArray[np.float32],
|
||||
language: str | None = None,
|
||||
) -> Iterator[AsrResult]:
|
||||
"""Transcribe audio and yield results.
|
||||
|
||||
Args:
|
||||
audio: Audio samples as float32 array (16kHz mono, normalized).
|
||||
language: Optional language code (e.g., "en").
|
||||
|
||||
Yields:
|
||||
AsrResult segments with word-level timestamps.
|
||||
"""
|
||||
if self._model is None:
|
||||
msg = "Model not loaded. Call load_model() first."
|
||||
raise RuntimeError(msg)
|
||||
|
||||
# Call transcribe on untyped faster-whisper model
|
||||
transcribe_fn = getattr(self._model, "transcribe")
|
||||
result = transcribe_fn(
|
||||
audio,
|
||||
language=language,
|
||||
word_timestamps=True,
|
||||
beam_size=5,
|
||||
vad_filter=True,
|
||||
)
|
||||
segments: Iterable[_WhisperSegment] = result[0]
|
||||
info: _WhisperInfo = result[1]
|
||||
|
||||
logger.debug(
|
||||
"Detected language: %s (prob: %.2f)",
|
||||
info.language,
|
||||
info.language_probability,
|
||||
)
|
||||
|
||||
for segment in segments:
|
||||
words = self._extract_word_timings(segment.words)
|
||||
yield AsrResult(
|
||||
text=segment.text.strip(),
|
||||
start=segment.start,
|
||||
end=segment.end,
|
||||
words=tuple(words),
|
||||
language=info.language,
|
||||
language_probability=info.language_probability,
|
||||
avg_logprob=segment.avg_logprob,
|
||||
no_speech_prob=segment.no_speech_prob,
|
||||
)
|
||||
|
||||
def _extract_word_timings(
|
||||
self,
|
||||
words: Iterable[_WhisperWord] | None,
|
||||
) -> list[WordTiming]:
|
||||
"""Extract word timings from segment words."""
|
||||
if not words:
|
||||
return []
|
||||
|
||||
return [
|
||||
WordTiming(
|
||||
word=word.word,
|
||||
start=word.start,
|
||||
end=word.end,
|
||||
probability=word.probability,
|
||||
)
|
||||
for word in words
|
||||
]
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from noteflow.domain.value_objects import OAuthProvider, OAuthState, OAuthTokens
|
||||
from noteflow.domain.value_objects import OAuthClientConfig, OAuthProvider, OAuthState, OAuthTokens
|
||||
from noteflow.infrastructure.calendar.oauth_flow import (
|
||||
OAuthFlowConfig,
|
||||
get_scopes,
|
||||
@@ -28,12 +28,12 @@ class OAuthManagerFlowMixin(OAuthManagerBase):
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
redirect_uri: str,
|
||||
client_config: OAuthClientConfig | None = None,
|
||||
) -> tuple[str, str]:
|
||||
"""Generate OAuth authorization URL with PKCE."""
|
||||
self._cleanup_expired_states()
|
||||
self._validate_provider_config(provider)
|
||||
self._validate_provider_config(provider, client_config)
|
||||
self._check_rate_limit(provider)
|
||||
|
||||
with self._state_lock:
|
||||
if len(self._pending_states) >= self.MAX_PENDING_STATES:
|
||||
logger.warning(
|
||||
@@ -45,11 +45,15 @@ class OAuthManagerFlowMixin(OAuthManagerBase):
|
||||
|
||||
raise OAuthError("Too many pending OAuth flows. Please try again later.")
|
||||
|
||||
client_id, _ = self._get_credentials(provider)
|
||||
scopes = get_scopes(
|
||||
provider,
|
||||
google_scopes=self.GOOGLE_SCOPES,
|
||||
outlook_scopes=self.OUTLOOK_SCOPES,
|
||||
client_id, _ = self._get_credentials(provider, client_config)
|
||||
scopes = (
|
||||
list(client_config.scopes)
|
||||
if client_config and client_config.scopes
|
||||
else get_scopes(
|
||||
provider,
|
||||
google_scopes=self.GOOGLE_SCOPES,
|
||||
outlook_scopes=self.OUTLOOK_SCOPES,
|
||||
)
|
||||
)
|
||||
state_token, oauth_state, auth_url = prepare_oauth_flow(
|
||||
OAuthFlowConfig(
|
||||
@@ -63,7 +67,6 @@ class OAuthManagerFlowMixin(OAuthManagerBase):
|
||||
)
|
||||
)
|
||||
self._pending_states[state_token] = oauth_state
|
||||
|
||||
logger.info(
|
||||
"oauth_initiated",
|
||||
provider=provider.value,
|
||||
@@ -77,6 +80,7 @@ class OAuthManagerFlowMixin(OAuthManagerBase):
|
||||
provider: OAuthProvider,
|
||||
code: str,
|
||||
state: str,
|
||||
client_config: OAuthClientConfig | None = None,
|
||||
) -> OAuthTokens:
|
||||
"""Exchange authorization code for tokens."""
|
||||
with self._state_lock:
|
||||
@@ -98,8 +102,8 @@ class OAuthManagerFlowMixin(OAuthManagerBase):
|
||||
tokens = await self._exchange_code(
|
||||
provider=provider,
|
||||
code=code,
|
||||
redirect_uri=oauth_state.redirect_uri,
|
||||
code_verifier=oauth_state.code_verifier,
|
||||
oauth_state=oauth_state,
|
||||
client_config=client_config,
|
||||
)
|
||||
|
||||
logger.info("Completed OAuth flow for provider=%s", provider.value)
|
||||
|
||||
@@ -9,7 +9,12 @@ import httpx
|
||||
|
||||
from noteflow.config.constants import HTTP_STATUS_OK
|
||||
from noteflow.domain.constants.fields import CODE
|
||||
from noteflow.domain.value_objects import OAuthProvider, OAuthTokens
|
||||
from noteflow.domain.value_objects import (
|
||||
OAuthClientConfig,
|
||||
OAuthProvider,
|
||||
OAuthState,
|
||||
OAuthTokens,
|
||||
)
|
||||
from noteflow.infrastructure.calendar.oauth_flow import (
|
||||
get_token_url,
|
||||
parse_token_response,
|
||||
@@ -30,9 +35,13 @@ class OAuthManagerHelpersMixin(OAuthManagerBase):
|
||||
_settings: CalendarIntegrationSettings
|
||||
_auth_attempts: dict[str, list[datetime]]
|
||||
|
||||
def _validate_provider_config(self, provider: OAuthProvider) -> None:
|
||||
def _validate_provider_config(
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
client_config: OAuthClientConfig | None = None,
|
||||
) -> None:
|
||||
"""Validate that provider credentials are configured."""
|
||||
client_id, client_secret = self._get_credentials(provider)
|
||||
client_id, client_secret = self._get_credentials(provider, client_config)
|
||||
if not client_id or not client_secret:
|
||||
from .oauth_manager import OAuthError
|
||||
|
||||
@@ -40,8 +49,14 @@ class OAuthManagerHelpersMixin(OAuthManagerBase):
|
||||
f"OAuth credentials not configured for {provider.value}"
|
||||
)
|
||||
|
||||
def _get_credentials(self, provider: OAuthProvider) -> tuple[str, str]:
|
||||
def _get_credentials(
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
client_config: OAuthClientConfig | None = None,
|
||||
) -> tuple[str, str]:
|
||||
"""Get client credentials for provider."""
|
||||
if client_config is not None:
|
||||
return client_config.client_id, client_config.client_secret
|
||||
if provider == OAuthProvider.GOOGLE:
|
||||
return (
|
||||
self._settings.google_client_id,
|
||||
@@ -56,8 +71,8 @@ class OAuthManagerHelpersMixin(OAuthManagerBase):
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
code: str,
|
||||
redirect_uri: str,
|
||||
code_verifier: str,
|
||||
oauth_state: OAuthState,
|
||||
client_config: OAuthClientConfig | None = None,
|
||||
) -> OAuthTokens:
|
||||
"""Exchange authorization code for tokens."""
|
||||
token_url = get_token_url(
|
||||
@@ -65,14 +80,14 @@ class OAuthManagerHelpersMixin(OAuthManagerBase):
|
||||
google_url=self.GOOGLE_TOKEN_URL,
|
||||
outlook_url=self.OUTLOOK_TOKEN_URL,
|
||||
)
|
||||
client_id, client_secret = self._get_credentials(provider)
|
||||
client_id, client_secret = self._get_credentials(provider, client_config)
|
||||
|
||||
data = {
|
||||
"grant_type": "authorization_code",
|
||||
CODE: code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"redirect_uri": oauth_state.redirect_uri,
|
||||
"client_id": client_id,
|
||||
"code_verifier": code_verifier,
|
||||
"code_verifier": oauth_state.code_verifier,
|
||||
}
|
||||
|
||||
if provider == OAuthProvider.GOOGLE:
|
||||
|
||||
@@ -10,7 +10,7 @@ from noteflow.config.constants import (
|
||||
HTTP_STATUS_OK,
|
||||
OAUTH_FIELD_REFRESH_TOKEN,
|
||||
)
|
||||
from noteflow.domain.value_objects import OAuthProvider, OAuthTokens
|
||||
from noteflow.domain.value_objects import OAuthClientConfig, OAuthProvider, OAuthTokens
|
||||
from noteflow.infrastructure.calendar.oauth_flow import (
|
||||
get_revoke_url,
|
||||
get_token_url,
|
||||
@@ -30,6 +30,7 @@ class OAuthManagerTokenMixin(OAuthManagerBase):
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
refresh_token: str,
|
||||
client_config: OAuthClientConfig | None = None,
|
||||
) -> OAuthTokens:
|
||||
"""Refresh expired access token."""
|
||||
token_url = get_token_url(
|
||||
@@ -37,8 +38,7 @@ class OAuthManagerTokenMixin(OAuthManagerBase):
|
||||
google_url=self.GOOGLE_TOKEN_URL,
|
||||
outlook_url=self.OUTLOOK_TOKEN_URL,
|
||||
)
|
||||
client_id, client_secret = self._get_credentials(provider)
|
||||
|
||||
client_id, client_secret = self._get_credentials(provider, client_config)
|
||||
data = {
|
||||
"grant_type": OAUTH_FIELD_REFRESH_TOKEN,
|
||||
OAUTH_FIELD_REFRESH_TOKEN: refresh_token,
|
||||
|
||||
@@ -32,5 +32,8 @@ class DiarizationEngineDeviceMixin(DiarizationEngineBase):
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# Explicitly log ROCm detection for debugging clarity
|
||||
if hasattr(torch.version, "hip") and torch.version.hip:
|
||||
logger.info("Detected ROCm environment (using cuda backend)")
|
||||
return "cuda"
|
||||
return "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
|
||||
25
src/noteflow/infrastructure/gpu/__init__.py
Normal file
25
src/noteflow/infrastructure/gpu/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""GPU detection and utilities for multi-backend support.
|
||||
|
||||
This module provides GPU backend detection (CUDA, ROCm, MPS) and utilities
|
||||
for determining the best compute device for ASR and diarization workloads.
|
||||
"""
|
||||
|
||||
from noteflow.infrastructure.gpu.detection import (
|
||||
SUPPORTED_AMD_ARCHITECTURES,
|
||||
GpuDetectionError,
|
||||
detect_gpu_backend,
|
||||
get_gpu_info,
|
||||
get_rocm_environment_info,
|
||||
is_ctranslate2_rocm_available,
|
||||
is_rocm_architecture_supported,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SUPPORTED_AMD_ARCHITECTURES",
|
||||
"GpuDetectionError",
|
||||
"detect_gpu_backend",
|
||||
"get_gpu_info",
|
||||
"get_rocm_environment_info",
|
||||
"is_ctranslate2_rocm_available",
|
||||
"is_rocm_architecture_supported",
|
||||
]
|
||||
237
src/noteflow/infrastructure/gpu/detection.py
Normal file
237
src/noteflow/infrastructure/gpu/detection.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""GPU backend detection utilities.
|
||||
|
||||
Provides functions to detect available GPU backends (CUDA, ROCm, MPS)
|
||||
and gather hardware information for compute device selection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from functools import cache
|
||||
from typing import Final, Protocol, cast
|
||||
|
||||
from noteflow.domain.constants.fields import UNKNOWN
|
||||
from noteflow.domain.ports.gpu import GpuBackend, GpuInfo
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class _DeviceProperties(Protocol):
|
||||
"""Protocol for torch.cuda.get_device_properties return type."""
|
||||
|
||||
name: str
|
||||
total_memory: int
|
||||
major: int
|
||||
minor: int
|
||||
|
||||
|
||||
# Officially supported AMD GPU architectures for ROCm.
|
||||
# Keep in sync with AMD ROCm compatibility matrix.
|
||||
# See: https://rocm.docs.amd.com/en/latest/compatibility/compatibility-matrix.html
|
||||
SUPPORTED_AMD_ARCHITECTURES: Final[frozenset[str]] = frozenset({
|
||||
# CDNA (Instinct datacenter GPUs)
|
||||
"gfx906", # MI50
|
||||
"gfx908", # MI100
|
||||
"gfx90a", # MI210, MI250, MI250X
|
||||
"gfx942", # MI300X, MI300A
|
||||
# RDNA 2 (Consumer/Workstation)
|
||||
"gfx1030", # RX 6800, 6800 XT, 6900 XT
|
||||
"gfx1031", # RX 6700 XT
|
||||
"gfx1032", # RX 6600, 6600 XT
|
||||
# RDNA 3 (Consumer/Workstation)
|
||||
"gfx1100", # RX 7900 XTX, 7900 XT
|
||||
"gfx1101", # RX 7800 XT, 7700 XT
|
||||
"gfx1102", # RX 7600
|
||||
})
|
||||
|
||||
|
||||
@cache
|
||||
def detect_gpu_backend() -> GpuBackend:
|
||||
"""Detect the available GPU backend.
|
||||
|
||||
Results are cached for performance. The cache is cleared when
|
||||
the module is reloaded.
|
||||
|
||||
Returns:
|
||||
GpuBackend enum indicating the detected backend.
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
logger.debug("PyTorch not installed, no GPU backend available")
|
||||
return GpuBackend.NONE
|
||||
|
||||
# Check CUDA/ROCm availability
|
||||
if torch.cuda.is_available():
|
||||
# Distinguish between CUDA and ROCm via HIP version
|
||||
if hasattr(torch.version, "hip") and torch.version.hip:
|
||||
logger.info("ROCm/HIP backend detected", version=torch.version.hip)
|
||||
return GpuBackend.ROCM
|
||||
|
||||
cuda_version = torch.version.cuda or UNKNOWN
|
||||
logger.info("CUDA backend detected", version=cuda_version)
|
||||
return GpuBackend.CUDA
|
||||
|
||||
# Check Apple Metal Performance Shaders
|
||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
logger.info("MPS backend detected")
|
||||
return GpuBackend.MPS
|
||||
|
||||
logger.debug("No GPU backend available, using CPU")
|
||||
return GpuBackend.NONE
|
||||
|
||||
|
||||
class GpuDetectionError(RuntimeError):
|
||||
"""Raised when GPU detection fails."""
|
||||
|
||||
|
||||
def _get_cuda_rocm_info(backend: GpuBackend) -> GpuInfo:
|
||||
"""Get GPU info for CUDA or ROCm backend."""
|
||||
import torch
|
||||
|
||||
get_props = getattr(torch.cuda, "get_device_properties")
|
||||
props = cast(_DeviceProperties, get_props(0))
|
||||
vram_mb = props.total_memory // (1024 * 1024)
|
||||
|
||||
if backend == GpuBackend.ROCM:
|
||||
hip_version = getattr(torch.version, "hip", None)
|
||||
driver = str(hip_version) if hip_version else UNKNOWN
|
||||
arch = _extract_rocm_architecture(props)
|
||||
else:
|
||||
cuda_version = getattr(torch.version, "cuda", None)
|
||||
driver = str(cuda_version) if cuda_version else UNKNOWN
|
||||
arch = f"sm_{props.major}{props.minor}"
|
||||
|
||||
return GpuInfo(
|
||||
backend=backend,
|
||||
device_name=str(props.name),
|
||||
vram_total_mb=vram_mb,
|
||||
driver_version=driver,
|
||||
architecture=arch,
|
||||
)
|
||||
|
||||
|
||||
def get_gpu_info() -> GpuInfo | None:
|
||||
"""Get detailed GPU information.
|
||||
|
||||
Returns:
|
||||
GpuInfo if a GPU is available, None otherwise.
|
||||
|
||||
Raises:
|
||||
GpuDetectionError: If GPU is detected but properties cannot be retrieved.
|
||||
"""
|
||||
backend = detect_gpu_backend()
|
||||
|
||||
if backend == GpuBackend.NONE:
|
||||
return None
|
||||
|
||||
if backend in (GpuBackend.CUDA, GpuBackend.ROCM):
|
||||
try:
|
||||
return _get_cuda_rocm_info(backend)
|
||||
except RuntimeError as e:
|
||||
msg = f"Failed to get GPU properties for {backend.value} device"
|
||||
raise GpuDetectionError(msg) from e
|
||||
|
||||
if backend == GpuBackend.MPS:
|
||||
return GpuInfo(
|
||||
backend=backend,
|
||||
device_name="Apple Metal",
|
||||
vram_total_mb=0,
|
||||
driver_version="mps",
|
||||
architecture=None,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _extract_rocm_architecture(props: _DeviceProperties) -> str | None:
|
||||
"""Extract ROCm GPU architecture from device properties.
|
||||
|
||||
Args:
|
||||
props: PyTorch device properties object.
|
||||
|
||||
Returns:
|
||||
Architecture string (e.g., "gfx1100") or None.
|
||||
"""
|
||||
# Try gcnArchName attribute first (available in newer PyTorch ROCm builds)
|
||||
gcn_arch = getattr(props, "gcnArchName", None)
|
||||
if gcn_arch is not None:
|
||||
return str(gcn_arch)
|
||||
|
||||
# Fall back to parsing device name if it contains gfx ID
|
||||
device_name = props.name
|
||||
if device_name.startswith("gfx"):
|
||||
return device_name.split()[0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def is_rocm_architecture_supported(architecture: str | None) -> bool:
|
||||
"""Check if AMD GPU architecture is officially supported for ROCm.
|
||||
|
||||
Args:
|
||||
architecture: GPU architecture string (e.g., "gfx1100").
|
||||
|
||||
Returns:
|
||||
True if supported, False otherwise.
|
||||
"""
|
||||
if architecture is None:
|
||||
return False
|
||||
|
||||
# Check for user override (allows unofficial GPUs)
|
||||
if os.environ.get("HSA_OVERRIDE_GFX_VERSION"):
|
||||
return True
|
||||
|
||||
return architecture in SUPPORTED_AMD_ARCHITECTURES
|
||||
|
||||
|
||||
def is_ctranslate2_rocm_available() -> bool:
|
||||
"""Check if CTranslate2-ROCm fork is installed.
|
||||
|
||||
The CTranslate2-ROCm fork is required for faster-whisper ROCm support.
|
||||
If not available, the system falls back to PyTorch Whisper.
|
||||
|
||||
Returns:
|
||||
True if the ROCm fork is available.
|
||||
"""
|
||||
if detect_gpu_backend() != GpuBackend.ROCM:
|
||||
return False
|
||||
|
||||
try:
|
||||
import ctranslate2
|
||||
|
||||
# The ROCm fork should work with HIP
|
||||
# Verify by checking if we can query compute types
|
||||
if not hasattr(ctranslate2, "get_supported_compute_types"):
|
||||
return False
|
||||
|
||||
# Try to get compute types for cuda (which maps to HIP on ROCm)
|
||||
get_types = getattr(ctranslate2, "get_supported_compute_types")
|
||||
try:
|
||||
get_types("cuda")
|
||||
return True
|
||||
except (ValueError, RuntimeError):
|
||||
return False
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def get_rocm_environment_info() -> dict[str, str]:
|
||||
"""Get ROCm-related environment variables for debugging.
|
||||
|
||||
Returns:
|
||||
Dictionary of relevant environment variables and their values.
|
||||
"""
|
||||
rocm_vars = [
|
||||
"HSA_OVERRIDE_GFX_VERSION",
|
||||
"HIP_VISIBLE_DEVICES",
|
||||
"ROCM_PATH",
|
||||
"MIOPEN_USER_DB_PATH",
|
||||
"MIOPEN_FIND_MODE",
|
||||
"AMD_LOG_LEVEL",
|
||||
"GPU_MAX_HW_QUEUES",
|
||||
]
|
||||
|
||||
return {var: os.environ[var] for var in rocm_vars if var in os.environ}
|
||||
@@ -30,6 +30,7 @@ class InMemoryIntegrationRepository:
|
||||
self,
|
||||
provider: str,
|
||||
integration_type: str | None = None,
|
||||
workspace_id: UUID | None = None,
|
||||
) -> Integration | None:
|
||||
"""Retrieve an integration by provider name."""
|
||||
for integration in self._integrations.values():
|
||||
@@ -38,7 +39,8 @@ class InMemoryIntegrationRepository:
|
||||
or provider.lower() in integration.name.lower()
|
||||
)
|
||||
type_match = integration_type is None or integration.type.value == integration_type
|
||||
if provider_match and type_match:
|
||||
workspace_match = workspace_id is None or integration.workspace_id == workspace_id
|
||||
if provider_match and type_match and workspace_match:
|
||||
return integration
|
||||
return None
|
||||
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql import Select
|
||||
|
||||
from noteflow.domain.constants.fields import PROVIDER
|
||||
@@ -15,18 +15,20 @@ from noteflow.infrastructure.persistence.models.integrations import IntegrationM
|
||||
|
||||
|
||||
async def get_by_provider(
|
||||
session: AsyncSession,
|
||||
execute_scalar_func: Callable[[Select[tuple[IntegrationModel]]], Awaitable[IntegrationModel | None]],
|
||||
execute_scalar_func: Callable[
|
||||
[Select[tuple[IntegrationModel]]], Awaitable[IntegrationModel | None]
|
||||
],
|
||||
provider: str,
|
||||
integration_type: str | None = None,
|
||||
workspace_id: UUID | None = None,
|
||||
) -> Integration | None:
|
||||
"""Retrieve an integration by provider name.
|
||||
|
||||
Args:
|
||||
session: Database session.
|
||||
execute_scalar_func: Function to execute scalar query.
|
||||
provider: Provider name (stored in config['provider'] or name).
|
||||
integration_type: Optional type filter.
|
||||
workspace_id: Optional workspace filter.
|
||||
|
||||
Returns:
|
||||
Integration if found, None otherwise.
|
||||
@@ -36,6 +38,8 @@ async def get_by_provider(
|
||||
)
|
||||
if integration_type:
|
||||
stmt = stmt.where(IntegrationModel.type == integration_type)
|
||||
if workspace_id:
|
||||
stmt = stmt.where(IntegrationModel.workspace_id == workspace_id)
|
||||
|
||||
model = await execute_scalar_func(stmt)
|
||||
if model:
|
||||
@@ -47,20 +51,22 @@ async def get_by_provider(
|
||||
)
|
||||
if integration_type:
|
||||
fallback_stmt = fallback_stmt.where(IntegrationModel.type == integration_type)
|
||||
if workspace_id:
|
||||
fallback_stmt = fallback_stmt.where(IntegrationModel.workspace_id == workspace_id)
|
||||
|
||||
fallback_model = await execute_scalar_func(fallback_stmt)
|
||||
return IntegrationConverter.orm_to_domain(fallback_model) if fallback_model else None
|
||||
|
||||
|
||||
async def list_by_type(
|
||||
session: AsyncSession,
|
||||
execute_scalars_func: Callable[[Select[tuple[IntegrationModel]]], Awaitable[list[IntegrationModel]]],
|
||||
execute_scalars_func: Callable[
|
||||
[Select[tuple[IntegrationModel]]], Awaitable[list[IntegrationModel]]
|
||||
],
|
||||
integration_type: str,
|
||||
) -> Sequence[Integration]:
|
||||
"""List integrations by type.
|
||||
|
||||
Args:
|
||||
session: Database session.
|
||||
execute_scalars_func: Function to execute scalars query.
|
||||
integration_type: Integration type (e.g., 'calendar', 'email').
|
||||
|
||||
@@ -77,13 +83,13 @@ async def list_by_type(
|
||||
|
||||
|
||||
async def list_all(
|
||||
session: AsyncSession,
|
||||
execute_scalars_func: Callable[[Select[tuple[IntegrationModel]]], Awaitable[list[IntegrationModel]]],
|
||||
execute_scalars_func: Callable[
|
||||
[Select[tuple[IntegrationModel]]], Awaitable[list[IntegrationModel]]
|
||||
],
|
||||
) -> Sequence[Integration]:
|
||||
"""List all integrations for the current workspace context.
|
||||
|
||||
Args:
|
||||
session: Database session.
|
||||
execute_scalars_func: Function to execute scalars query.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -62,18 +62,23 @@ class SqlAlchemyIntegrationRepository(
|
||||
self,
|
||||
provider: str,
|
||||
integration_type: str | None = None,
|
||||
workspace_id: UUID | None = None,
|
||||
) -> Integration | None:
|
||||
"""Retrieve an integration by provider name.
|
||||
|
||||
Args:
|
||||
provider: Provider name (stored in config['provider'] or name).
|
||||
integration_type: Optional type filter.
|
||||
workspace_id: Optional workspace filter.
|
||||
|
||||
Returns:
|
||||
Integration if found, None otherwise.
|
||||
"""
|
||||
return await get_by_provider(
|
||||
self._session, self._execute_scalar, provider, integration_type
|
||||
self._execute_scalar,
|
||||
provider,
|
||||
integration_type,
|
||||
workspace_id,
|
||||
)
|
||||
|
||||
async def create(self, integration: Integration) -> Integration:
|
||||
@@ -181,7 +186,7 @@ class SqlAlchemyIntegrationRepository(
|
||||
Returns:
|
||||
List of integrations of the specified type.
|
||||
"""
|
||||
return await list_by_type(self._session, self._execute_scalars, integration_type)
|
||||
return await list_by_type(self._execute_scalars, integration_type)
|
||||
|
||||
async def list_all(self) -> Sequence[Integration]:
|
||||
"""List all integrations for the current workspace context.
|
||||
@@ -189,7 +194,7 @@ class SqlAlchemyIntegrationRepository(
|
||||
Returns:
|
||||
All integrations the user has access to.
|
||||
"""
|
||||
return await list_all(self._session, self._execute_scalars)
|
||||
return await list_all(self._execute_scalars)
|
||||
|
||||
# Sync run operations
|
||||
|
||||
|
||||
@@ -10,15 +10,25 @@ import pytest
|
||||
|
||||
from noteflow.application.services.calendar import CalendarService, CalendarServiceError
|
||||
from noteflow.config.settings import CalendarIntegrationSettings
|
||||
from noteflow.domain.constants.fields import (
|
||||
OAUTH_OVERRIDE_CLIENT_ID,
|
||||
OAUTH_OVERRIDE_CLIENT_SECRET,
|
||||
OAUTH_OVERRIDE_ENABLED,
|
||||
OAUTH_OVERRIDE_REDIRECT_URI,
|
||||
OAUTH_OVERRIDE_SCOPES,
|
||||
PROVIDER,
|
||||
)
|
||||
from noteflow.domain.entities import Integration, IntegrationStatus, IntegrationType
|
||||
from noteflow.domain.ports.calendar import CalendarEventInfo
|
||||
from noteflow.domain.value_objects import OAuthTokens
|
||||
from noteflow.domain.value_objects import OAuthClientConfig, OAuthTokens
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_calendar_oauth_manager() -> MagicMock:
|
||||
"""Create mock OAuth manager."""
|
||||
manager = MagicMock()
|
||||
manager.GOOGLE_SCOPES = ["calendar.read"]
|
||||
manager.OUTLOOK_SCOPES = ["Calendars.Read"]
|
||||
manager.initiate_auth.return_value = ("https://auth.example.com", "state-123")
|
||||
manager.complete_auth = AsyncMock(
|
||||
return_value=OAuthTokens(
|
||||
@@ -75,6 +85,7 @@ def mock_outlook_adapter() -> MagicMock:
|
||||
def calendar_mock_uow(mock_uow: MagicMock) -> MagicMock:
|
||||
"""Configure mock_uow with calendar service integrations behavior."""
|
||||
mock_uow.integrations.get_by_type_and_provider = AsyncMock(return_value=None)
|
||||
mock_uow.integrations.get_by_provider = AsyncMock(return_value=None)
|
||||
mock_uow.integrations.add = AsyncMock()
|
||||
mock_uow.integrations.get_secrets = AsyncMock(return_value=None)
|
||||
mock_uow.integrations.set_secrets = AsyncMock()
|
||||
@@ -335,3 +346,158 @@ class TestCalendarServiceListEvents:
|
||||
|
||||
with pytest.raises(CalendarServiceError, match="not connected"):
|
||||
await calendar_service.list_calendar_events(provider="google")
|
||||
|
||||
|
||||
class TestCalendarServiceOAuthOverrideConfig:
|
||||
"""CalendarService OAuth override config tests."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_oauth_client_config_defaults_when_missing(
|
||||
self,
|
||||
calendar_service: CalendarService,
|
||||
calendar_settings: CalendarIntegrationSettings,
|
||||
mock_calendar_oauth_manager: MagicMock,
|
||||
) -> None:
|
||||
"""get_oauth_client_config should return defaults when no integration exists."""
|
||||
config, override_enabled, has_secret = await calendar_service.get_oauth_client_config(
|
||||
"google"
|
||||
)
|
||||
|
||||
assert config.client_id == "", "Client ID should default to empty string"
|
||||
assert (
|
||||
config.redirect_uri == calendar_settings.redirect_uri
|
||||
), "Redirect URI should use settings default"
|
||||
assert config.scopes == tuple(
|
||||
mock_calendar_oauth_manager.GOOGLE_SCOPES
|
||||
), "Scopes should fall back to default provider scopes"
|
||||
assert override_enabled is False, "Override should be disabled by default"
|
||||
assert has_secret is False, "Stored secret flag should be false by default"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_oauth_client_config_returns_override_values(
|
||||
self,
|
||||
calendar_service: CalendarService,
|
||||
calendar_mock_uow: MagicMock,
|
||||
) -> None:
|
||||
"""get_oauth_client_config should return stored override values."""
|
||||
integration = Integration.create(
|
||||
workspace_id=uuid4(),
|
||||
name="Google Calendar",
|
||||
integration_type=IntegrationType.CALENDAR,
|
||||
config={
|
||||
PROVIDER: "google",
|
||||
OAUTH_OVERRIDE_ENABLED: True,
|
||||
OAUTH_OVERRIDE_CLIENT_ID: "client-123",
|
||||
OAUTH_OVERRIDE_REDIRECT_URI: "http://localhost/callback",
|
||||
OAUTH_OVERRIDE_SCOPES: ["scope-a"],
|
||||
},
|
||||
)
|
||||
calendar_mock_uow.integrations.get_by_provider = AsyncMock(return_value=integration)
|
||||
calendar_mock_uow.integrations.get_secrets = AsyncMock(
|
||||
return_value={OAUTH_OVERRIDE_CLIENT_SECRET: "secret-xyz"}
|
||||
)
|
||||
|
||||
config, override_enabled, has_secret = await calendar_service.get_oauth_client_config(
|
||||
"google"
|
||||
)
|
||||
|
||||
assert config.client_id == "client-123", "Client ID should match stored override"
|
||||
assert (
|
||||
config.redirect_uri == "http://localhost/callback"
|
||||
), "Redirect URI should match stored override"
|
||||
assert config.scopes == ("scope-a",), "Scopes should match stored override"
|
||||
assert override_enabled is True, "Override should be enabled when stored"
|
||||
assert has_secret is True, "Stored secret flag should reflect stored secret"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_oauth_client_config_persists_config(
|
||||
self,
|
||||
calendar_service: CalendarService,
|
||||
calendar_mock_uow: MagicMock,
|
||||
) -> None:
|
||||
"""set_oauth_client_config should persist config."""
|
||||
calendar_mock_uow.integrations.get_by_provider = AsyncMock(return_value=None)
|
||||
calendar_mock_uow.integrations.create = AsyncMock()
|
||||
calendar_mock_uow.integrations.update = AsyncMock()
|
||||
calendar_mock_uow.integrations.get_secrets = AsyncMock(return_value={})
|
||||
client_config = OAuthClientConfig(
|
||||
client_id="client-456",
|
||||
client_secret="secret-abc",
|
||||
redirect_uri="http://localhost/custom",
|
||||
scopes=("scope-b",),
|
||||
)
|
||||
await calendar_service.set_oauth_client_config(
|
||||
provider="google",
|
||||
client_config=client_config,
|
||||
override_enabled=True,
|
||||
)
|
||||
calendar_mock_uow.integrations.update.assert_called_once()
|
||||
update_call = calendar_mock_uow.integrations.update.call_args[0][0]
|
||||
assert update_call.config[OAUTH_OVERRIDE_ENABLED] is True, (
|
||||
"Override enabled flag should be stored"
|
||||
)
|
||||
assert update_call.config[OAUTH_OVERRIDE_CLIENT_ID] == "client-456", (
|
||||
"Client ID should be stored in config"
|
||||
)
|
||||
assert update_call.config[OAUTH_OVERRIDE_REDIRECT_URI] == "http://localhost/custom", (
|
||||
"Redirect URI should be stored in config"
|
||||
)
|
||||
assert update_call.config[OAUTH_OVERRIDE_SCOPES] == ["scope-b"], (
|
||||
"Scopes should be stored in config"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_oauth_client_config_persists_secret(
|
||||
self,
|
||||
calendar_service: CalendarService,
|
||||
calendar_mock_uow: MagicMock,
|
||||
) -> None:
|
||||
"""set_oauth_client_config should persist secret."""
|
||||
calendar_mock_uow.integrations.get_by_provider = AsyncMock(return_value=None)
|
||||
calendar_mock_uow.integrations.create = AsyncMock()
|
||||
calendar_mock_uow.integrations.update = AsyncMock()
|
||||
calendar_mock_uow.integrations.get_secrets = AsyncMock(return_value={})
|
||||
|
||||
client_config = OAuthClientConfig(
|
||||
client_id="client-456",
|
||||
client_secret="secret-abc",
|
||||
redirect_uri="http://localhost/custom",
|
||||
scopes=("scope-b",),
|
||||
)
|
||||
await calendar_service.set_oauth_client_config(
|
||||
provider="google",
|
||||
client_config=client_config,
|
||||
override_enabled=True,
|
||||
)
|
||||
|
||||
secrets_call = calendar_mock_uow.integrations.set_secrets.call_args.kwargs
|
||||
assert (
|
||||
secrets_call["secrets"][OAUTH_OVERRIDE_CLIENT_SECRET] == "secret-abc"
|
||||
), "Client secret should be persisted in secrets store"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_oauth_client_config_requires_secret_when_enabled(
|
||||
self,
|
||||
calendar_service: CalendarService,
|
||||
calendar_mock_uow: MagicMock,
|
||||
) -> None:
|
||||
"""set_oauth_client_config should require secret when override enabled."""
|
||||
calendar_mock_uow.integrations.get_by_provider = AsyncMock(return_value=None)
|
||||
calendar_mock_uow.integrations.create = AsyncMock()
|
||||
calendar_mock_uow.integrations.update = AsyncMock()
|
||||
calendar_mock_uow.integrations.get_secrets = AsyncMock(return_value={})
|
||||
|
||||
with pytest.raises(
|
||||
CalendarServiceError,
|
||||
match="client secret is missing",
|
||||
):
|
||||
await calendar_service.set_oauth_client_config(
|
||||
provider="google",
|
||||
client_config=OAuthClientConfig(
|
||||
client_id="client-789",
|
||||
client_secret="",
|
||||
redirect_uri="http://localhost/override",
|
||||
scopes=("scope-c",),
|
||||
),
|
||||
override_enabled=True,
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@ import pytest
|
||||
|
||||
from noteflow.application.services.calendar import CalendarServiceError
|
||||
from noteflow.domain.entities.integration import IntegrationStatus
|
||||
from noteflow.domain.identity import DEFAULT_WORKSPACE_ID
|
||||
from noteflow.grpc.config.config import ServicesConfig
|
||||
from noteflow.grpc.proto import noteflow_pb2
|
||||
from noteflow.grpc.service import NoteFlowServicer
|
||||
@@ -219,12 +220,17 @@ def _create_mockcalendar_service(
|
||||
|
||||
service = MagicMock()
|
||||
|
||||
async def get_connection_status(provider: str) -> OAuthConnectionInfo:
|
||||
async def get_connection_status(
|
||||
provider: str,
|
||||
workspace_id: UUID | None = None,
|
||||
) -> OAuthConnectionInfo:
|
||||
is_connected = providers_connected.get(provider, False)
|
||||
email = provider_emails.get(provider)
|
||||
return _create_mock_connection_info(
|
||||
provider=provider,
|
||||
status=IntegrationStatus.CONNECTED.value if is_connected else IntegrationStatus.DISCONNECTED.value,
|
||||
status=IntegrationStatus.CONNECTED.value
|
||||
if is_connected
|
||||
else IntegrationStatus.DISCONNECTED.value,
|
||||
email=email,
|
||||
expires_at=datetime.now(UTC) + timedelta(hours=1) if is_connected else None,
|
||||
)
|
||||
@@ -232,7 +238,7 @@ def _create_mockcalendar_service(
|
||||
service.get_connection_status = AsyncMock(side_effect=get_connection_status)
|
||||
service.initiate_oauth = AsyncMock()
|
||||
service.complete_oauth = AsyncMock()
|
||||
service.disconnect = AsyncMock()
|
||||
service.disconnect = AsyncMock(return_value=True)
|
||||
|
||||
return service
|
||||
|
||||
@@ -293,7 +299,9 @@ class TestGetCalendarProviders:
|
||||
outlook = next(p for p in response.providers if p.name == "outlook")
|
||||
|
||||
assert google.display_name == "Google Calendar", "google should have correct display name"
|
||||
assert outlook.display_name == "Microsoft Outlook", "outlook should have correct display name"
|
||||
assert outlook.display_name == "Microsoft Outlook", (
|
||||
"outlook should have correct display name"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aborts_whencalendar_service_not_configured(self) -> None:
|
||||
@@ -349,6 +357,7 @@ class TestInitiateOAuth:
|
||||
service.initiate_oauth.assert_awaited_once_with(
|
||||
provider="outlook",
|
||||
redirect_uri=None,
|
||||
workspace_id=DEFAULT_WORKSPACE_ID,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -370,6 +379,7 @@ class TestInitiateOAuth:
|
||||
service.initiate_oauth.assert_awaited_once_with(
|
||||
provider="google",
|
||||
redirect_uri="noteflow://oauth/callback",
|
||||
workspace_id=DEFAULT_WORKSPACE_ID,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -454,15 +464,14 @@ class TestCompleteOAuth:
|
||||
provider="google",
|
||||
code="my-auth-code",
|
||||
state="my-state-token",
|
||||
workspace_id=DEFAULT_WORKSPACE_ID,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_error_on_invalid_state(self) -> None:
|
||||
"""Returns success=False with error message for invalid state."""
|
||||
service = _create_mockcalendar_service()
|
||||
service.complete_oauth.side_effect = CalendarServiceError(
|
||||
"Invalid or expired state token"
|
||||
)
|
||||
service.complete_oauth.side_effect = CalendarServiceError("Invalid or expired state token")
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
response = await _call_complete_oauth(
|
||||
@@ -476,7 +485,9 @@ class TestCompleteOAuth:
|
||||
)
|
||||
|
||||
assert response.success is False, "should fail on invalid state"
|
||||
assert "Invalid or expired state" in response.error_message, "error should mention invalid state"
|
||||
assert "Invalid or expired state" in response.error_message, (
|
||||
"error should mention invalid state"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_error_on_invalid_code(self) -> None:
|
||||
@@ -498,7 +509,9 @@ class TestCompleteOAuth:
|
||||
)
|
||||
|
||||
assert response.success is False, "should fail on invalid code"
|
||||
assert "Token exchange failed" in response.error_message, "error should mention token exchange failure"
|
||||
assert "Token exchange failed" in response.error_message, (
|
||||
"error should mention token exchange failure"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aborts_when_complete_service_unavailable(self) -> None:
|
||||
@@ -539,15 +552,15 @@ class TestGetOAuthConnectionStatus:
|
||||
)
|
||||
|
||||
assert response.connection.provider == "google", "should return correct provider"
|
||||
assert response.connection.status == IntegrationStatus.CONNECTED.value, "status should be connected"
|
||||
assert response.connection.status == IntegrationStatus.CONNECTED.value, (
|
||||
"status should be connected"
|
||||
)
|
||||
assert response.connection.email == "user@gmail.com", "should return connected email"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_disconnected_status(self) -> None:
|
||||
"""Returns disconnected status when provider not connected."""
|
||||
service = _create_mockcalendar_service(
|
||||
providers_connected={"google": False}
|
||||
)
|
||||
service = _create_mockcalendar_service(providers_connected={"google": False})
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
response = await _call_get_oauth_status(
|
||||
@@ -556,7 +569,9 @@ class TestGetOAuthConnectionStatus:
|
||||
_DummyContext(),
|
||||
)
|
||||
|
||||
assert response.connection.status == IntegrationStatus.DISCONNECTED.value, "status should be disconnected"
|
||||
assert response.connection.status == IntegrationStatus.DISCONNECTED.value, (
|
||||
"status should be disconnected"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_integration_type(self) -> None:
|
||||
@@ -573,7 +588,9 @@ class TestGetOAuthConnectionStatus:
|
||||
_DummyContext(),
|
||||
)
|
||||
|
||||
assert response.connection.integration_type == "calendar", "should return calendar integration type"
|
||||
assert response.connection.integration_type == "calendar", (
|
||||
"should return calendar integration type"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aborts_when_status_service_unavailable(self) -> None:
|
||||
@@ -597,9 +614,7 @@ class TestDisconnectOAuth:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_success_on_disconnect(self) -> None:
|
||||
"""Returns success=True when disconnection succeeds."""
|
||||
service = _create_mockcalendar_service(
|
||||
providers_connected={"google": True}
|
||||
)
|
||||
service = _create_mockcalendar_service(providers_connected={"google": True})
|
||||
service.disconnect.return_value = True
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
@@ -624,14 +639,12 @@ class TestDisconnectOAuth:
|
||||
_DummyContext(),
|
||||
)
|
||||
|
||||
service.disconnect.assert_awaited_once_with("outlook")
|
||||
service.disconnect.assert_awaited_once_with("outlook", workspace_id=DEFAULT_WORKSPACE_ID)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_false_when_not_connected(self) -> None:
|
||||
"""Returns success=False when provider was not connected."""
|
||||
service = _create_mockcalendar_service(
|
||||
providers_connected={"google": False}
|
||||
)
|
||||
service = _create_mockcalendar_service(providers_connected={"google": False})
|
||||
service.disconnect.return_value = False
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
@@ -674,7 +687,9 @@ class TestOAuthRoundTrip:
|
||||
"state-123",
|
||||
)
|
||||
|
||||
async def complete_oauth(provider: str, code: str, state: str) -> UUID:
|
||||
async def complete_oauth(
|
||||
provider: str, code: str, state: str, workspace_id: UUID | None = None
|
||||
) -> UUID:
|
||||
if state != "state-123":
|
||||
raise CalendarServiceError("Invalid state")
|
||||
connected_state[provider] = True
|
||||
@@ -683,17 +698,21 @@ class TestOAuthRoundTrip:
|
||||
|
||||
service.complete_oauth.side_effect = complete_oauth
|
||||
|
||||
async def get_status(provider: str) -> OAuthConnectionInfo:
|
||||
async def get_status(
|
||||
provider: str, workspace_id: UUID | None = None
|
||||
) -> OAuthConnectionInfo:
|
||||
is_connected = connected_state.get(provider, False)
|
||||
return _create_mock_connection_info(
|
||||
provider=provider,
|
||||
status=IntegrationStatus.CONNECTED.value if is_connected else IntegrationStatus.DISCONNECTED.value,
|
||||
status=IntegrationStatus.CONNECTED.value
|
||||
if is_connected
|
||||
else IntegrationStatus.DISCONNECTED.value,
|
||||
email=email_state.get(provider),
|
||||
)
|
||||
|
||||
service.get_connection_status.side_effect = get_status
|
||||
|
||||
async def disconnect(provider: str) -> bool:
|
||||
async def disconnect(provider: str, workspace_id: UUID | None = None) -> bool:
|
||||
if connected_state.get(provider, False):
|
||||
connected_state[provider] = False
|
||||
email_state.pop(provider, None)
|
||||
@@ -770,9 +789,7 @@ class TestOAuthRoundTrip:
|
||||
"""Completing OAuth with wrong state token fails gracefully."""
|
||||
service = _create_mockcalendar_service()
|
||||
service.initiate_oauth.return_value = ("https://auth.url", "correct-state")
|
||||
service.complete_oauth.side_effect = CalendarServiceError(
|
||||
"Invalid or expired state token"
|
||||
)
|
||||
service.complete_oauth.side_effect = CalendarServiceError("Invalid or expired state token")
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
response = await _call_complete_oauth(
|
||||
@@ -786,7 +803,9 @@ class TestOAuthRoundTrip:
|
||||
)
|
||||
|
||||
assert response.success is False, "should fail with wrong state"
|
||||
assert "Invalid or expired state" in response.error_message, "error should mention invalid state"
|
||||
assert "Invalid or expired state" in response.error_message, (
|
||||
"error should mention invalid state"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_providers_independent(self) -> None:
|
||||
@@ -809,8 +828,12 @@ class TestOAuthRoundTrip:
|
||||
ctx,
|
||||
)
|
||||
|
||||
assert google_status.connection.status == IntegrationStatus.CONNECTED.value, "google should be connected"
|
||||
assert outlook_status.connection.status == IntegrationStatus.DISCONNECTED.value, "outlook should be disconnected"
|
||||
assert google_status.connection.status == IntegrationStatus.CONNECTED.value, (
|
||||
"google should be connected"
|
||||
)
|
||||
assert outlook_status.connection.status == IntegrationStatus.DISCONNECTED.value, (
|
||||
"outlook should be disconnected"
|
||||
)
|
||||
|
||||
|
||||
class TestOAuthSecurityBehavior:
|
||||
@@ -839,9 +862,7 @@ class TestOAuthSecurityBehavior:
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokens_revoked_on_disconnect(self) -> None:
|
||||
"""Disconnect should call service to revoke tokens."""
|
||||
service = _create_mockcalendar_service(
|
||||
providers_connected={"google": True}
|
||||
)
|
||||
service = _create_mockcalendar_service(providers_connected={"google": True})
|
||||
service.disconnect.return_value = True
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
@@ -851,7 +872,7 @@ class TestOAuthSecurityBehavior:
|
||||
_DummyContext(),
|
||||
)
|
||||
|
||||
service.disconnect.assert_awaited_once_with("google")
|
||||
service.disconnect.assert_awaited_once_with("google", workspace_id=DEFAULT_WORKSPACE_ID)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_sensitive_data_in_error_responses(self) -> None:
|
||||
|
||||
214
tests/infrastructure/asr/test_factory.py
Normal file
214
tests/infrastructure/asr/test_factory.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""Tests for ASR engine factory."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.domain.ports.gpu import GpuBackend
|
||||
from noteflow.infrastructure.asr.factory import (
|
||||
EngineCreationError,
|
||||
create_asr_engine,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class TestCreateAsrEngine:
|
||||
"""Test ASR engine factory."""
|
||||
|
||||
def test_cpu_engine_creation(self) -> None:
|
||||
"""Test CPU engine is created for cpu device."""
|
||||
with patch(
|
||||
"noteflow.infrastructure.asr.factory.resolve_device",
|
||||
return_value="cpu",
|
||||
):
|
||||
engine = create_asr_engine(device="cpu", compute_type="int8")
|
||||
assert engine.device == "cpu"
|
||||
|
||||
def test_cpu_forces_float32_for_float16(self) -> None:
|
||||
"""Test CPU engine converts float16 to float32."""
|
||||
with patch(
|
||||
"noteflow.infrastructure.asr.factory.resolve_device",
|
||||
return_value="cpu",
|
||||
):
|
||||
engine = create_asr_engine(device="cpu", compute_type="float16")
|
||||
# CPU doesn't support float16, should use float32
|
||||
assert engine.compute_type in ("float32", "float16")
|
||||
|
||||
def test_auto_device_resolution(self) -> None:
|
||||
"""Test auto device resolution."""
|
||||
# Mock GPU detection to return CUDA
|
||||
with (
|
||||
patch(
|
||||
"noteflow.infrastructure.asr.factory.detect_gpu_backend",
|
||||
return_value=GpuBackend.CUDA,
|
||||
),
|
||||
patch(
|
||||
"noteflow.infrastructure.asr.factory._create_cuda_engine",
|
||||
) as mock_cuda,
|
||||
):
|
||||
mock_cuda.return_value = MagicMock()
|
||||
create_asr_engine(device="auto", compute_type="float16")
|
||||
mock_cuda.assert_called_once()
|
||||
|
||||
def test_unsupported_device_raises(self) -> None:
|
||||
"""Test unsupported device raises EngineCreationError."""
|
||||
with patch(
|
||||
"noteflow.infrastructure.asr.factory.resolve_device",
|
||||
return_value="invalid_device",
|
||||
):
|
||||
with pytest.raises(EngineCreationError, match="Unsupported device"):
|
||||
create_asr_engine(device="invalid_device")
|
||||
|
||||
|
||||
class TestDeviceResolution:
|
||||
"""Test device resolution logic."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"device",
|
||||
["cpu", "cuda", "rocm"],
|
||||
)
|
||||
def test_explicit_device_not_changed(self, device: str) -> None:
|
||||
"""Test explicit device string is not changed."""
|
||||
from noteflow.infrastructure.asr.factory import resolve_device
|
||||
|
||||
assert resolve_device(device) == device, f"Device {device} should remain unchanged"
|
||||
|
||||
def test_auto_with_cuda(self) -> None:
|
||||
"""Test auto resolves to cuda when CUDA is available."""
|
||||
from noteflow.infrastructure.asr.factory import resolve_device
|
||||
|
||||
with patch(
|
||||
"noteflow.infrastructure.asr.factory.detect_gpu_backend",
|
||||
return_value=GpuBackend.CUDA,
|
||||
):
|
||||
assert resolve_device("auto") == "cuda"
|
||||
|
||||
def test_auto_with_rocm_supported(self) -> None:
|
||||
"""Test auto resolves to rocm when ROCm is available and supported."""
|
||||
from noteflow.infrastructure.asr.factory import resolve_device
|
||||
|
||||
mock_gpu_info = MagicMock()
|
||||
mock_gpu_info.architecture = "gfx1100"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"noteflow.infrastructure.asr.factory.detect_gpu_backend",
|
||||
return_value=GpuBackend.ROCM,
|
||||
),
|
||||
patch(
|
||||
"noteflow.infrastructure.asr.factory.get_gpu_info",
|
||||
return_value=mock_gpu_info,
|
||||
),
|
||||
patch(
|
||||
"noteflow.infrastructure.asr.factory.is_rocm_architecture_supported",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
assert resolve_device("auto") == "rocm"
|
||||
|
||||
def test_auto_with_rocm_unsupported_falls_to_cpu(self) -> None:
|
||||
"""Test auto falls back to CPU when ROCm arch is unsupported."""
|
||||
from noteflow.infrastructure.asr.factory import resolve_device
|
||||
|
||||
mock_gpu_info = MagicMock()
|
||||
mock_gpu_info.architecture = "gfx803"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"noteflow.infrastructure.asr.factory.detect_gpu_backend",
|
||||
return_value=GpuBackend.ROCM,
|
||||
),
|
||||
patch(
|
||||
"noteflow.infrastructure.asr.factory.get_gpu_info",
|
||||
return_value=mock_gpu_info,
|
||||
),
|
||||
patch(
|
||||
"noteflow.infrastructure.asr.factory.is_rocm_architecture_supported",
|
||||
return_value=False,
|
||||
),
|
||||
):
|
||||
assert resolve_device("auto") == "cpu"
|
||||
|
||||
def test_auto_with_mps_falls_to_cpu(self) -> None:
|
||||
"""Test auto falls back to CPU for MPS (not supported for ASR)."""
|
||||
from noteflow.infrastructure.asr.factory import resolve_device
|
||||
|
||||
with patch(
|
||||
"noteflow.infrastructure.asr.factory.detect_gpu_backend",
|
||||
return_value=GpuBackend.MPS,
|
||||
):
|
||||
assert resolve_device("auto") == "cpu"
|
||||
|
||||
|
||||
class TestRocmEngineCreation:
|
||||
"""Test ROCm engine creation."""
|
||||
|
||||
def test_rocm_with_ctranslate2_available(self) -> None:
|
||||
"""Test ROCm uses CTranslate2 when available."""
|
||||
mock_engine = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"noteflow.infrastructure.asr.factory.resolve_device",
|
||||
return_value="rocm",
|
||||
),
|
||||
patch(
|
||||
"noteflow.infrastructure.asr.factory.is_ctranslate2_rocm_available",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"noteflow.infrastructure.asr.rocm_engine.FasterWhisperRocmEngine",
|
||||
return_value=mock_engine,
|
||||
),
|
||||
):
|
||||
engine = create_asr_engine(device="rocm", compute_type="float16")
|
||||
assert engine == mock_engine
|
||||
|
||||
def test_rocm_falls_back_to_pytorch(self) -> None:
|
||||
"""Test ROCm falls back to PyTorch Whisper when CTranslate2 unavailable."""
|
||||
mock_engine = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"noteflow.infrastructure.asr.factory.resolve_device",
|
||||
return_value="rocm",
|
||||
),
|
||||
patch(
|
||||
"noteflow.infrastructure.asr.factory.is_ctranslate2_rocm_available",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"noteflow.infrastructure.asr.pytorch_engine.WhisperPyTorchEngine",
|
||||
return_value=mock_engine,
|
||||
),
|
||||
):
|
||||
engine = create_asr_engine(device="rocm", compute_type="float16")
|
||||
assert engine == mock_engine
|
||||
|
||||
|
||||
class TestPytorchEngineFallback:
|
||||
"""Test PyTorch engine fallback."""
|
||||
|
||||
def test_pytorch_engine_import_error(self) -> None:
|
||||
"""Test import error raises EngineCreationError."""
|
||||
with (
|
||||
patch(
|
||||
"noteflow.infrastructure.asr.factory.resolve_device",
|
||||
return_value="rocm",
|
||||
),
|
||||
patch(
|
||||
"noteflow.infrastructure.asr.factory.is_ctranslate2_rocm_available",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"noteflow.infrastructure.asr.pytorch_engine.WhisperPyTorchEngine",
|
||||
side_effect=ImportError("No module"),
|
||||
),
|
||||
pytest.raises(EngineCreationError, match="Neither CTranslate2"),
|
||||
):
|
||||
create_asr_engine(device="rocm", compute_type="float16")
|
||||
1
tests/infrastructure/gpu/__init__.py
Normal file
1
tests/infrastructure/gpu/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""GPU detection tests."""
|
||||
345
tests/infrastructure/gpu/test_detection.py
Normal file
345
tests/infrastructure/gpu/test_detection.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""Tests for GPU backend detection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Final
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.domain.ports.gpu import GpuBackend, GpuInfo
|
||||
from noteflow.infrastructure.gpu.detection import (
|
||||
SUPPORTED_AMD_ARCHITECTURES,
|
||||
GpuDetectionError,
|
||||
detect_gpu_backend,
|
||||
get_gpu_info,
|
||||
get_rocm_environment_info,
|
||||
is_ctranslate2_rocm_available,
|
||||
is_rocm_architecture_supported,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
# Test constants
|
||||
VRAM_24GB_BYTES: Final[int] = 24 * 1024 * 1024 * 1024
|
||||
VRAM_24GB_MB: Final[int] = 24 * 1024
|
||||
VRAM_RX7900_MB: Final[int] = 24576
|
||||
|
||||
|
||||
class TestDetectGpuBackend:
|
||||
"""Test GPU backend detection."""
|
||||
|
||||
def test_no_pytorch_returns_none(self) -> None:
|
||||
"""Test that missing PyTorch returns NONE backend."""
|
||||
# Clear cache first
|
||||
detect_gpu_backend.cache_clear()
|
||||
|
||||
# Create a mock that only raises ImportError for 'torch'
|
||||
original_import = __builtins__["__import__"]
|
||||
|
||||
def mock_import(name: str, *args: object, **kwargs: object) -> object:
|
||||
if name == "torch":
|
||||
raise ImportError("No module named 'torch'")
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch.dict("sys.modules", {"torch": None}):
|
||||
with patch("builtins.__import__", side_effect=mock_import):
|
||||
result = detect_gpu_backend()
|
||||
assert result == GpuBackend.NONE, "Missing PyTorch should return NONE"
|
||||
|
||||
# Clear cache after test
|
||||
detect_gpu_backend.cache_clear()
|
||||
|
||||
def test_cuda_detected(self) -> None:
|
||||
"""Test CUDA backend detection."""
|
||||
detect_gpu_backend.cache_clear()
|
||||
|
||||
mock_torch = MagicMock()
|
||||
mock_torch.cuda.is_available.return_value = True
|
||||
mock_torch.version.hip = None
|
||||
mock_torch.version.cuda = "12.1"
|
||||
|
||||
with patch.dict("sys.modules", {"torch": mock_torch}):
|
||||
# Need to reimport to use mocked torch
|
||||
from noteflow.infrastructure.gpu import detection
|
||||
|
||||
# Clear the function's cache
|
||||
detection.detect_gpu_backend.cache_clear()
|
||||
|
||||
result = detection.detect_gpu_backend()
|
||||
assert result == GpuBackend.CUDA, "CUDA should be detected when available"
|
||||
|
||||
detect_gpu_backend.cache_clear()
|
||||
|
||||
def test_rocm_detected(self) -> None:
|
||||
"""Test ROCm backend detection via HIP."""
|
||||
detect_gpu_backend.cache_clear()
|
||||
|
||||
mock_torch = MagicMock()
|
||||
mock_torch.cuda.is_available.return_value = True
|
||||
mock_torch.version.hip = "6.0"
|
||||
mock_torch.version.cuda = None
|
||||
|
||||
with patch.dict("sys.modules", {"torch": mock_torch}):
|
||||
from noteflow.infrastructure.gpu import detection
|
||||
|
||||
detection.detect_gpu_backend.cache_clear()
|
||||
|
||||
result = detection.detect_gpu_backend()
|
||||
assert result == GpuBackend.ROCM, "ROCm should be detected when HIP available"
|
||||
|
||||
detect_gpu_backend.cache_clear()
|
||||
|
||||
def test_mps_detected(self) -> None:
|
||||
"""Test MPS backend detection."""
|
||||
detect_gpu_backend.cache_clear()
|
||||
|
||||
mock_torch = MagicMock()
|
||||
mock_torch.cuda.is_available.return_value = False
|
||||
mock_torch.backends.mps.is_available.return_value = True
|
||||
|
||||
with patch.dict("sys.modules", {"torch": mock_torch}):
|
||||
from noteflow.infrastructure.gpu import detection
|
||||
|
||||
detection.detect_gpu_backend.cache_clear()
|
||||
|
||||
result = detection.detect_gpu_backend()
|
||||
assert result == GpuBackend.MPS, "MPS should be detected on Apple Silicon"
|
||||
|
||||
detect_gpu_backend.cache_clear()
|
||||
|
||||
|
||||
class TestSupportedArchitectures:
|
||||
"""Test supported AMD architecture list."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"architecture",
|
||||
["gfx906", "gfx908", "gfx90a", "gfx942"],
|
||||
ids=["MI50", "MI100", "MI210", "MI300X"],
|
||||
)
|
||||
def test_cdna_architectures_included(self, architecture: str) -> None:
|
||||
"""Test that CDNA datacenter architectures are supported."""
|
||||
assert architecture in SUPPORTED_AMD_ARCHITECTURES, f"{architecture} should be supported"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"architecture",
|
||||
["gfx1030", "gfx1031", "gfx1032"],
|
||||
ids=["RX6800", "RX6700XT", "RX6600"],
|
||||
)
|
||||
def test_rdna2_architectures_included(self, architecture: str) -> None:
|
||||
"""Test that RDNA 2 consumer architectures are supported."""
|
||||
assert architecture in SUPPORTED_AMD_ARCHITECTURES, f"{architecture} should be supported"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"architecture",
|
||||
["gfx1100", "gfx1101", "gfx1102"],
|
||||
ids=["RX7900XTX", "RX7800XT", "RX7600"],
|
||||
)
|
||||
def test_rdna3_architectures_included(self, architecture: str) -> None:
|
||||
"""Test that RDNA 3 consumer architectures are supported."""
|
||||
assert architecture in SUPPORTED_AMD_ARCHITECTURES, f"{architecture} should be supported"
|
||||
|
||||
|
||||
class TestIsRocmArchitectureSupported:
|
||||
"""Test ROCm architecture support checking."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"architecture",
|
||||
["gfx1100", "gfx1030", "gfx90a", "gfx942"],
|
||||
)
|
||||
def test_supported_architectures(self, architecture: str) -> None:
|
||||
"""Test officially supported architectures."""
|
||||
assert is_rocm_architecture_supported(architecture) is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"architecture",
|
||||
["gfx803", "gfx900", "gfx1010", "unknown"],
|
||||
)
|
||||
def test_unsupported_architectures(self, architecture: str) -> None:
|
||||
"""Test unsupported architectures."""
|
||||
assert is_rocm_architecture_supported(architecture) is False
|
||||
|
||||
def test_none_architecture(self) -> None:
|
||||
"""Test None architecture returns False."""
|
||||
assert is_rocm_architecture_supported(None) is False
|
||||
|
||||
def test_override_env_var(self) -> None:
|
||||
"""Test HSA_OVERRIDE_GFX_VERSION allows any architecture."""
|
||||
with patch.dict("os.environ", {"HSA_OVERRIDE_GFX_VERSION": "10.3.0"}):
|
||||
# Even unsupported architecture should work
|
||||
assert is_rocm_architecture_supported("gfx803") is True
|
||||
|
||||
|
||||
class TestGetGpuInfo:
|
||||
"""Test GPU info retrieval."""
|
||||
|
||||
def test_no_gpu_returns_none(self) -> None:
|
||||
"""Test no GPU returns None."""
|
||||
detect_gpu_backend.cache_clear()
|
||||
|
||||
with patch(
|
||||
"noteflow.infrastructure.gpu.detection.detect_gpu_backend",
|
||||
return_value=GpuBackend.NONE,
|
||||
):
|
||||
result = get_gpu_info()
|
||||
assert result is None, "No GPU should return None"
|
||||
|
||||
def test_cuda_gpu_info(self) -> None:
|
||||
"""Test CUDA GPU info retrieval."""
|
||||
mock_props = MagicMock()
|
||||
mock_props.name = "NVIDIA GeForce RTX 4090"
|
||||
mock_props.total_memory = VRAM_24GB_BYTES
|
||||
mock_props.major = 8
|
||||
mock_props.minor = 9
|
||||
|
||||
mock_torch = MagicMock()
|
||||
mock_torch.cuda.get_device_properties.return_value = mock_props
|
||||
mock_torch.version.cuda = "12.1"
|
||||
mock_torch.version.hip = None
|
||||
|
||||
with (
|
||||
patch(
|
||||
"noteflow.infrastructure.gpu.detection.detect_gpu_backend",
|
||||
return_value=GpuBackend.CUDA,
|
||||
),
|
||||
patch.dict("sys.modules", {"torch": mock_torch}),
|
||||
):
|
||||
result = get_gpu_info()
|
||||
assert result is not None, "CUDA GPU info should not be None"
|
||||
assert result.backend == GpuBackend.CUDA, "Backend should be CUDA"
|
||||
assert result.device_name == "NVIDIA GeForce RTX 4090", "Device name mismatch"
|
||||
assert result.vram_total_mb == VRAM_24GB_MB, "VRAM should be 24GB in MB"
|
||||
assert result.architecture == "sm_89", "Architecture should be sm_89"
|
||||
|
||||
def test_mps_gpu_info(self) -> None:
|
||||
"""Test MPS GPU info retrieval."""
|
||||
with patch(
|
||||
"noteflow.infrastructure.gpu.detection.detect_gpu_backend",
|
||||
return_value=GpuBackend.MPS,
|
||||
):
|
||||
result = get_gpu_info()
|
||||
assert result is not None, "MPS GPU info should not be None"
|
||||
assert result.backend == GpuBackend.MPS, "Backend should be MPS"
|
||||
assert result.device_name == "Apple Metal", "Device should be Apple Metal"
|
||||
# MPS doesn't expose VRAM
|
||||
assert result.vram_total_mb == 0, "MPS doesn't expose VRAM"
|
||||
|
||||
def test_gpu_properties_error_raises(self) -> None:
|
||||
"""Test GPU properties retrieval error raises GpuDetectionError."""
|
||||
mock_torch = MagicMock()
|
||||
mock_torch.cuda.get_device_properties.side_effect = RuntimeError("Device not found")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"noteflow.infrastructure.gpu.detection.detect_gpu_backend",
|
||||
return_value=GpuBackend.CUDA,
|
||||
),
|
||||
patch.dict("sys.modules", {"torch": mock_torch}),
|
||||
pytest.raises(GpuDetectionError, match="Failed to get GPU properties"),
|
||||
):
|
||||
get_gpu_info()
|
||||
|
||||
|
||||
class TestIsCtranslate2RocmAvailable:
|
||||
"""Test CTranslate2-ROCm availability checking."""
|
||||
|
||||
def test_not_rocm_returns_false(self) -> None:
|
||||
"""Test non-ROCm backend returns False."""
|
||||
detect_gpu_backend.cache_clear()
|
||||
|
||||
with patch(
|
||||
"noteflow.infrastructure.gpu.detection.detect_gpu_backend",
|
||||
return_value=GpuBackend.CUDA,
|
||||
):
|
||||
assert is_ctranslate2_rocm_available() is False
|
||||
|
||||
def test_no_ctranslate2_returns_false(self) -> None:
|
||||
"""Test missing CTranslate2 returns False."""
|
||||
with (
|
||||
patch(
|
||||
"noteflow.infrastructure.gpu.detection.detect_gpu_backend",
|
||||
return_value=GpuBackend.ROCM,
|
||||
),
|
||||
patch("builtins.__import__", side_effect=ImportError),
|
||||
):
|
||||
assert is_ctranslate2_rocm_available() is False
|
||||
|
||||
|
||||
class TestGetRocmEnvironmentInfo:
|
||||
"""Test ROCm environment info retrieval."""
|
||||
|
||||
def test_empty_env(self) -> None:
|
||||
"""Test empty environment returns empty dict."""
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
result = get_rocm_environment_info()
|
||||
assert result == {}, "Empty env should return empty dict"
|
||||
|
||||
def test_rocm_vars_captured(self) -> None:
|
||||
"""Test ROCm environment variables are captured."""
|
||||
env_vars = {
|
||||
"HSA_OVERRIDE_GFX_VERSION": "10.3.0",
|
||||
"HIP_VISIBLE_DEVICES": "0,1",
|
||||
"ROCM_PATH": "/opt/rocm",
|
||||
}
|
||||
with patch.dict("os.environ", env_vars, clear=True):
|
||||
result = get_rocm_environment_info()
|
||||
assert result == env_vars, "ROCm env vars should be captured"
|
||||
|
||||
|
||||
class TestGpuBackendEnum:
|
||||
"""Test GpuBackend enum."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("backend", "expected_value"),
|
||||
[
|
||||
(GpuBackend.NONE, "none"),
|
||||
(GpuBackend.CUDA, "cuda"),
|
||||
(GpuBackend.ROCM, "rocm"),
|
||||
(GpuBackend.MPS, "mps"),
|
||||
],
|
||||
)
|
||||
def test_enum_values(self, backend: GpuBackend, expected_value: str) -> None:
|
||||
"""Test GpuBackend enum has expected values."""
|
||||
assert backend.value == expected_value, f"{backend} should have value {expected_value}"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("backend", "string_value"),
|
||||
[
|
||||
(GpuBackend.CUDA, "cuda"),
|
||||
(GpuBackend.ROCM, "rocm"),
|
||||
],
|
||||
)
|
||||
def test_string_comparison(self, backend: GpuBackend, string_value: str) -> None:
|
||||
"""Test GpuBackend can be compared as string."""
|
||||
assert backend == string_value, f"{backend} should equal {string_value}"
|
||||
|
||||
|
||||
class TestGpuInfo:
|
||||
"""Test GpuInfo dataclass."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test GpuInfo creation."""
|
||||
info = GpuInfo(
|
||||
backend=GpuBackend.ROCM,
|
||||
device_name="AMD Radeon RX 7900 XTX",
|
||||
vram_total_mb=VRAM_RX7900_MB,
|
||||
driver_version="6.0",
|
||||
architecture="gfx1100",
|
||||
)
|
||||
assert info.backend == GpuBackend.ROCM, "Backend should be ROCM"
|
||||
assert info.device_name == "AMD Radeon RX 7900 XTX", "Device name mismatch"
|
||||
assert info.vram_total_mb == VRAM_RX7900_MB, "VRAM mismatch"
|
||||
assert info.driver_version == "6.0", "Driver version mismatch"
|
||||
assert info.architecture == "gfx1100", "Architecture mismatch"
|
||||
|
||||
def test_frozen(self) -> None:
|
||||
"""Test GpuInfo is immutable."""
|
||||
info = GpuInfo(
|
||||
backend=GpuBackend.CUDA,
|
||||
device_name="GPU",
|
||||
vram_total_mb=1024,
|
||||
driver_version="12.0",
|
||||
)
|
||||
with pytest.raises(AttributeError, match="cannot assign"):
|
||||
info.device_name = "New Name" # type: ignore[misc]
|
||||
@@ -4,7 +4,9 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Final
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
@@ -20,6 +22,17 @@ from support.db_utils import (
|
||||
stop_container,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
|
||||
# ============================================================================
|
||||
# Audio Fixture Constants
|
||||
# ============================================================================
|
||||
|
||||
SAMPLE_RATE: Final[int] = 16000
|
||||
MAX_AUDIO_SECONDS: Final[float] = 10.0
|
||||
MAX_AUDIO_SAMPLES: Final[int] = int(MAX_AUDIO_SECONDS * SAMPLE_RATE)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def session_factory() -> AsyncGenerator[async_sessionmaker[AsyncSession], None]:
|
||||
@@ -104,3 +117,44 @@ async def stopped_meeting_with_segments(
|
||||
await uow.segments.add(meeting.id, segment_1)
|
||||
await uow.commit()
|
||||
return meeting.id
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Audio Fixtures (for ASR integration tests)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def audio_fixture_path() -> Path:
|
||||
"""Path to the test audio fixture.
|
||||
|
||||
Returns path to tests/fixtures/sample_discord.wav (16kHz mono PCM).
|
||||
Skips test if fixture file is not found.
|
||||
"""
|
||||
path = Path(__file__).parent.parent / "fixtures" / "sample_discord.wav"
|
||||
if not path.exists():
|
||||
pytest.skip(f"Test audio fixture not found: {path}")
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def audio_samples(audio_fixture_path: Path) -> NDArray[np.float32]:
|
||||
"""Load audio samples from fixture file.
|
||||
|
||||
Returns first 10 seconds as float32 array normalized to [-1, 1].
|
||||
"""
|
||||
import wave
|
||||
|
||||
with wave.open(str(audio_fixture_path), "rb") as wav:
|
||||
assert wav.getsampwidth() == 2, "Expected 16-bit audio"
|
||||
assert wav.getnchannels() == 1, "Expected mono audio"
|
||||
assert wav.getframerate() == SAMPLE_RATE, f"Expected {SAMPLE_RATE}Hz"
|
||||
|
||||
# Read limited samples for faster testing
|
||||
n_frames = min(wav.getnframes(), MAX_AUDIO_SAMPLES)
|
||||
raw_data = wav.readframes(n_frames)
|
||||
|
||||
# Convert to float32 normalized
|
||||
samples = np.frombuffer(raw_data, dtype=np.int16).astype(np.float32)
|
||||
samples /= 32768.0 # Normalize int16 to [-1, 1]
|
||||
return samples
|
||||
|
||||
639
tests/integration/test_asr_pytorch_engine.py
Normal file
639
tests/integration/test_asr_pytorch_engine.py
Normal file
@@ -0,0 +1,639 @@
|
||||
"""Integration tests for WhisperPyTorchEngine.
|
||||
|
||||
These tests verify that the PyTorch-based Whisper engine can actually
|
||||
load models and transcribe audio. Unlike mock-based unit tests, these
|
||||
tests exercise the real transcription pipeline.
|
||||
|
||||
Requirements:
|
||||
- openai-whisper package installed
|
||||
- CPU-only (no GPU required)
|
||||
- Internet connection for first model download
|
||||
|
||||
Test audio fixture:
|
||||
Uses tests/fixtures/sample_discord.wav (16kHz mono PCM)
|
||||
Fixtures defined in conftest.py: audio_fixture_path, audio_samples
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Final
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from .conftest import MAX_AUDIO_SECONDS, SAMPLE_RATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
|
||||
# ============================================================================
|
||||
# Test Constants
|
||||
# ============================================================================
|
||||
|
||||
MODEL_SIZE_TINY: Final[str] = "tiny"
|
||||
DEVICE_CPU: Final[str] = "cpu"
|
||||
COMPUTE_TYPE_FLOAT32: Final[str] = "float32"
|
||||
|
||||
|
||||
def _check_whisper_available() -> bool:
|
||||
"""Check if openai-whisper is available.
|
||||
|
||||
Note: There's a package conflict with graphite's 'whisper' database package.
|
||||
We check for 'load_model' attribute to verify it's the correct whisper.
|
||||
"""
|
||||
try:
|
||||
import whisper
|
||||
|
||||
# Verify it's OpenAI's whisper, not graphite's whisper database
|
||||
return hasattr(whisper, "load_model")
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
# Provide informative skip message
|
||||
_WHISPER_SKIP_REASON = (
|
||||
"openai-whisper not installed (note: 'whisper' package exists but is "
|
||||
"graphite's database, not OpenAI's speech recognition)"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Integration Tests - Core Functionality
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(not _check_whisper_available(), reason=_WHISPER_SKIP_REASON)
|
||||
class TestWhisperPyTorchEngineIntegration:
|
||||
"""Integration tests for WhisperPyTorchEngine with real model loading."""
|
||||
|
||||
def test_engine_creation(self) -> None:
|
||||
"""Test engine can be created with CPU device."""
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
engine = WhisperPyTorchEngine(
|
||||
device=DEVICE_CPU,
|
||||
compute_type=COMPUTE_TYPE_FLOAT32,
|
||||
)
|
||||
|
||||
assert engine.device == DEVICE_CPU, "Expected CPU device"
|
||||
assert engine.compute_type == COMPUTE_TYPE_FLOAT32, "Expected float32 compute type"
|
||||
assert engine.model_size is None, "Expected model size to be unset before load_model"
|
||||
assert engine.is_loaded is False, "Expected engine to be unloaded initially"
|
||||
|
||||
def test_model_loading(self) -> None:
|
||||
"""Test tiny model can be loaded on CPU."""
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
engine = WhisperPyTorchEngine(
|
||||
device=DEVICE_CPU,
|
||||
compute_type=COMPUTE_TYPE_FLOAT32,
|
||||
)
|
||||
|
||||
# Load model with size
|
||||
engine.load_model(MODEL_SIZE_TINY)
|
||||
assert engine.is_loaded is True, "Expected model to be loaded"
|
||||
assert engine.model_size == MODEL_SIZE_TINY, "Expected model size to match"
|
||||
|
||||
# Unload model
|
||||
engine.unload()
|
||||
assert engine.is_loaded is False, "Expected engine to be unloaded"
|
||||
|
||||
def test_transcription_produces_text(
|
||||
self,
|
||||
audio_samples: NDArray[np.float32],
|
||||
) -> None:
|
||||
"""Test transcription produces non-empty text from real audio."""
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
||||
engine.load_model(MODEL_SIZE_TINY)
|
||||
|
||||
try:
|
||||
results = list(engine.transcribe(audio_samples))
|
||||
assert len(results) > 0, "Expected at least one transcription segment"
|
||||
|
||||
first_result = results[0]
|
||||
assert hasattr(first_result, "text"), "Expected text attribute on result"
|
||||
assert hasattr(first_result, "start"), "Expected start attribute on result"
|
||||
assert hasattr(first_result, "end"), "Expected end attribute on result"
|
||||
assert hasattr(first_result, "language"), "Expected language attribute on result"
|
||||
assert (
|
||||
len(first_result.text.strip()) > 0
|
||||
), "Expected non-empty transcription text"
|
||||
assert first_result.start >= 0.0, "Expected non-negative start time"
|
||||
assert first_result.end > first_result.start, "Expected end > start"
|
||||
assert (
|
||||
first_result.end <= MAX_AUDIO_SECONDS + 1.0
|
||||
), "Expected end time within audio duration buffer"
|
||||
finally:
|
||||
engine.unload()
|
||||
|
||||
def test_transcription_with_word_timings(
|
||||
self,
|
||||
audio_samples: NDArray[np.float32],
|
||||
) -> None:
|
||||
"""Test transcription produces word-level timings."""
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
engine = WhisperPyTorchEngine(
|
||||
device=DEVICE_CPU,
|
||||
compute_type=COMPUTE_TYPE_FLOAT32,
|
||||
)
|
||||
engine.load_model(MODEL_SIZE_TINY)
|
||||
|
||||
try:
|
||||
results = list(engine.transcribe(audio_samples))
|
||||
assert len(results) > 0
|
||||
|
||||
# Get first result with word timings
|
||||
first_result = results[0]
|
||||
assert hasattr(first_result, "words"), "Expected words attribute in result"
|
||||
assert len(first_result.words) > 0, "Expected word-level timings in first result"
|
||||
|
||||
# Verify first word timing structure
|
||||
first_word = first_result.words[0]
|
||||
assert hasattr(first_word, "word"), "Expected word attribute"
|
||||
assert hasattr(first_word, "start"), "Expected start attribute"
|
||||
assert hasattr(first_word, "end"), "Expected end attribute"
|
||||
assert first_word.end >= first_word.start, "Expected end >= start"
|
||||
|
||||
finally:
|
||||
engine.unload()
|
||||
|
||||
def test_transcribe_file_helper(
|
||||
self,
|
||||
audio_fixture_path: Path,
|
||||
) -> None:
|
||||
"""Test transcribe_file helper method works."""
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
engine = WhisperPyTorchEngine(
|
||||
device=DEVICE_CPU,
|
||||
compute_type=COMPUTE_TYPE_FLOAT32,
|
||||
)
|
||||
engine.load_model(MODEL_SIZE_TINY)
|
||||
|
||||
try:
|
||||
# Use transcribe_file helper
|
||||
results = list(engine.transcribe_file(audio_fixture_path))
|
||||
|
||||
# Verify we got results
|
||||
assert len(results) > 0, "Expected transcription results from file"
|
||||
|
||||
# Verify text was produced in first result
|
||||
first_result = results[0]
|
||||
assert len(first_result.text.strip()) > 0, "Expected non-empty transcription text"
|
||||
|
||||
finally:
|
||||
engine.unload()
|
||||
|
||||
def test_language_detection(
|
||||
self,
|
||||
audio_samples: NDArray[np.float32],
|
||||
) -> None:
|
||||
"""Test language is detected from audio."""
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
engine = WhisperPyTorchEngine(
|
||||
device=DEVICE_CPU,
|
||||
compute_type=COMPUTE_TYPE_FLOAT32,
|
||||
)
|
||||
engine.load_model(MODEL_SIZE_TINY)
|
||||
|
||||
try:
|
||||
results = list(engine.transcribe(audio_samples))
|
||||
assert len(results) > 0, "Expected at least one transcription segment"
|
||||
|
||||
# Verify language was detected
|
||||
first_result = results[0]
|
||||
assert first_result.language is not None, "Expected detected language"
|
||||
assert len(first_result.language) == 2, "Expected 2-letter language code"
|
||||
|
||||
finally:
|
||||
engine.unload()
|
||||
|
||||
def test_transcribe_without_model_raises(self) -> None:
|
||||
"""Test transcribing without loading model raises RuntimeError."""
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
engine = WhisperPyTorchEngine(
|
||||
device=DEVICE_CPU,
|
||||
compute_type=COMPUTE_TYPE_FLOAT32,
|
||||
)
|
||||
|
||||
# Don't load model
|
||||
assert engine.is_loaded is False
|
||||
|
||||
# Attempt to transcribe should raise
|
||||
dummy_audio = np.zeros(SAMPLE_RATE, dtype=np.float32)
|
||||
with pytest.raises(RuntimeError, match="Model not loaded"):
|
||||
list(engine.transcribe(dummy_audio))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Edge Case Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(not _check_whisper_available(), reason=_WHISPER_SKIP_REASON)
|
||||
class TestWhisperPyTorchEngineEdgeCases:
|
||||
"""Edge case tests for WhisperPyTorchEngine."""
|
||||
|
||||
def test_empty_audio_returns_empty_list(self) -> None:
|
||||
"""Test transcribing empty audio returns empty list."""
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
||||
engine.load_model(MODEL_SIZE_TINY)
|
||||
|
||||
empty_audio = np.array([], dtype=np.float32)
|
||||
|
||||
# Whisper handles empty audio gracefully by returning empty results
|
||||
results = list(engine.transcribe(empty_audio))
|
||||
assert results == [], "Expected empty list for empty audio"
|
||||
|
||||
engine.unload()
|
||||
|
||||
def test_very_short_audio_handled(self) -> None:
|
||||
"""Test transcribing very short audio (< 1 second) is handled."""
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
||||
engine.load_model(MODEL_SIZE_TINY)
|
||||
|
||||
try:
|
||||
# 0.5 seconds of silence
|
||||
short_audio = np.zeros(SAMPLE_RATE // 2, dtype=np.float32)
|
||||
results = list(engine.transcribe(short_audio))
|
||||
|
||||
# Should handle without crashing
|
||||
assert isinstance(results, list)
|
||||
finally:
|
||||
engine.unload()
|
||||
|
||||
def test_silent_audio_produces_minimal_output(self) -> None:
|
||||
"""Test transcribing silent audio produces minimal/no speech output."""
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
||||
engine.load_model(MODEL_SIZE_TINY)
|
||||
|
||||
try:
|
||||
# 3 seconds of silence
|
||||
silent_audio = np.zeros(SAMPLE_RATE * 3, dtype=np.float32)
|
||||
results = list(engine.transcribe(silent_audio))
|
||||
|
||||
# Silent audio should produce a valid (possibly empty) result list
|
||||
assert isinstance(results, list)
|
||||
finally:
|
||||
engine.unload()
|
||||
|
||||
def test_audio_with_clipping_handled(self) -> None:
|
||||
"""Test audio with extreme values (clipping) is handled."""
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
||||
engine.load_model(MODEL_SIZE_TINY)
|
||||
|
||||
try:
|
||||
# Create clipped audio (values at ±1.0)
|
||||
clipped_audio = np.ones(SAMPLE_RATE * 2, dtype=np.float32)
|
||||
clipped_audio[::2] = -1.0 # Alternating +1/-1 (harsh noise)
|
||||
|
||||
results = list(engine.transcribe(clipped_audio))
|
||||
|
||||
# Should handle without crashing
|
||||
assert isinstance(results, list)
|
||||
finally:
|
||||
engine.unload()
|
||||
|
||||
def test_audio_outside_normal_range_handled(self) -> None:
|
||||
"""Test audio with values outside [-1, 1] range is handled."""
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
||||
engine.load_model(MODEL_SIZE_TINY)
|
||||
|
||||
try:
|
||||
# Audio with values outside normal range
|
||||
rng = np.random.default_rng(42)
|
||||
loud_audio = rng.uniform(-5.0, 5.0, SAMPLE_RATE * 2).astype(np.float32)
|
||||
|
||||
results = list(engine.transcribe(loud_audio))
|
||||
|
||||
# Should handle without crashing (whisper normalizes internally)
|
||||
assert isinstance(results, list)
|
||||
finally:
|
||||
engine.unload()
|
||||
|
||||
def test_nan_values_in_audio_raises_error(self) -> None:
|
||||
"""Test audio containing NaN values raises ValueError during decoding."""
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
||||
engine.load_model(MODEL_SIZE_TINY)
|
||||
|
||||
# Audio with NaN values causes invalid logits in whisper decoding
|
||||
audio_with_nan = np.zeros(SAMPLE_RATE, dtype=np.float32)
|
||||
audio_with_nan[100:200] = np.nan
|
||||
|
||||
with pytest.raises(ValueError, match="invalid values"):
|
||||
list(engine.transcribe(audio_with_nan))
|
||||
|
||||
engine.unload()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Functional Scenario Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(not _check_whisper_available(), reason=_WHISPER_SKIP_REASON)
|
||||
class TestWhisperPyTorchEngineFunctionalScenarios:
|
||||
"""Functional scenario tests for WhisperPyTorchEngine."""
|
||||
|
||||
def test_multiple_sequential_transcriptions(
|
||||
self,
|
||||
audio_samples: NDArray[np.float32],
|
||||
) -> None:
|
||||
"""Test multiple transcriptions with same engine instance."""
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
||||
engine.load_model(MODEL_SIZE_TINY)
|
||||
|
||||
try:
|
||||
# First transcription
|
||||
results1 = list(engine.transcribe(audio_samples))
|
||||
assert len(results1) > 0, "Expected results from first transcription"
|
||||
text1 = results1[0].text
|
||||
|
||||
# Second transcription (should produce consistent results)
|
||||
results2 = list(engine.transcribe(audio_samples))
|
||||
assert len(results2) > 0, "Expected results from second transcription"
|
||||
text2 = results2[0].text
|
||||
|
||||
# First result text should be identical for same input
|
||||
assert text1 == text2, "Same audio should produce same transcription"
|
||||
|
||||
# Third transcription with shorter audio (first 3 seconds)
|
||||
short_audio = audio_samples[: SAMPLE_RATE * 3]
|
||||
results3 = list(engine.transcribe(short_audio))
|
||||
|
||||
assert isinstance(results3, list), "Expected list result from short audio"
|
||||
finally:
|
||||
engine.unload()
|
||||
|
||||
def test_transcription_with_language_hint(
|
||||
self,
|
||||
audio_samples: NDArray[np.float32],
|
||||
) -> None:
|
||||
"""Test transcription with explicit language specification."""
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
||||
engine.load_model(MODEL_SIZE_TINY)
|
||||
|
||||
try:
|
||||
# Transcribe with English hint
|
||||
results = list(engine.transcribe(audio_samples, language="en"))
|
||||
|
||||
assert len(results) > 0, "Expected transcription results with language hint"
|
||||
# Language should match hint
|
||||
assert results[0].language == "en", "Expected language to match hint"
|
||||
finally:
|
||||
engine.unload()
|
||||
|
||||
def test_model_reload_behavior(self) -> None:
|
||||
"""Test loading different model sizes sequentially."""
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
||||
|
||||
# Load tiny
|
||||
engine.load_model("tiny")
|
||||
assert engine.model_size == "tiny", "Expected tiny model size"
|
||||
assert engine.is_loaded is True, "Expected model to be loaded"
|
||||
|
||||
# Unload
|
||||
engine.unload()
|
||||
assert engine.is_loaded is False, "Expected model to be unloaded"
|
||||
|
||||
# Load base (different size)
|
||||
engine.load_model("base")
|
||||
assert engine.model_size == "base", "Expected base model size"
|
||||
assert engine.is_loaded is True, "Expected model to be loaded"
|
||||
|
||||
engine.unload()
|
||||
|
||||
def test_multiple_load_unload_cycles(self) -> None:
|
||||
"""Test multiple load/unload cycles don't cause issues."""
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
||||
|
||||
# Cycle 1
|
||||
engine.load_model(MODEL_SIZE_TINY)
|
||||
assert engine.is_loaded is True, "Expected model to be loaded (cycle 1)"
|
||||
engine.unload()
|
||||
assert engine.is_loaded is False, "Expected model to be unloaded (cycle 1)"
|
||||
|
||||
# Cycle 2
|
||||
engine.load_model(MODEL_SIZE_TINY)
|
||||
assert engine.is_loaded is True, "Expected model to be loaded (cycle 2)"
|
||||
engine.unload()
|
||||
assert engine.is_loaded is False, "Expected model to be unloaded (cycle 2)"
|
||||
|
||||
# Cycle 3
|
||||
engine.load_model(MODEL_SIZE_TINY)
|
||||
assert engine.is_loaded is True, "Expected model to be loaded (cycle 3)"
|
||||
engine.unload()
|
||||
assert engine.is_loaded is False, "Expected model to be unloaded (cycle 3)"
|
||||
|
||||
def test_unload_without_load_is_safe(self) -> None:
|
||||
"""Test calling unload without loading is safe."""
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
||||
|
||||
# Should not raise
|
||||
engine.unload()
|
||||
engine.unload() # Multiple unloads should be safe
|
||||
|
||||
assert engine.is_loaded is False
|
||||
|
||||
def test_transcription_timing_accuracy(
|
||||
self,
|
||||
audio_samples: NDArray[np.float32],
|
||||
) -> None:
|
||||
"""Test that segment timings are accurate and non-overlapping."""
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
||||
engine.load_model(MODEL_SIZE_TINY)
|
||||
|
||||
try:
|
||||
results = list(engine.transcribe(audio_samples))
|
||||
|
||||
# Verify at least one result
|
||||
assert len(results) >= 1, "Expected at least one transcription segment"
|
||||
|
||||
# Verify first segment has valid timing
|
||||
first_segment = results[0]
|
||||
assert first_segment.start >= 0.0, "Expected non-negative start time"
|
||||
assert first_segment.end > first_segment.start, "Expected end > start"
|
||||
assert first_segment.end <= MAX_AUDIO_SECONDS + 1.0, "Expected reasonable end time"
|
||||
finally:
|
||||
engine.unload()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Error Handling Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(not _check_whisper_available(), reason=_WHISPER_SKIP_REASON)
|
||||
class TestWhisperPyTorchEngineErrorHandling:
|
||||
"""Error handling tests for WhisperPyTorchEngine."""
|
||||
|
||||
def test_invalid_model_size_raises(self) -> None:
|
||||
"""Test loading invalid model size raises ValueError."""
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid model size"):
|
||||
engine.load_model("nonexistent_model")
|
||||
|
||||
def test_transcribe_file_nonexistent_raises(self) -> None:
|
||||
"""Test transcribing nonexistent file raises appropriate error."""
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
||||
engine.load_model(MODEL_SIZE_TINY)
|
||||
|
||||
try:
|
||||
nonexistent_path = Path("/nonexistent/path/audio.wav")
|
||||
|
||||
with pytest.raises((FileNotFoundError, RuntimeError, OSError), match=".*"):
|
||||
list(engine.transcribe_file(nonexistent_path))
|
||||
finally:
|
||||
engine.unload()
|
||||
|
||||
def test_double_load_overwrites_model(self) -> None:
|
||||
"""Test loading model twice overwrites previous model."""
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
||||
|
||||
engine.load_model("tiny")
|
||||
assert engine.model_size == "tiny", "Expected tiny model size"
|
||||
|
||||
# Load again without unload
|
||||
engine.load_model("base")
|
||||
assert engine.model_size == "base", "Expected base model size"
|
||||
assert engine.is_loaded is True, "Expected model to be loaded"
|
||||
|
||||
engine.unload()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Compute Type Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(not _check_whisper_available(), reason=_WHISPER_SKIP_REASON)
|
||||
class TestWhisperPyTorchEngineComputeTypes:
|
||||
"""Test different compute type configurations."""
|
||||
|
||||
def test_float32_compute_type(self) -> None:
|
||||
"""Test float32 compute type works on CPU."""
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type="float32")
|
||||
engine.load_model(MODEL_SIZE_TINY)
|
||||
|
||||
assert engine.compute_type == "float32", "Expected float32 compute type"
|
||||
assert engine.is_loaded is True, "Expected model to be loaded"
|
||||
|
||||
engine.unload()
|
||||
|
||||
def test_int8_normalized_to_float32_on_cpu(self) -> None:
|
||||
"""Test int8 is normalized to float32 on CPU."""
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
|
||||
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type="int8")
|
||||
|
||||
# int8 not supported on CPU, should normalize to float32
|
||||
assert engine.compute_type == "float32"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Factory Integration Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(not _check_whisper_available(), reason=_WHISPER_SKIP_REASON)
|
||||
class TestAsrFactoryIntegration:
|
||||
"""Integration tests for ASR factory with real engine creation."""
|
||||
|
||||
def test_factory_creates_cpu_engine(self) -> None:
|
||||
"""Test factory creates working CPU engine."""
|
||||
from noteflow.infrastructure.asr.factory import create_asr_engine
|
||||
|
||||
engine = create_asr_engine(
|
||||
device=DEVICE_CPU,
|
||||
compute_type=COMPUTE_TYPE_FLOAT32,
|
||||
)
|
||||
|
||||
# Factory should return a working engine
|
||||
assert engine is not None, "Expected engine instance"
|
||||
assert engine.device == DEVICE_CPU, "Expected CPU device"
|
||||
# model_size is None until load_model is called
|
||||
assert engine.model_size is None, "Expected model size to be unset before load_model"
|
||||
|
||||
def test_factory_auto_device_resolves_to_cpu(self) -> None:
|
||||
"""Test auto device resolves to CPU when no GPU available."""
|
||||
from noteflow.infrastructure.asr.factory import create_asr_engine
|
||||
|
||||
# In CI/test environment without GPU, should fall back to CPU
|
||||
engine = create_asr_engine(
|
||||
device="auto",
|
||||
compute_type=COMPUTE_TYPE_FLOAT32,
|
||||
)
|
||||
|
||||
assert engine is not None, "Expected engine instance"
|
||||
# Device should be resolved (not "auto")
|
||||
assert engine.device in ("cpu", "cuda", "rocm", "mps"), "Expected resolved device"
|
||||
|
||||
def test_factory_engine_can_transcribe(
|
||||
self,
|
||||
audio_samples: NDArray[np.float32],
|
||||
) -> None:
|
||||
"""Test factory-created engine can actually transcribe."""
|
||||
from noteflow.infrastructure.asr.factory import create_asr_engine
|
||||
|
||||
engine = create_asr_engine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
||||
engine.load_model(MODEL_SIZE_TINY)
|
||||
|
||||
try:
|
||||
results = list(engine.transcribe(audio_samples))
|
||||
|
||||
assert len(results) > 0, "Expected transcription results"
|
||||
first_result = results[0]
|
||||
assert len(first_result.text.strip()) > 0, "Expected non-empty transcription"
|
||||
finally:
|
||||
engine.unload()
|
||||
217
uv.lock
generated
217
uv.lock
generated
@@ -1859,6 +1859,26 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/de/73/3d757cb3fc16f0f9794dd289bcd0c4a031d9cf54d8137d6b984b2d02edf3/lightning_utilities-0.15.2-py3-none-any.whl", hash = "sha256:ad3ab1703775044bbf880dbf7ddaaac899396c96315f3aa1779cec9d618a9841", size = 29431, upload-time = "2025-08-06T13:57:38.046Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "llvmlite"
|
||||
version = "0.46.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/74/cd/08ae687ba099c7e3d21fe2ea536500563ef1943c5105bf6ab4ee3829f68e/llvmlite-0.46.0.tar.gz", hash = "sha256:227c9fd6d09dce2783c18b754b7cd9d9b3b3515210c46acc2d3c5badd9870ceb", size = 193456, upload-time = "2025-12-08T18:15:36.295Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2b/f8/4db016a5e547d4e054ff2f3b99203d63a497465f81ab78ec8eb2ff7b2304/llvmlite-0.46.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6b9588ad4c63b4f0175a3984b85494f0c927c6b001e3a246a3a7fb3920d9a137", size = 37232767, upload-time = "2025-12-08T18:15:00.737Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/aa/85/4890a7c14b4fa54400945cb52ac3cd88545bbdb973c440f98ca41591cdc5/llvmlite-0.46.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3535bd2bb6a2d7ae4012681ac228e5132cdb75fefb1bcb24e33f2f3e0c865ed4", size = 56275176, upload-time = "2025-12-08T18:15:03.936Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/07/3d31d39c1a1a08cd5337e78299fca77e6aebc07c059fbd0033e3edfab45c/llvmlite-0.46.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4cbfd366e60ff87ea6cc62f50bc4cd800ebb13ed4c149466f50cf2163a473d1e", size = 55128630, upload-time = "2025-12-08T18:15:07.196Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/6b/d139535d7590a1bba1ceb68751bef22fadaa5b815bbdf0e858e3875726b2/llvmlite-0.46.0-cp312-cp312-win_amd64.whl", hash = "sha256:398b39db462c39563a97b912d4f2866cd37cba60537975a09679b28fbbc0fb38", size = 38138940, upload-time = "2025-12-08T18:15:10.162Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e6/ff/3eba7eb0aed4b6fca37125387cd417e8c458e750621fce56d2c541f67fa8/llvmlite-0.46.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:30b60892d034bc560e0ec6654737aaa74e5ca327bd8114d82136aa071d611172", size = 37232767, upload-time = "2025-12-08T18:15:13.22Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0e/54/737755c0a91558364b9200702c3c9c15d70ed63f9b98a2c32f1c2aa1f3ba/llvmlite-0.46.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6cc19b051753368a9c9f31dc041299059ee91aceec81bd57b0e385e5d5bf1a54", size = 56275176, upload-time = "2025-12-08T18:15:16.339Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e6/91/14f32e1d70905c1c0aa4e6609ab5d705c3183116ca02ac6df2091868413a/llvmlite-0.46.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bca185892908f9ede48c0acd547fe4dc1bafefb8a4967d47db6cf664f9332d12", size = 55128629, upload-time = "2025-12-08T18:15:19.493Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4a/a7/d526ae86708cea531935ae777b6dbcabe7db52718e6401e0fb9c5edea80e/llvmlite-0.46.0-cp313-cp313-win_amd64.whl", hash = "sha256:67438fd30e12349ebb054d86a5a1a57fd5e87d264d2451bcfafbbbaa25b82a35", size = 38138941, upload-time = "2025-12-08T18:15:22.536Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/95/ae/af0ffb724814cc2ea64445acad05f71cff5f799bb7efb22e47ee99340dbc/llvmlite-0.46.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:d252edfb9f4ac1fcf20652258e3f102b26b03eef738dc8a6ffdab7d7d341d547", size = 37232768, upload-time = "2025-12-08T18:15:25.055Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c9/19/5018e5352019be753b7b07f7759cdabb69ca5779fea2494be8839270df4c/llvmlite-0.46.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:379fdd1c59badeff8982cb47e4694a6143bec3bb49aa10a466e095410522064d", size = 56275173, upload-time = "2025-12-08T18:15:28.109Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9f/c9/d57877759d707e84c082163c543853245f91b70c804115a5010532890f18/llvmlite-0.46.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2e8cbfff7f6db0fa2c771ad24154e2a7e457c2444d7673e6de06b8b698c3b269", size = 55128628, upload-time = "2025-12-08T18:15:31.098Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/30/a8/e61a8c2b3cc7a597073d9cde1fcbb567e9d827f1db30c93cf80422eac70d/llvmlite-0.46.0-cp314-cp314-win_amd64.whl", hash = "sha256:7821eda3ec1f18050f981819756631d60b6d7ab1a6cf806d9efefbe3f4082d61", size = 39153056, upload-time = "2025-12-08T18:15:33.938Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mako"
|
||||
version = "1.3.10"
|
||||
@@ -2255,6 +2275,7 @@ dependencies = [
|
||||
{ name = "grpcio-tools" },
|
||||
{ name = "httpx" },
|
||||
{ name = "keyring" },
|
||||
{ name = "openai-whisper" },
|
||||
{ name = "pgvector" },
|
||||
{ name = "protobuf" },
|
||||
{ name = "psutil" },
|
||||
@@ -2266,6 +2287,7 @@ dependencies = [
|
||||
{ name = "sqlalchemy", extra = ["asyncio"] },
|
||||
{ name = "structlog" },
|
||||
{ name = "types-psutil" },
|
||||
{ name = "whisper" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
@@ -2362,6 +2384,12 @@ optional = [
|
||||
pdf = [
|
||||
{ name = "weasyprint" },
|
||||
]
|
||||
rocm = [
|
||||
{ name = "openai-whisper" },
|
||||
]
|
||||
rocm-ctranslate2 = [
|
||||
{ name = "faster-whisper" },
|
||||
]
|
||||
summarization = [
|
||||
{ name = "anthropic" },
|
||||
{ name = "ollama" },
|
||||
@@ -2398,6 +2426,7 @@ requires-dist = [
|
||||
{ name = "diart", marker = "extra == 'diarization'", specifier = ">=0.9.2" },
|
||||
{ name = "diart", marker = "extra == 'optional'", specifier = ">=0.9.2" },
|
||||
{ name = "faster-whisper", specifier = ">=1.0" },
|
||||
{ name = "faster-whisper", marker = "extra == 'rocm-ctranslate2'", specifier = ">=1.0" },
|
||||
{ name = "google-api-python-client", marker = "extra == 'calendar'", specifier = ">=2.100" },
|
||||
{ name = "google-api-python-client", marker = "extra == 'optional'", specifier = ">=2.100" },
|
||||
{ name = "google-auth", marker = "extra == 'calendar'", specifier = ">=2.23" },
|
||||
@@ -2417,6 +2446,8 @@ requires-dist = [
|
||||
{ name = "openai", marker = "extra == 'ollama'", specifier = ">=2.13.0" },
|
||||
{ name = "openai", marker = "extra == 'optional'", specifier = ">=2.13.0" },
|
||||
{ name = "openai", marker = "extra == 'summarization'", specifier = ">=2.13.0" },
|
||||
{ name = "openai-whisper", specifier = ">=20250625" },
|
||||
{ name = "openai-whisper", marker = "extra == 'rocm'", specifier = ">=20231117" },
|
||||
{ name = "opentelemetry-api", marker = "extra == 'observability'", specifier = ">=1.28" },
|
||||
{ name = "opentelemetry-api", marker = "extra == 'optional'", specifier = ">=1.28" },
|
||||
{ name = "opentelemetry-exporter-otlp", marker = "extra == 'observability'", specifier = ">=1.28" },
|
||||
@@ -2457,8 +2488,9 @@ requires-dist = [
|
||||
{ name = "types-psutil", specifier = ">=7.2.0.20251228" },
|
||||
{ name = "weasyprint", marker = "extra == 'optional'", specifier = ">=67.0" },
|
||||
{ name = "weasyprint", marker = "extra == 'pdf'", specifier = ">=67.0" },
|
||||
{ name = "whisper", specifier = ">=1.1.10" },
|
||||
]
|
||||
provides-extras = ["audio", "dev", "triggers", "summarization", "diarization", "pdf", "ner", "calendar", "observability", "optional", "all", "ollama"]
|
||||
provides-extras = ["audio", "dev", "triggers", "summarization", "diarization", "pdf", "ner", "calendar", "rocm", "rocm-ctranslate2", "observability", "optional", "all", "ollama"]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
@@ -2474,6 +2506,30 @@ dev = [
|
||||
{ name = "watchfiles", specifier = ">=1.1.1" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "numba"
|
||||
version = "0.63.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "llvmlite" },
|
||||
{ name = "numpy" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/dc/60/0145d479b2209bd8fdae5f44201eceb8ce5a23e0ed54c71f57db24618665/numba-0.63.1.tar.gz", hash = "sha256:b320aa675d0e3b17b40364935ea52a7b1c670c9037c39cf92c49502a75902f4b", size = 2761666, upload-time = "2025-12-10T02:57:39.002Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/14/9c/c0974cd3d00ff70d30e8ff90522ba5fbb2bcee168a867d2321d8d0457676/numba-0.63.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2819cd52afa5d8d04e057bdfd54367575105f8829350d8fb5e4066fb7591cc71", size = 2680981, upload-time = "2025-12-10T02:57:17.579Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cb/70/ea2bc45205f206b7a24ee68a159f5097c9ca7e6466806e7c213587e0c2b1/numba-0.63.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5cfd45dbd3d409e713b1ccfdc2ee72ca82006860254429f4ef01867fdba5845f", size = 3801656, upload-time = "2025-12-10T02:57:19.106Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0d/82/4f4ba4fd0f99825cbf3cdefd682ca3678be1702b63362011de6e5f71f831/numba-0.63.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:69a599df6976c03b7ecf15d05302696f79f7e6d10d620367407517943355bcb0", size = 3501857, upload-time = "2025-12-10T02:57:20.721Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/af/fd/6540456efa90b5f6604a86ff50dabefb187e43557e9081adcad3be44f048/numba-0.63.1-cp312-cp312-win_amd64.whl", hash = "sha256:bbad8c63e4fc7eb3cdb2c2da52178e180419f7969f9a685f283b313a70b92af3", size = 2750282, upload-time = "2025-12-10T02:57:22.474Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/57/f7/e19e6eff445bec52dde5bed1ebb162925a8e6f988164f1ae4b3475a73680/numba-0.63.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:0bd4fd820ef7442dcc07da184c3f54bb41d2bdb7b35bacf3448e73d081f730dc", size = 2680954, upload-time = "2025-12-10T02:57:24.145Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/6c/1e222edba1e20e6b113912caa9b1665b5809433cbcb042dfd133c6f1fd38/numba-0.63.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:53de693abe4be3bd4dee38e1c55f01c55ff644a6a3696a3670589e6e4c39cde2", size = 3809736, upload-time = "2025-12-10T02:57:25.836Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/76/0a/590bad11a8b3feeac30a24d01198d46bdb76ad15c70d3a530691ce3cae58/numba-0.63.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:81227821a72a763c3d4ac290abbb4371d855b59fdf85d5af22a47c0e86bf8c7e", size = 3508854, upload-time = "2025-12-10T02:57:27.438Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4e/f5/3800384a24eed1e4d524669cdbc0b9b8a628800bb1e90d7bd676e5f22581/numba-0.63.1-cp313-cp313-win_amd64.whl", hash = "sha256:eb227b07c2ac37b09432a9bda5142047a2d1055646e089d4a240a2643e508102", size = 2750228, upload-time = "2025-12-10T02:57:30.36Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/36/2f/53be2aa8a55ee2608ebe1231789cbb217f6ece7f5e1c685d2f0752e95a5b/numba-0.63.1-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:f180883e5508940cc83de8a8bea37fc6dd20fbe4e5558d4659b8b9bef5ff4731", size = 2681153, upload-time = "2025-12-10T02:57:32.016Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/13/91/53e59c86759a0648282368d42ba732c29524a745fd555ed1fb1df83febbe/numba-0.63.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f0938764afa82a47c0e895637a6c55547a42c9e1d35cac42285b1fa60a8b02bb", size = 3778718, upload-time = "2025-12-10T02:57:33.764Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6c/0c/2be19eba50b0b7636f6d1f69dfb2825530537708a234ba1ff34afc640138/numba-0.63.1-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f90a929fa5094e062d4e0368ede1f4497d5e40f800e80aa5222c4734236a2894", size = 3478712, upload-time = "2025-12-10T02:57:35.518Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0d/5f/4d0c9e756732577a52211f31da13a3d943d185f7fb90723f56d79c696caa/numba-0.63.1-cp314-cp314-win_amd64.whl", hash = "sha256:8d6d5ce85f572ed4e1a135dbb8c0114538f9dd0e3657eeb0bb64ab204cbe2a8f", size = 2752161, upload-time = "2025-12-10T02:57:37.12Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "numpy"
|
||||
version = "1.26.4"
|
||||
@@ -2705,6 +2761,21 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/d5/eb52edff49d3d5ea116e225538c118699ddeb7c29fa17ec28af14bc10033/openai-2.13.0-py3-none-any.whl", hash = "sha256:746521065fed68df2f9c2d85613bb50844343ea81f60009b60e6a600c9352c79", size = 1066837, upload-time = "2025-12-16T18:19:43.124Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openai-whisper"
|
||||
version = "20250625"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "more-itertools" },
|
||||
{ name = "numba" },
|
||||
{ name = "numpy" },
|
||||
{ name = "tiktoken" },
|
||||
{ name = "torch" },
|
||||
{ name = "tqdm" },
|
||||
{ name = "triton", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'linux2'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/35/8e/d36f8880bcf18ec026a55807d02fe4c7357da9f25aebd92f85178000c0dc/openai_whisper-20250625.tar.gz", hash = "sha256:37a91a3921809d9f44748ffc73c0a55c9f366c85a3ef5c2ae0cc09540432eb96", size = 803191, upload-time = "2025-06-26T01:06:13.34Z" }
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-api"
|
||||
version = "1.39.1"
|
||||
@@ -6466,6 +6537,94 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "2026.1.15"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0b/86/07d5056945f9ec4590b518171c4254a5925832eb727b56d3c38a7476f316/regex-2026.1.15.tar.gz", hash = "sha256:164759aa25575cbc0651bef59a0b18353e54300d79ace8084c818ad8ac72b7d5", size = 414811, upload-time = "2026-01-14T23:18:02.775Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/92/81/10d8cf43c807d0326efe874c1b79f22bfb0fb226027b0b19ebc26d301408/regex-2026.1.15-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:4c8fcc5793dde01641a35905d6731ee1548f02b956815f8f1cab89e515a5bdf1", size = 489398, upload-time = "2026-01-14T23:14:43.741Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/90/b0/7c2a74e74ef2a7c32de724658a69a862880e3e4155cba992ba04d1c70400/regex-2026.1.15-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:bfd876041a956e6a90ad7cdb3f6a630c07d491280bfeed4544053cd434901681", size = 291339, upload-time = "2026-01-14T23:14:45.183Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/19/4d/16d0773d0c818417f4cc20aa0da90064b966d22cd62a8c46765b5bd2d643/regex-2026.1.15-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9250d087bc92b7d4899ccd5539a1b2334e44eee85d848c4c1aef8e221d3f8c8f", size = 289003, upload-time = "2026-01-14T23:14:47.25Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c6/e4/1fc4599450c9f0863d9406e944592d968b8d6dfd0d552a7d569e43bceada/regex-2026.1.15-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c8a154cf6537ebbc110e24dabe53095e714245c272da9c1be05734bdad4a61aa", size = 798656, upload-time = "2026-01-14T23:14:48.77Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b2/e6/59650d73a73fa8a60b3a590545bfcf1172b4384a7df2e7fe7b9aab4e2da9/regex-2026.1.15-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8050ba2e3ea1d8731a549e83c18d2f0999fbc99a5f6bd06b4c91449f55291804", size = 864252, upload-time = "2026-01-14T23:14:50.528Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6e/ab/1d0f4d50a1638849a97d731364c9a80fa304fec46325e48330c170ee8e80/regex-2026.1.15-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:0bf065240704cb8951cc04972cf107063917022511273e0969bdb34fc173456c", size = 912268, upload-time = "2026-01-14T23:14:52.952Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/dd/df/0d722c030c82faa1d331d1921ee268a4e8fb55ca8b9042c9341c352f17fa/regex-2026.1.15-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c32bef3e7aeee75746748643667668ef941d28b003bfc89994ecf09a10f7a1b5", size = 803589, upload-time = "2026-01-14T23:14:55.182Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/66/23/33289beba7ccb8b805c6610a8913d0131f834928afc555b241caabd422a9/regex-2026.1.15-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:d5eaa4a4c5b1906bd0d2508d68927f15b81821f85092e06f1a34a4254b0e1af3", size = 775700, upload-time = "2026-01-14T23:14:56.707Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e7/65/bf3a42fa6897a0d3afa81acb25c42f4b71c274f698ceabd75523259f6688/regex-2026.1.15-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:86c1077a3cc60d453d4084d5b9649065f3bf1184e22992bd322e1f081d3117fb", size = 787928, upload-time = "2026-01-14T23:14:58.312Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f4/f5/13bf65864fc314f68cdd6d8ca94adcab064d4d39dbd0b10fef29a9da48fc/regex-2026.1.15-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:2b091aefc05c78d286657cd4db95f2e6313375ff65dcf085e42e4c04d9c8d410", size = 858607, upload-time = "2026-01-14T23:15:00.657Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a3/31/040e589834d7a439ee43fb0e1e902bc81bd58a5ba81acffe586bb3321d35/regex-2026.1.15-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:57e7d17f59f9ebfa9667e6e5a1c0127b96b87cb9cede8335482451ed00788ba4", size = 763729, upload-time = "2026-01-14T23:15:02.248Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9b/84/6921e8129687a427edf25a34a5594b588b6d88f491320b9de5b6339a4fcb/regex-2026.1.15-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:c6c4dcdfff2c08509faa15d36ba7e5ef5fcfab25f1e8f85a0c8f45bc3a30725d", size = 850697, upload-time = "2026-01-14T23:15:03.878Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8a/87/3d06143d4b128f4229158f2de5de6c8f2485170c7221e61bf381313314b2/regex-2026.1.15-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:cf8ff04c642716a7f2048713ddc6278c5fd41faa3b9cab12607c7abecd012c22", size = 789849, upload-time = "2026-01-14T23:15:06.102Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/77/69/c50a63842b6bd48850ebc7ab22d46e7a2a32d824ad6c605b218441814639/regex-2026.1.15-cp312-cp312-win32.whl", hash = "sha256:82345326b1d8d56afbe41d881fdf62f1926d7264b2fc1537f99ae5da9aad7913", size = 266279, upload-time = "2026-01-14T23:15:07.678Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f2/36/39d0b29d087e2b11fd8191e15e81cce1b635fcc845297c67f11d0d19274d/regex-2026.1.15-cp312-cp312-win_amd64.whl", hash = "sha256:4def140aa6156bc64ee9912383d4038f3fdd18fee03a6f222abd4de6357ce42a", size = 277166, upload-time = "2026-01-14T23:15:09.257Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/28/32/5b8e476a12262748851fa8ab1b0be540360692325975b094e594dfebbb52/regex-2026.1.15-cp312-cp312-win_arm64.whl", hash = "sha256:c6c565d9a6e1a8d783c1948937ffc377dd5771e83bd56de8317c450a954d2056", size = 270415, upload-time = "2026-01-14T23:15:10.743Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/2e/6870bb16e982669b674cce3ee9ff2d1d46ab80528ee6bcc20fb2292efb60/regex-2026.1.15-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:e69d0deeb977ffe7ed3d2e4439360089f9c3f217ada608f0f88ebd67afb6385e", size = 489164, upload-time = "2026-01-14T23:15:13.962Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/dc/67/9774542e203849b0286badf67199970a44ebdb0cc5fb739f06e47ada72f8/regex-2026.1.15-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:3601ffb5375de85a16f407854d11cca8fe3f5febbe3ac78fb2866bb220c74d10", size = 291218, upload-time = "2026-01-14T23:15:15.647Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b2/87/b0cda79f22b8dee05f774922a214da109f9a4c0eca5da2c9d72d77ea062c/regex-2026.1.15-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4c5ef43b5c2d4114eb8ea424bb8c9cec01d5d17f242af88b2448f5ee81caadbc", size = 288895, upload-time = "2026-01-14T23:15:17.788Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/6a/0041f0a2170d32be01ab981d6346c83a8934277d82c780d60b127331f264/regex-2026.1.15-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:968c14d4f03e10b2fd960f1d5168c1f0ac969381d3c1fcc973bc45fb06346599", size = 798680, upload-time = "2026-01-14T23:15:19.342Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/58/de/30e1cfcdbe3e891324aa7568b7c968771f82190df5524fabc1138cb2d45a/regex-2026.1.15-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:56a5595d0f892f214609c9f76b41b7428bed439d98dc961efafdd1354d42baae", size = 864210, upload-time = "2026-01-14T23:15:22.005Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/64/44/4db2f5c5ca0ccd40ff052ae7b1e9731352fcdad946c2b812285a7505ca75/regex-2026.1.15-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:0bf650f26087363434c4e560011f8e4e738f6f3e029b85d4904c50135b86cfa5", size = 912358, upload-time = "2026-01-14T23:15:24.569Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/b6/e6a5665d43a7c42467138c8a2549be432bad22cbd206f5ec87162de74bd7/regex-2026.1.15-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:18388a62989c72ac24de75f1449d0fb0b04dfccd0a1a7c1c43af5eb503d890f6", size = 803583, upload-time = "2026-01-14T23:15:26.526Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e7/53/7cd478222169d85d74d7437e74750005e993f52f335f7c04ff7adfda3310/regex-2026.1.15-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:6d220a2517f5893f55daac983bfa9fe998a7dbcaee4f5d27a88500f8b7873788", size = 775782, upload-time = "2026-01-14T23:15:29.352Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ca/b5/75f9a9ee4b03a7c009fe60500fe550b45df94f0955ca29af16333ef557c5/regex-2026.1.15-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c9c08c2fbc6120e70abff5d7f28ffb4d969e14294fb2143b4b5c7d20e46d1714", size = 787978, upload-time = "2026-01-14T23:15:31.295Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/72/b3/79821c826245bbe9ccbb54f6eadb7879c722fd3e0248c17bfc90bf54e123/regex-2026.1.15-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:7ef7d5d4bd49ec7364315167a4134a015f61e8266c6d446fc116a9ac4456e10d", size = 858550, upload-time = "2026-01-14T23:15:33.558Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4a/85/2ab5f77a1c465745bfbfcb3ad63178a58337ae8d5274315e2cc623a822fa/regex-2026.1.15-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:6e42844ad64194fa08d5ccb75fe6a459b9b08e6d7296bd704460168d58a388f3", size = 763747, upload-time = "2026-01-14T23:15:35.206Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6d/84/c27df502d4bfe2873a3e3a7cf1bdb2b9cc10284d1a44797cf38bed790470/regex-2026.1.15-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:cfecdaa4b19f9ca534746eb3b55a5195d5c95b88cac32a205e981ec0a22b7d31", size = 850615, upload-time = "2026-01-14T23:15:37.523Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7d/b7/658a9782fb253680aa8ecb5ccbb51f69e088ed48142c46d9f0c99b46c575/regex-2026.1.15-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:08df9722d9b87834a3d701f3fca570b2be115654dbfd30179f30ab2f39d606d3", size = 789951, upload-time = "2026-01-14T23:15:39.582Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/2a/5928af114441e059f15b2f63e188bd00c6529b3051c974ade7444b85fcda/regex-2026.1.15-cp313-cp313-win32.whl", hash = "sha256:d426616dae0967ca225ab12c22274eb816558f2f99ccb4a1d52ca92e8baf180f", size = 266275, upload-time = "2026-01-14T23:15:42.108Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4f/16/5bfbb89e435897bff28cf0352a992ca719d9e55ebf8b629203c96b6ce4f7/regex-2026.1.15-cp313-cp313-win_amd64.whl", hash = "sha256:febd38857b09867d3ed3f4f1af7d241c5c50362e25ef43034995b77a50df494e", size = 277145, upload-time = "2026-01-14T23:15:44.244Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/56/c1/a09ff7392ef4233296e821aec5f78c51be5e91ffde0d163059e50fd75835/regex-2026.1.15-cp313-cp313-win_arm64.whl", hash = "sha256:8e32f7896f83774f91499d239e24cebfadbc07639c1494bb7213983842348337", size = 270411, upload-time = "2026-01-14T23:15:45.858Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3c/38/0cfd5a78e5c6db00e6782fdae70458f89850ce95baa5e8694ab91d89744f/regex-2026.1.15-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:ec94c04149b6a7b8120f9f44565722c7ae31b7a6d2275569d2eefa76b83da3be", size = 492068, upload-time = "2026-01-14T23:15:47.616Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/50/72/6c86acff16cb7c959c4355826bbf06aad670682d07c8f3998d9ef4fee7cd/regex-2026.1.15-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:40c86d8046915bb9aeb15d3f3f15b6fd500b8ea4485b30e1bbc799dab3fe29f8", size = 292756, upload-time = "2026-01-14T23:15:49.307Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4e/58/df7fb69eadfe76526ddfce28abdc0af09ffe65f20c2c90932e89d705153f/regex-2026.1.15-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:726ea4e727aba21643205edad8f2187ec682d3305d790f73b7a51c7587b64bdd", size = 291114, upload-time = "2026-01-14T23:15:51.484Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ed/6c/a4011cd1cf96b90d2cdc7e156f91efbd26531e822a7fbb82a43c1016678e/regex-2026.1.15-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1cb740d044aff31898804e7bf1181cc72c03d11dfd19932b9911ffc19a79070a", size = 807524, upload-time = "2026-01-14T23:15:53.102Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1d/25/a53ffb73183f69c3e9f4355c4922b76d2840aee160af6af5fac229b6201d/regex-2026.1.15-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:05d75a668e9ea16f832390d22131fe1e8acc8389a694c8febc3e340b0f810b93", size = 873455, upload-time = "2026-01-14T23:15:54.956Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/66/0b/8b47fc2e8f97d9b4a851736f3890a5f786443aa8901061c55f24c955f45b/regex-2026.1.15-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d991483606f3dbec93287b9f35596f41aa2e92b7c2ebbb935b63f409e243c9af", size = 915007, upload-time = "2026-01-14T23:15:57.041Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/fa/97de0d681e6d26fabe71968dbee06dd52819e9a22fdce5dac7256c31ed84/regex-2026.1.15-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:194312a14819d3e44628a44ed6fea6898fdbecb0550089d84c403475138d0a09", size = 812794, upload-time = "2026-01-14T23:15:58.916Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/22/38/e752f94e860d429654aa2b1c51880bff8dfe8f084268258adf9151cf1f53/regex-2026.1.15-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:fe2fda4110a3d0bc163c2e0664be44657431440722c5c5315c65155cab92f9e5", size = 781159, upload-time = "2026-01-14T23:16:00.817Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/a7/d739ffaef33c378fc888302a018d7f81080393d96c476b058b8c64fd2b0d/regex-2026.1.15-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:124dc36c85d34ef2d9164da41a53c1c8c122cfb1f6e1ec377a1f27ee81deb794", size = 795558, upload-time = "2026-01-14T23:16:03.267Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3e/c4/542876f9a0ac576100fc73e9c75b779f5c31e3527576cfc9cb3009dcc58a/regex-2026.1.15-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:a1774cd1981cd212506a23a14dba7fdeaee259f5deba2df6229966d9911e767a", size = 868427, upload-time = "2026-01-14T23:16:05.646Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/0f/d5655bea5b22069e32ae85a947aa564912f23758e112cdb74212848a1a1b/regex-2026.1.15-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:b5f7d8d2867152cdb625e72a530d2ccb48a3d199159144cbdd63870882fb6f80", size = 769939, upload-time = "2026-01-14T23:16:07.542Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/20/06/7e18a4fa9d326daeda46d471a44ef94201c46eaa26dbbb780b5d92cbfdda/regex-2026.1.15-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:492534a0ab925d1db998defc3c302dae3616a2fc3fe2e08db1472348f096ddf2", size = 854753, upload-time = "2026-01-14T23:16:10.395Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/67/dc8946ef3965e166f558ef3b47f492bc364e96a265eb4a2bb3ca765c8e46/regex-2026.1.15-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:c661fc820cfb33e166bf2450d3dadbda47c8d8981898adb9b6fe24e5e582ba60", size = 799559, upload-time = "2026-01-14T23:16:12.347Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a5/61/1bba81ff6d50c86c65d9fd84ce9699dd106438ee4cdb105bf60374ee8412/regex-2026.1.15-cp313-cp313t-win32.whl", hash = "sha256:99ad739c3686085e614bf77a508e26954ff1b8f14da0e3765ff7abbf7799f952", size = 268879, upload-time = "2026-01-14T23:16:14.049Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/5e/cef7d4c5fb0ea3ac5c775fd37db5747f7378b29526cc83f572198924ff47/regex-2026.1.15-cp313-cp313t-win_amd64.whl", hash = "sha256:32655d17905e7ff8ba5c764c43cb124e34a9245e45b83c22e81041e1071aee10", size = 280317, upload-time = "2026-01-14T23:16:15.718Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/52/4317f7a5988544e34ab57b4bde0f04944c4786128c933fb09825924d3e82/regex-2026.1.15-cp313-cp313t-win_arm64.whl", hash = "sha256:b2a13dd6a95e95a489ca242319d18fc02e07ceb28fa9ad146385194d95b3c829", size = 271551, upload-time = "2026-01-14T23:16:17.533Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/52/0a/47fa888ec7cbbc7d62c5f2a6a888878e76169170ead271a35239edd8f0e8/regex-2026.1.15-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:d920392a6b1f353f4aa54328c867fec3320fa50657e25f64abf17af054fc97ac", size = 489170, upload-time = "2026-01-14T23:16:19.835Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ac/c4/d000e9b7296c15737c9301708e9e7fbdea009f8e93541b6b43bdb8219646/regex-2026.1.15-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:b5a28980a926fa810dbbed059547b02783952e2efd9c636412345232ddb87ff6", size = 291146, upload-time = "2026-01-14T23:16:21.541Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/b6/921cc61982e538682bdf3bdf5b2c6ab6b34368da1f8e98a6c1ddc503c9cf/regex-2026.1.15-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:621f73a07595d83f28952d7bd1e91e9d1ed7625fb7af0064d3516674ec93a2a2", size = 288986, upload-time = "2026-01-14T23:16:23.381Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ca/33/eb7383dde0bbc93f4fb9d03453aab97e18ad4024ac7e26cef8d1f0a2cff0/regex-2026.1.15-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3d7d92495f47567a9b1669c51fc8d6d809821849063d168121ef801bbc213846", size = 799098, upload-time = "2026-01-14T23:16:25.088Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/27/56/b664dccae898fc8d8b4c23accd853f723bde0f026c747b6f6262b688029c/regex-2026.1.15-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8dd16fba2758db7a3780a051f245539c4451ca20910f5a5e6ea1c08d06d4a76b", size = 864980, upload-time = "2026-01-14T23:16:27.297Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/16/40/0999e064a170eddd237bae9ccfcd8f28b3aa98a38bf727a086425542a4fc/regex-2026.1.15-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:1e1808471fbe44c1a63e5f577a1d5f02fe5d66031dcbdf12f093ffc1305a858e", size = 911607, upload-time = "2026-01-14T23:16:29.235Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/07/78/c77f644b68ab054e5a674fb4da40ff7bffb2c88df58afa82dbf86573092d/regex-2026.1.15-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0751a26ad39d4f2ade8fe16c59b2bf5cb19eb3d2cd543e709e583d559bd9efde", size = 803358, upload-time = "2026-01-14T23:16:31.369Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/27/31/d4292ea8566eaa551fafc07797961c5963cf5235c797cc2ae19b85dfd04d/regex-2026.1.15-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:0f0c7684c7f9ca241344ff95a1de964f257a5251968484270e91c25a755532c5", size = 775833, upload-time = "2026-01-14T23:16:33.141Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ce/b2/cff3bf2fea4133aa6fb0d1e370b37544d18c8350a2fa118c7e11d1db0e14/regex-2026.1.15-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:74f45d170a21df41508cb67165456538425185baaf686281fa210d7e729abc34", size = 788045, upload-time = "2026-01-14T23:16:35.005Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8d/99/2cb9b69045372ec877b6f5124bda4eb4253bc58b8fe5848c973f752bc52c/regex-2026.1.15-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:f1862739a1ffb50615c0fde6bae6569b5efbe08d98e59ce009f68a336f64da75", size = 859374, upload-time = "2026-01-14T23:16:36.919Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/09/16/710b0a5abe8e077b1729a562d2f297224ad079f3a66dce46844c193416c8/regex-2026.1.15-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:453078802f1b9e2b7303fb79222c054cb18e76f7bdc220f7530fdc85d319f99e", size = 763940, upload-time = "2026-01-14T23:16:38.685Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/dd/d1/7585c8e744e40eb3d32f119191969b91de04c073fca98ec14299041f6e7e/regex-2026.1.15-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:a30a68e89e5a218b8b23a52292924c1f4b245cb0c68d1cce9aec9bbda6e2c160", size = 850112, upload-time = "2026-01-14T23:16:40.646Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/af/d6/43e1dd85df86c49a347aa57c1f69d12c652c7b60e37ec162e3096194a278/regex-2026.1.15-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:9479cae874c81bf610d72b85bb681a94c95722c127b55445285fb0e2c82db8e1", size = 789586, upload-time = "2026-01-14T23:16:42.799Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/93/38/77142422f631e013f316aaae83234c629555729a9fbc952b8a63ac91462a/regex-2026.1.15-cp314-cp314-win32.whl", hash = "sha256:d639a750223132afbfb8f429c60d9d318aeba03281a5f1ab49f877456448dcf1", size = 271691, upload-time = "2026-01-14T23:16:44.671Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4a/a9/ab16b4649524ca9e05213c1cdbb7faa85cc2aa90a0230d2f796cbaf22736/regex-2026.1.15-cp314-cp314-win_amd64.whl", hash = "sha256:4161d87f85fa831e31469bfd82c186923070fc970b9de75339b68f0c75b51903", size = 280422, upload-time = "2026-01-14T23:16:46.607Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/be/2a/20fd057bf3521cb4791f69f869635f73e0aaf2b9ad2d260f728144f9047c/regex-2026.1.15-cp314-cp314-win_arm64.whl", hash = "sha256:91c5036ebb62663a6b3999bdd2e559fd8456d17e2b485bf509784cd31a8b1705", size = 273467, upload-time = "2026-01-14T23:16:48.967Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ad/77/0b1e81857060b92b9cad239104c46507dd481b3ff1fa79f8e7f865aae38a/regex-2026.1.15-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:ee6854c9000a10938c79238de2379bea30c82e4925a371711af45387df35cab8", size = 492073, upload-time = "2026-01-14T23:16:51.154Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/70/f3/f8302b0c208b22c1e4f423147e1913fd475ddd6230565b299925353de644/regex-2026.1.15-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:2c2b80399a422348ce5de4fe40c418d6299a0fa2803dd61dc0b1a2f28e280fcf", size = 292757, upload-time = "2026-01-14T23:16:53.08Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bf/f0/ef55de2460f3b4a6da9d9e7daacd0cb79d4ef75c64a2af316e68447f0df0/regex-2026.1.15-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:dca3582bca82596609959ac39e12b7dad98385b4fefccb1151b937383cec547d", size = 291122, upload-time = "2026-01-14T23:16:55.383Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cf/55/bb8ccbacabbc3a11d863ee62a9f18b160a83084ea95cdfc5d207bfc3dd75/regex-2026.1.15-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ef71d476caa6692eea743ae5ea23cde3260677f70122c4d258ca952e5c2d4e84", size = 807761, upload-time = "2026-01-14T23:16:57.251Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8f/84/f75d937f17f81e55679a0509e86176e29caa7298c38bd1db7ce9c0bf6075/regex-2026.1.15-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c243da3436354f4af6c3058a3f81a97d47ea52c9bd874b52fd30274853a1d5df", size = 873538, upload-time = "2026-01-14T23:16:59.349Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b8/d9/0da86327df70349aa8d86390da91171bd3ca4f0e7c1d1d453a9c10344da3/regex-2026.1.15-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8355ad842a7c7e9e5e55653eade3b7d1885ba86f124dd8ab1f722f9be6627434", size = 915066, upload-time = "2026-01-14T23:17:01.607Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/5e/f660fb23fc77baa2a61aa1f1fe3a4eea2bbb8a286ddec148030672e18834/regex-2026.1.15-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f192a831d9575271a22d804ff1a5355355723f94f31d9eef25f0d45a152fdc1a", size = 812938, upload-time = "2026-01-14T23:17:04.366Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/69/33/a47a29bfecebbbfd1e5cd3f26b28020a97e4820f1c5148e66e3b7d4b4992/regex-2026.1.15-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:166551807ec20d47ceaeec380081f843e88c8949780cd42c40f18d16168bed10", size = 781314, upload-time = "2026-01-14T23:17:06.378Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/65/ec/7ec2bbfd4c3f4e494a24dec4c6943a668e2030426b1b8b949a6462d2c17b/regex-2026.1.15-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:f9ca1cbdc0fbfe5e6e6f8221ef2309988db5bcede52443aeaee9a4ad555e0dac", size = 795652, upload-time = "2026-01-14T23:17:08.521Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/46/79/a5d8651ae131fe27d7c521ad300aa7f1c7be1dbeee4d446498af5411b8a9/regex-2026.1.15-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:b30bcbd1e1221783c721483953d9e4f3ab9c5d165aa709693d3f3946747b1aea", size = 868550, upload-time = "2026-01-14T23:17:10.573Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/06/b7/25635d2809664b79f183070786a5552dd4e627e5aedb0065f4e3cf8ee37d/regex-2026.1.15-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:2a8d7b50c34578d0d3bf7ad58cde9652b7d683691876f83aedc002862a35dc5e", size = 769981, upload-time = "2026-01-14T23:17:12.871Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/16/8b/fc3fcbb2393dcfa4a6c5ffad92dc498e842df4581ea9d14309fcd3c55fb9/regex-2026.1.15-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:9d787e3310c6a6425eb346be4ff2ccf6eece63017916fd77fe8328c57be83521", size = 854780, upload-time = "2026-01-14T23:17:14.837Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d0/38/dde117c76c624713c8a2842530be9c93ca8b606c0f6102d86e8cd1ce8bea/regex-2026.1.15-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:619843841e220adca114118533a574a9cd183ed8a28b85627d2844c500a2b0db", size = 799778, upload-time = "2026-01-14T23:17:17.369Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e3/0d/3a6cfa9ae99606afb612d8fb7a66b245a9d5ff0f29bb347c8a30b6ad561b/regex-2026.1.15-cp314-cp314t-win32.whl", hash = "sha256:e90b8db97f6f2c97eb045b51a6b2c5ed69cedd8392459e0642d4199b94fabd7e", size = 274667, upload-time = "2026-01-14T23:17:19.301Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5b/b2/297293bb0742fd06b8d8e2572db41a855cdf1cae0bf009b1cb74fe07e196/regex-2026.1.15-cp314-cp314t-win_amd64.whl", hash = "sha256:5ef19071f4ac9f0834793af85bd04a920b4407715624e40cb7a0631a11137cdf", size = 284386, upload-time = "2026-01-14T23:17:21.231Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/95/e4/a3b9480c78cf8ee86626cb06f8d931d74d775897d44201ccb813097ae697/regex-2026.1.15-cp314-cp314t-win_arm64.whl", hash = "sha256:ca89c5e596fc05b015f27561b3793dc2fa0917ea0d7507eebb448efd35274a70", size = 274837, upload-time = "2026-01-14T23:17:23.146Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "requests"
|
||||
version = "2.32.5"
|
||||
@@ -7163,6 +7322,53 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638, upload-time = "2025-03-13T13:49:21.846Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tiktoken"
|
||||
version = "0.12.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "regex" },
|
||||
{ name = "requests" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/7d/ab/4d017d0f76ec3171d469d80fc03dfbb4e48a4bcaddaa831b31d526f05edc/tiktoken-0.12.0.tar.gz", hash = "sha256:b18ba7ee2b093863978fcb14f74b3707cdc8d4d4d3836853ce7ec60772139931", size = 37806, upload-time = "2025-10-06T20:22:45.419Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a4/85/be65d39d6b647c79800fd9d29241d081d4eeb06271f383bb87200d74cf76/tiktoken-0.12.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b97f74aca0d78a1ff21b8cd9e9925714c15a9236d6ceacf5c7327c117e6e21e8", size = 1050728, upload-time = "2025-10-06T20:21:52.756Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4a/42/6573e9129bc55c9bf7300b3a35bef2c6b9117018acca0dc760ac2d93dffe/tiktoken-0.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2b90f5ad190a4bb7c3eb30c5fa32e1e182ca1ca79f05e49b448438c3e225a49b", size = 994049, upload-time = "2025-10-06T20:21:53.782Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/66/c5/ed88504d2f4a5fd6856990b230b56d85a777feab84e6129af0822f5d0f70/tiktoken-0.12.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:65b26c7a780e2139e73acc193e5c63ac754021f160df919add909c1492c0fb37", size = 1129008, upload-time = "2025-10-06T20:21:54.832Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f4/90/3dae6cc5436137ebd38944d396b5849e167896fc2073da643a49f372dc4f/tiktoken-0.12.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:edde1ec917dfd21c1f2f8046b86348b0f54a2c0547f68149d8600859598769ad", size = 1152665, upload-time = "2025-10-06T20:21:56.129Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a3/fe/26df24ce53ffde419a42f5f53d755b995c9318908288c17ec3f3448313a3/tiktoken-0.12.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:35a2f8ddd3824608b3d650a000c1ef71f730d0c56486845705a8248da00f9fe5", size = 1194230, upload-time = "2025-10-06T20:21:57.546Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/20/cc/b064cae1a0e9fac84b0d2c46b89f4e57051a5f41324e385d10225a984c24/tiktoken-0.12.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83d16643edb7fa2c99eff2ab7733508aae1eebb03d5dfc46f5565862810f24e3", size = 1254688, upload-time = "2025-10-06T20:21:58.619Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/81/10/b8523105c590c5b8349f2587e2fdfe51a69544bd5a76295fc20f2374f470/tiktoken-0.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:ffc5288f34a8bc02e1ea7047b8d041104791d2ddbf42d1e5fa07822cbffe16bd", size = 878694, upload-time = "2025-10-06T20:21:59.876Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/00/61/441588ee21e6b5cdf59d6870f86beb9789e532ee9718c251b391b70c68d6/tiktoken-0.12.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:775c2c55de2310cc1bc9a3ad8826761cbdc87770e586fd7b6da7d4589e13dab3", size = 1050802, upload-time = "2025-10-06T20:22:00.96Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1f/05/dcf94486d5c5c8d34496abe271ac76c5b785507c8eae71b3708f1ad9b45a/tiktoken-0.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a01b12f69052fbe4b080a2cfb867c4de12c704b56178edf1d1d7b273561db160", size = 993995, upload-time = "2025-10-06T20:22:02.788Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/70/5163fe5359b943f8db9946b62f19be2305de8c3d78a16f629d4165e2f40e/tiktoken-0.12.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:01d99484dc93b129cd0964f9d34eee953f2737301f18b3c7257bf368d7615baa", size = 1128948, upload-time = "2025-10-06T20:22:03.814Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0c/da/c028aa0babf77315e1cef357d4d768800c5f8a6de04d0eac0f377cb619fa/tiktoken-0.12.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:4a1a4fcd021f022bfc81904a911d3df0f6543b9e7627b51411da75ff2fe7a1be", size = 1151986, upload-time = "2025-10-06T20:22:05.173Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/5a/886b108b766aa53e295f7216b509be95eb7d60b166049ce2c58416b25f2a/tiktoken-0.12.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:981a81e39812d57031efdc9ec59fa32b2a5a5524d20d4776574c4b4bd2e9014a", size = 1194222, upload-time = "2025-10-06T20:22:06.265Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f4/f8/4db272048397636ac7a078d22773dd2795b1becee7bc4922fe6207288d57/tiktoken-0.12.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9baf52f84a3f42eef3ff4e754a0db79a13a27921b457ca9832cf944c6be4f8f3", size = 1255097, upload-time = "2025-10-06T20:22:07.403Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/32/45d02e2e0ea2be3a9ed22afc47d93741247e75018aac967b713b2941f8ea/tiktoken-0.12.0-cp313-cp313-win_amd64.whl", hash = "sha256:b8a0cd0c789a61f31bf44851defbd609e8dd1e2c8589c614cc1060940ef1f697", size = 879117, upload-time = "2025-10-06T20:22:08.418Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ce/76/994fc868f88e016e6d05b0da5ac24582a14c47893f4474c3e9744283f1d5/tiktoken-0.12.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:d5f89ea5680066b68bcb797ae85219c72916c922ef0fcdd3480c7d2315ffff16", size = 1050309, upload-time = "2025-10-06T20:22:10.939Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f6/b8/57ef1456504c43a849821920d582a738a461b76a047f352f18c0b26c6516/tiktoken-0.12.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b4e7ed1c6a7a8a60a3230965bdedba8cc58f68926b835e519341413370e0399a", size = 993712, upload-time = "2025-10-06T20:22:12.115Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/72/90/13da56f664286ffbae9dbcfadcc625439142675845baa62715e49b87b68b/tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:fc530a28591a2d74bce821d10b418b26a094bf33839e69042a6e86ddb7a7fb27", size = 1128725, upload-time = "2025-10-06T20:22:13.541Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/05/df/4f80030d44682235bdaecd7346c90f67ae87ec8f3df4a3442cb53834f7e4/tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:06a9f4f49884139013b138920a4c393aa6556b2f8f536345f11819389c703ebb", size = 1151875, upload-time = "2025-10-06T20:22:14.559Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/22/1f/ae535223a8c4ef4c0c1192e3f9b82da660be9eb66b9279e95c99288e9dab/tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:04f0e6a985d95913cabc96a741c5ffec525a2c72e9df086ff17ebe35985c800e", size = 1194451, upload-time = "2025-10-06T20:22:15.545Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/78/a7/f8ead382fce0243cb625c4f266e66c27f65ae65ee9e77f59ea1653b6d730/tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:0ee8f9ae00c41770b5f9b0bb1235474768884ae157de3beb5439ca0fd70f3e25", size = 1253794, upload-time = "2025-10-06T20:22:16.624Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/93/e0/6cc82a562bc6365785a3ff0af27a2a092d57c47d7a81d9e2295d8c36f011/tiktoken-0.12.0-cp313-cp313t-win_amd64.whl", hash = "sha256:dc2dd125a62cb2b3d858484d6c614d136b5b848976794edfb63688d539b8b93f", size = 878777, upload-time = "2025-10-06T20:22:18.036Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/72/05/3abc1db5d2c9aadc4d2c76fa5640134e475e58d9fbb82b5c535dc0de9b01/tiktoken-0.12.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:a90388128df3b3abeb2bfd1895b0681412a8d7dc644142519e6f0a97c2111646", size = 1050188, upload-time = "2025-10-06T20:22:19.563Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e3/7b/50c2f060412202d6c95f32b20755c7a6273543b125c0985d6fa9465105af/tiktoken-0.12.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:da900aa0ad52247d8794e307d6446bd3cdea8e192769b56276695d34d2c9aa88", size = 993978, upload-time = "2025-10-06T20:22:20.702Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/14/27/bf795595a2b897e271771cd31cb847d479073497344c637966bdf2853da1/tiktoken-0.12.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:285ba9d73ea0d6171e7f9407039a290ca77efcdb026be7769dccc01d2c8d7fff", size = 1129271, upload-time = "2025-10-06T20:22:22.06Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f5/de/9341a6d7a8f1b448573bbf3425fa57669ac58258a667eb48a25dfe916d70/tiktoken-0.12.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:d186a5c60c6a0213f04a7a802264083dea1bbde92a2d4c7069e1a56630aef830", size = 1151216, upload-time = "2025-10-06T20:22:23.085Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/75/0d/881866647b8d1be4d67cb24e50d0c26f9f807f994aa1510cb9ba2fe5f612/tiktoken-0.12.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:604831189bd05480f2b885ecd2d1986dc7686f609de48208ebbbddeea071fc0b", size = 1194860, upload-time = "2025-10-06T20:22:24.602Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/1e/b651ec3059474dab649b8d5b69f5c65cd8fcd8918568c1935bd4136c9392/tiktoken-0.12.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:8f317e8530bb3a222547b85a58583238c8f74fd7a7408305f9f63246d1a0958b", size = 1254567, upload-time = "2025-10-06T20:22:25.671Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/80/57/ce64fd16ac390fafde001268c364d559447ba09b509181b2808622420eec/tiktoken-0.12.0-cp314-cp314-win_amd64.whl", hash = "sha256:399c3dd672a6406719d84442299a490420b458c44d3ae65516302a99675888f3", size = 921067, upload-time = "2025-10-06T20:22:26.753Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ac/a4/72eed53e8976a099539cdd5eb36f241987212c29629d0a52c305173e0a68/tiktoken-0.12.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:c2c714c72bc00a38ca969dae79e8266ddec999c7ceccd603cc4f0d04ccd76365", size = 1050473, upload-time = "2025-10-06T20:22:27.775Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e6/d7/0110b8f54c008466b19672c615f2168896b83706a6611ba6e47313dbc6e9/tiktoken-0.12.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:cbb9a3ba275165a2cb0f9a83f5d7025afe6b9d0ab01a22b50f0e74fee2ad253e", size = 993855, upload-time = "2025-10-06T20:22:28.799Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5f/77/4f268c41a3957c418b084dd576ea2fad2e95da0d8e1ab705372892c2ca22/tiktoken-0.12.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:dfdfaa5ffff8993a3af94d1125870b1d27aed7cb97aa7eb8c1cefdbc87dbee63", size = 1129022, upload-time = "2025-10-06T20:22:29.981Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4e/2b/fc46c90fe5028bd094cd6ee25a7db321cb91d45dc87531e2bdbb26b4867a/tiktoken-0.12.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:584c3ad3d0c74f5269906eb8a659c8bfc6144a52895d9261cdaf90a0ae5f4de0", size = 1150736, upload-time = "2025-10-06T20:22:30.996Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/28/c0/3c7a39ff68022ddfd7d93f3337ad90389a342f761c4d71de99a3ccc57857/tiktoken-0.12.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:54c891b416a0e36b8e2045b12b33dd66fb34a4fe7965565f1b482da50da3e86a", size = 1194908, upload-time = "2025-10-06T20:22:32.073Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/0d/c1ad6f4016a3968c048545f5d9b8ffebf577774b2ede3e2e352553b685fe/tiktoken-0.12.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:5edb8743b88d5be814b1a8a8854494719080c28faaa1ccbef02e87354fe71ef0", size = 1253706, upload-time = "2025-10-06T20:22:33.385Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/af/df/c7891ef9d2712ad774777271d39fdef63941ffba0a9d59b7ad1fd2765e57/tiktoken-0.12.0-cp314-cp314t-win_amd64.whl", hash = "sha256:f61c0aea5565ac82e2ec50a05e02a6c44734e91b51c10510b084ea1b8e633a71", size = 920667, upload-time = "2025-10-06T20:22:34.444Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinycss2"
|
||||
version = "1.5.1"
|
||||
@@ -7637,6 +7843,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/9e/510086a9ed0dee3830da838f9207f5c787487813d5eb74eb19fe306e6a3e/websocket_server-0.6.4-py3-none-any.whl", hash = "sha256:aca2d8f7569c82fe3e949cbae1f9d3f3035ae15f1d4048085431c94b7dfd35be", size = 7534, upload-time = "2021-12-19T16:34:34.597Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "whisper"
|
||||
version = "1.1.10"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "six" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b4/c3/913cdd13ef3d882fa483981378a08cd0f018fd8dd95b6bf006b9bf1cfbc9/whisper-1.1.10.tar.gz", hash = "sha256:435b4fb843c4c752719bdf0511a652d5be710e9bb35ad9ebe3b133268ee31c44", size = 42835, upload-time = "2022-05-22T18:19:54.839Z" }
|
||||
|
||||
[[package]]
|
||||
name = "wrapt"
|
||||
version = "1.17.3"
|
||||
|
||||
Reference in New Issue
Block a user