Files
noteflow/client/src-tauri/src/audio/windows_loopback.rs

301 lines
9.2 KiB
Rust

//! Windows WASAPI loopback capture (system audio).
#[cfg(target_os = "windows")]
use std::collections::VecDeque;
#[cfg(target_os = "windows")]
use std::sync::mpsc as std_mpsc;
#[cfg(target_os = "windows")]
use std::thread;
#[cfg(target_os = "windows")]
use std::time::Duration;
#[cfg(target_os = "windows")]
use wasapi::{
deinitialize, initialize_mta, DeviceEnumerator, Direction, SampleType, StreamMode, WaveFormat,
};
#[cfg(target_os = "windows")]
use crate::error::{Error, Result};
/// Pseudo device ID used to select WASAPI loopback capture on Windows.
#[cfg(target_os = "windows")]
pub const WASAPI_LOOPBACK_DEVICE_ID: &str = "system:wasapi:default";
/// Display name for the WASAPI loopback pseudo device.
#[cfg(target_os = "windows")]
pub const WASAPI_LOOPBACK_DEVICE_NAME: &str = "System Audio (Windows loopback)";
#[cfg(target_os = "windows")]
const LOOPBACK_WAIT_TIMEOUT_MS: u32 = 200;
#[cfg(target_os = "windows")]
const DEFAULT_LOOPBACK_BUFFER_HNS: i64 = 200_000;
/// Check if a device id refers to the WASAPI loopback pseudo device.
#[cfg(target_os = "windows")]
pub fn matches_wasapi_loopback_device_id(device_id: &str) -> bool {
device_id == WASAPI_LOOPBACK_DEVICE_ID
}
/// Handle for stopping WASAPI loopback capture.
#[cfg(target_os = "windows")]
pub struct WasapiLoopbackHandle {
stop_tx: std_mpsc::Sender<()>,
join: Option<std::thread::JoinHandle<()>>,
}
#[cfg(target_os = "windows")]
impl WasapiLoopbackHandle {
pub fn stop(mut self) {
self.stop_internal();
}
fn stop_internal(&mut self) {
let _ = self.stop_tx.send(());
if let Some(join) = self.join.take() {
let _ = join.join();
}
}
}
#[cfg(target_os = "windows")]
impl Drop for WasapiLoopbackHandle {
fn drop(&mut self) {
self.stop_internal();
}
}
/// Start WASAPI loopback capture on a background thread.
#[cfg(target_os = "windows")]
pub fn start_wasapi_loopback_capture<F>(
meeting_id: String,
output_device_name: Option<String>,
sample_rate: u32,
channels: u16,
samples_per_chunk: usize,
mut on_samples: F,
) -> Result<WasapiLoopbackHandle>
where
F: FnMut(&[f32]) + Send + 'static,
{
let (stop_tx, stop_rx) = std_mpsc::channel::<()>();
let (ready_tx, ready_rx) = std_mpsc::channel::<Result<()>>();
let join = thread::Builder::new()
.name("noteflow-wasapi-loopback".to_string())
.spawn(move || {
let _ = wasapi_loopback_thread_main(
meeting_id,
output_device_name,
sample_rate,
channels,
samples_per_chunk,
stop_rx,
ready_tx,
&mut on_samples,
);
})
.map_err(|err| Error::AudioCapture(format!("Failed to spawn loopback thread: {err}")))?;
match ready_rx.recv() {
Ok(Ok(())) => Ok(WasapiLoopbackHandle {
stop_tx,
join: Some(join),
}),
Ok(Err(err)) => {
let handle = WasapiLoopbackHandle {
stop_tx,
join: Some(join),
};
handle.stop();
Err(err)
}
Err(_) => {
let handle = WasapiLoopbackHandle {
stop_tx,
join: Some(join),
};
handle.stop();
Err(Error::AudioCapture(
"WASAPI loopback thread failed to start".to_string(),
))
}
}
}
/// Look up the render device by name, falling back to default if not found.
#[cfg(target_os = "windows")]
fn lookup_render_device(
enumerator: &DeviceEnumerator,
meeting_id: &str,
device_name: Option<&str>,
) -> Result<wasapi::Device> {
match device_name {
Some(name) => {
tracing::info!(
meeting_id = %meeting_id,
requested_device_name = %name,
"WASAPI loopback looking up output device by name"
);
let lookup_result = enumerator
.get_device_collection(&Direction::Render)
.and_then(|collection| collection.get_device_with_name(name));
match lookup_result {
Ok(device) => Ok(device),
Err(lookup_err) => {
tracing::warn!(
meeting_id = %meeting_id,
requested_name = %name,
error = %lookup_err,
"WASAPI device lookup failed - falling back to default device. \
This may capture from wrong device!"
);
get_default_render_device(enumerator)
}
}
}
None => {
tracing::info!(
meeting_id = %meeting_id,
"No output device specified - using default render device for WASAPI loopback"
);
get_default_render_device(enumerator)
}
}
}
/// Get the default render device.
#[cfg(target_os = "windows")]
fn get_default_render_device(enumerator: &DeviceEnumerator) -> Result<wasapi::Device> {
enumerator
.get_default_device(&Direction::Render)
.map_err(|err| Error::AudioCapture(format!("WASAPI default render device error: {err}")))
}
#[cfg(target_os = "windows")]
fn wasapi_loopback_thread_main<F>(
meeting_id: String,
output_device_name: Option<String>,
sample_rate: u32,
channels: u16,
samples_per_chunk: usize,
stop_rx: std_mpsc::Receiver<()>,
ready_tx: std_mpsc::Sender<Result<()>>,
on_samples: &mut F,
) -> Result<()>
where
F: FnMut(&[f32]),
{
initialize_mta()
.map_err(|err| Error::AudioCapture(format!("WASAPI init failed: {err}")))?;
let result: Result<()> = (|| {
let enumerator = DeviceEnumerator::new()
.map_err(|err| Error::AudioCapture(format!("WASAPI enumerator error: {err}")))?;
let device = lookup_render_device(
&enumerator,
&meeting_id,
output_device_name.as_deref(),
)?;
let device_name = device
.get_friendlyname()
.unwrap_or_else(|_| "<unknown>".to_string());
tracing::info!(
meeting_id = %meeting_id,
device_name = %device_name,
"WASAPI loopback capture using render device"
);
let mut audio_client = device
.get_iaudioclient()
.map_err(|err| Error::AudioCapture(format!("WASAPI audio client error: {err}")))?;
let desired_format = WaveFormat::new(
32,
32,
&SampleType::Float,
sample_rate as usize,
channels as usize,
None,
);
let (_, min_time) = audio_client.get_device_period().unwrap_or((0, 0));
let buffer_duration_hns = if min_time > 0 {
min_time
} else {
DEFAULT_LOOPBACK_BUFFER_HNS
};
let mode = StreamMode::EventsShared {
autoconvert: true,
buffer_duration_hns,
};
audio_client
.initialize_client(&desired_format, &Direction::Capture, &mode)
.map_err(|err| Error::AudioCapture(format!("WASAPI init client error: {err}")))?;
let h_event = audio_client
.set_get_eventhandle()
.map_err(|err| Error::AudioCapture(format!("WASAPI event handle error: {err}")))?;
let capture_client = audio_client
.get_audiocaptureclient()
.map_err(|err| Error::AudioCapture(format!("WASAPI capture client error: {err}")))?;
audio_client
.start_stream()
.map_err(|err| Error::AudioCapture(format!("WASAPI start stream error: {err}")))?;
let _ = ready_tx.send(Ok(()));
let bytes_per_chunk = samples_per_chunk * 4;
let mut sample_queue: VecDeque<u8> =
VecDeque::with_capacity(bytes_per_chunk.saturating_mul(4));
loop {
if stop_rx.try_recv().is_ok() {
break;
}
if let Err(err) = capture_client.read_from_device_to_deque(&mut sample_queue) {
tracing::error!(
meeting_id = %meeting_id,
"WASAPI loopback capture error: {}",
err
);
break;
}
while sample_queue.len() >= bytes_per_chunk {
let mut samples = Vec::with_capacity(samples_per_chunk);
for _ in 0..samples_per_chunk {
let b0 = sample_queue.pop_front().unwrap_or(0);
let b1 = sample_queue.pop_front().unwrap_or(0);
let b2 = sample_queue.pop_front().unwrap_or(0);
let b3 = sample_queue.pop_front().unwrap_or(0);
samples.push(f32::from_le_bytes([b0, b1, b2, b3]));
}
on_samples(&samples);
}
if h_event.wait_for_event(LOOPBACK_WAIT_TIMEOUT_MS).is_err() {
std::thread::sleep(Duration::from_millis(5));
}
}
let _ = audio_client.stop_stream();
Ok(())
})();
if let Err(err) = &result {
let _ = ready_tx.send(Err(Error::AudioCapture(err.to_string())));
}
deinitialize();
result
}